diff --git a/.gitattributes b/.gitattributes index 1d9109d5d3f4374bb8bbe54d3b39527374d8529f..ad94a41e4af773d90e30d921cdec035d16a142cc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text erling.jpg filter=lfs diff=lfs merge=lfs -text *.jpg filter=lfs diff=lfs merge=lfs -text +torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth filter=lfs diff=lfs merge=lfs -text +torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e filter=lfs diff=lfs merge=lfs -text diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..a8736fdf09f983f34b0bd35e5d78eb1083f98958 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "deep_privacy2"] + path = deep_privacy2 + url = https://github.com/hukkelas/deep_privacy2 diff --git a/app.py b/app.py index c7ea0d4fcb0831179e845e297f21092a4a9ce3b5..f2e73970b9ce61a7170e08ea50d05e1d6f54e9af 100644 --- a/app.py +++ b/app.py @@ -1,78 +1,37 @@ +import gradio +import sys import os +from pathlib import Path +from tops.config import instantiate +import gradio.inputs os.system("pip install --upgrade pip") os.system("pip install ftfy regex tqdm") -os.system("pip install git+https://github.com/openai/CLIP.git") +os.system("pip install --no-deps git+https://github.com/openai/CLIP.git") os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose") -os.system("pip install git+https://github.com/hukkelas/DSFD-Pytorch-Inference") -import gradio -import numpy as np -import torch -from PIL import Image +os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference") +sys.path.insert(0, Path(os.getcwd(), "deep_privacy2")) +os.environ["TORCH_HOME"] = "torch_home" from dp2 import utils -from tops.config import instantiate -import tops -import gradio.inputs - - -cfg_body = utils.load_config("configs/anonymizers/FB_cse.py") -anonymizer_body = instantiate(cfg_body.anonymizer, load_cache=False) -anonymizer_body.initialize_tracker(fps=1) -cfg_face = utils.load_config("configs/anonymizers/face.py") -anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False) -anonymizer_face.initialize_tracker(fps=1) - - -class ExampleDemo: +from gradio_demos.modules import ExampleDemo, WebcamDemo - def __init__(self, anonymizer, multi_modal_truncation=False) -> None: - self.multi_modal_truncation = multi_modal_truncation - self.anonymizer = anonymizer - with gradio.Row(): - input_image = gradio.Image(type="pil", label="Upload your image or try the example below!") - output_image = gradio.Image(type="numpy", label="Output") - with gradio.Row(): - update_btn = gradio.Button("Update Anonymization").style(full_width=True) - visualize_det = gradio.Checkbox(value=False, label="Show Detections") - visualize_det.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image]) - gradio.Examples( - ["erling.jpg", "g7-summit-leaders-distraction.jpg"], inputs=[input_image] - ) - update_btn.click(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image]) - input_image.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image]) - self.track = False +cfg_face = utils.load_config("deep_privacy2/configs/anonymizers/face.py") +for key in ["person_G_cfg", "cse_person_G_cfg", "face_G_cfg", "car_G_cfg"]: + if key in cfg_face.anonymizer: + cfg_face.anonymizer[key] = Path("deep_privacy2", cfg_face.anonymizer[key]) - def anonymize(self, img: Image, visualize_detection: bool): - img, cache_id = pil2torch(img) - img = tops.to_cuda(img) - if visualize_detection: - img = self.anonymizer.visualize_detection(img, cache_id=cache_id) - else: - img = self.anonymizer( - img, truncation_value=0 if self.multi_modal_truncation else 1, multi_modal_truncation=self.multi_modal_truncation, amp=True, - cache_id=cache_id, track=self.track) - img = utils.im2numpy(img) - return img +anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False) -def pil2torch(img: Image.Image): - img = img.convert("RGB") - img = np.array(img) - img = np.rollaxis(img, 2) - return torch.from_numpy(img), None +anonymizer_face.initialize_tracker(fps=1) with gradio.Blocks() as demo: gradio.Markdown("#
DeepPrivacy2 - Realistic Image Anonymization
") gradio.Markdown("###
Håkon Hukkelås, Rudolf Mester, Frank Lindseth
") - gradio.Markdown("
DeepPrivacy2 is a toolbox for realistic anonymization of humans, including a face and a full-body anonymizer.
") gradio.Markdown("
See more information at: https://github.com/hukkelas/deep_privacy2
") + with gradio.Tab("Face Anonymization"): + ExampleDemo(anonymizer_face) + with gradio.Tab("Live Webcam"): + WebcamDemo(anonymizer_face) - - - with gradio.Tab("Full-Body Anonymization"): - ExampleDemo(anonymizer_body, multi_modal_truncation=True) - with gradio.Tab("Face Anonymization"): - ExampleDemo(anonymizer_face, multi_modal_truncation=False) - - -demo.launch() \ No newline at end of file +demo.launch() diff --git a/configs/anonymizers/FB_cse.py b/configs/anonymizers/FB_cse.py deleted file mode 100644 index ff44a8ef9da980d545b09609de82e071edc912ac..0000000000000000000000000000000000000000 --- a/configs/anonymizers/FB_cse.py +++ /dev/null @@ -1,28 +0,0 @@ -from dp2.anonymizer import Anonymizer -from dp2.detection.person_detector import CSEPersonDetector -from ..defaults import common -from tops.config import LazyCall as L -from dp2.generator.dummy_generators import MaskOutGenerator - - -maskout_G = L(MaskOutGenerator)(noise="constant") - -detector = L(CSEPersonDetector)( - mask_rcnn_cfg=dict(), - cse_cfg=dict(), - cse_post_process_cfg=dict( - target_imsize=(288, 160), - exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1), - exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]), - iou_combine_threshold=0.4, - dilation_percentage=0.02, - normalize_embedding=False - ), - score_threshold=0.3, - cache_directory=common.output_dir.joinpath("cse_person_detection_cache") -) - -anonymizer = L(Anonymizer)( - detector="${detector}", - cse_person_G_cfg="configs/fdh/styleganL.py", -) diff --git a/configs/anonymizers/FB_cse_mask.py b/configs/anonymizers/FB_cse_mask.py deleted file mode 100644 index ff5e3bfbefad8e1d6e480fa22256aff0f9647b35..0000000000000000000000000000000000000000 --- a/configs/anonymizers/FB_cse_mask.py +++ /dev/null @@ -1,29 +0,0 @@ -from dp2.anonymizer import Anonymizer -from dp2.detection.person_detector import CSEPersonDetector -from ..defaults import common -from tops.config import LazyCall as L -from dp2.generator.dummy_generators import MaskOutGenerator - - -maskout_G = L(MaskOutGenerator)(noise="constant") - -detector = L(CSEPersonDetector)( - mask_rcnn_cfg=dict(), - cse_cfg=dict(), - cse_post_process_cfg=dict( - target_imsize=(288, 160), - exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1), - exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]), - iou_combine_threshold=0.4, - dilation_percentage=0.02, - normalize_embedding=False - ), - score_threshold=0.3, - cache_directory=common.output_dir.joinpath("cse_person_detection_cache") -) - -anonymizer = L(Anonymizer)( - detector="${detector}", - person_G_cfg="configs/fdh/styleganL_nocse.py", - cse_person_G_cfg="configs/fdh/styleganL.py", -) diff --git a/configs/anonymizers/FB_cse_mask_face.py b/configs/anonymizers/FB_cse_mask_face.py deleted file mode 100644 index 8d6e4eef72111c0d8f427631926a7465a9cb174d..0000000000000000000000000000000000000000 --- a/configs/anonymizers/FB_cse_mask_face.py +++ /dev/null @@ -1,29 +0,0 @@ -from dp2.anonymizer import Anonymizer -from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector -from ..defaults import common -from tops.config import LazyCall as L - -detector = L(CSeMaskFaceDetector)( - mask_rcnn_cfg=dict(), - face_detector_cfg=dict(), - face_post_process_cfg=dict(target_imsize=(256, 256)), - cse_cfg=dict(), - cse_post_process_cfg=dict( - target_imsize=(288, 160), - exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1), - exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]), - iou_combine_threshold=0.4, - dilation_percentage=0.02, - normalize_embedding=False - ), - score_threshold=0.3, - cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache") -) - -anonymizer = L(Anonymizer)( - detector="${detector}", - face_G_cfg="configs/fdf/stylegan.py", - person_G_cfg="configs/fdh/styleganL_nocse.py", - cse_person_G_cfg="configs/fdh/styleganL.py", - car_G_cfg="configs/generators/dummy/pixelation8.py" -) diff --git a/configs/anonymizers/face.py b/configs/anonymizers/face.py deleted file mode 100644 index 8c39560aad3c8c3f40592bb58850065994860b46..0000000000000000000000000000000000000000 --- a/configs/anonymizers/face.py +++ /dev/null @@ -1,18 +0,0 @@ -from dp2.anonymizer import Anonymizer -from dp2.detection.face_detector import FaceDetector -from ..defaults import common -from tops.config import LazyCall as L - - -detector = L(FaceDetector)( - face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True), - face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False), - score_threshold=0.3, - cache_directory=common.output_dir.joinpath("face_detection_cache") -) - - -anonymizer = L(Anonymizer)( - detector="${detector}", - face_G_cfg="configs/fdf/stylegan.py", -) diff --git a/configs/anonymizers/market1501/blackout.py b/configs/anonymizers/market1501/blackout.py deleted file mode 100644 index 14da21e3c4b367a942f9a99796a1d9996b773522..0000000000000000000000000000000000000000 --- a/configs/anonymizers/market1501/blackout.py +++ /dev/null @@ -1,8 +0,0 @@ -from ..FB_cse_mask_face import anonymizer, detector, common - -detector.score_threshold = .1 -detector.face_detector_cfg.confidence_threshold = .5 -detector.cse_cfg.score_thres = 0.3 -anonymizer.generators.face_G_cfg = None -anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py" -anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py" \ No newline at end of file diff --git a/configs/anonymizers/market1501/person.py b/configs/anonymizers/market1501/person.py deleted file mode 100644 index 51fa99b21f068ce68f796fd32c85d37d9a22bec1..0000000000000000000000000000000000000000 --- a/configs/anonymizers/market1501/person.py +++ /dev/null @@ -1,6 +0,0 @@ -from ..FB_cse_mask_face import anonymizer, detector, common - -detector.score_threshold = .1 -detector.face_detector_cfg.confidence_threshold = .5 -detector.cse_cfg.score_thres = 0.3 -anonymizer.generators.face_G_cfg = None \ No newline at end of file diff --git a/configs/anonymizers/market1501/pixelation16.py b/configs/anonymizers/market1501/pixelation16.py deleted file mode 100644 index 2569fc2abb91919f91dd12546c06a86624d235fc..0000000000000000000000000000000000000000 --- a/configs/anonymizers/market1501/pixelation16.py +++ /dev/null @@ -1,8 +0,0 @@ -from ..FB_cse_mask_face import anonymizer, detector, common - -detector.score_threshold = .1 -detector.face_detector_cfg.confidence_threshold = .5 -detector.cse_cfg.score_thres = 0.3 -anonymizer.generators.face_G_cfg = None -anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py" -anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py" \ No newline at end of file diff --git a/configs/anonymizers/market1501/pixelation8.py b/configs/anonymizers/market1501/pixelation8.py deleted file mode 100644 index ef49cb613d09e972adf7b8136b632eb210420686..0000000000000000000000000000000000000000 --- a/configs/anonymizers/market1501/pixelation8.py +++ /dev/null @@ -1,8 +0,0 @@ -from ..FB_cse_mask_face import anonymizer, detector, common - -detector.score_threshold = .1 -detector.face_detector_cfg.confidence_threshold = .5 -detector.cse_cfg.score_thres = 0.3 -anonymizer.generators.face_G_cfg = None -anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py" -anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py" \ No newline at end of file diff --git a/configs/datasets/coco_cse.py b/configs/datasets/coco_cse.py deleted file mode 100644 index 00582b39cb473c7cad1ec95ce7361ba9c25939b4..0000000000000000000000000000000000000000 --- a/configs/datasets/coco_cse.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -from pathlib import Path -from tops.config import LazyCall as L -import torch -import functools -from dp2.data.datasets import CocoCSE -from dp2.data.build import get_dataloader -from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip -from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe -from dp2.metrics.torch_metrics import compute_metrics_iteratively -from .utils import final_eval_fn - - -dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data" -metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache" -data_dir = Path(dataset_base_dir, "coco_cse") -data = dict( - imsize=(288, 160), - im_channels=3, - semantic_nc=26, - cse_nc=16, - train=dict( - dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False), - loader=L(get_dataloader)( - shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2, - batch_size="${train.batch_size}", - dataset="${..dataset}", - infinite=True, - gpu_transform=L(torch.nn.Sequential)(*[ - L(ToFloat)(), - L(StyleGANAugmentPipe)( - rotate=0.5, rotate_max=.05, - xint=.5, xint_max=0.05, - scale=.5, scale_std=.05, - aniso=0.5, aniso_std=.05, - xfrac=.5, xfrac_std=.05, - brightness=.5, brightness_std=.05, - contrast=.5, contrast_std=.1, - hue=.5, hue_max=.05, - saturation=.5, saturation_std=.5, - imgfilter=.5, imgfilter_std=.1), - L(RandomHorizontalFlip)(p=0.5), - L(CreateEmbedding)(), - L(Resize)(size="${data.imsize}"), - L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True), - L(CreateCondition)(), - ]) - ) - ), - val=dict( - dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False), - loader=L(get_dataloader)( - shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2, - batch_size="${train.batch_size}", - dataset="${..dataset}", - infinite=False, - gpu_transform=L(torch.nn.Sequential)(*[ - L(ToFloat)(), - L(CreateEmbedding)(), - L(Resize)(size="${data.imsize}"), - L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True), - L(CreateCondition)(), - ]) - ) - ), - # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP. - train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False), - evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True) -) diff --git a/configs/datasets/fdf128.py b/configs/datasets/fdf128.py deleted file mode 100644 index 8740ebd4738d7487cc9f1c6fbbcc2a2695240163..0000000000000000000000000000000000000000 --- a/configs/datasets/fdf128.py +++ /dev/null @@ -1,24 +0,0 @@ -from pathlib import Path -from functools import partial -from dp2.data.datasets.fdf import FDFDataset -from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn - -data_dir = Path(dataset_base_dir, "fdf") -data.train.dataset.dirpath = data_dir.joinpath("train") -data.val.dataset.dirpath = data_dir.joinpath("val") -data.imsize = (128, 128) - - -data.train_evaluation_fn = partial( - final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train")) -data.evaluation_fn = partial( - final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final")) - -data.train.dataset.update( - _target_ = FDFDataset, - imsize="${data.imsize}" -) -data.val.dataset.update( - _target_ = FDFDataset, - imsize="${data.imsize}" -) \ No newline at end of file diff --git a/configs/datasets/fdf256.py b/configs/datasets/fdf256.py deleted file mode 100644 index 3767634c3d8f4b2253b5ad3c28bd324a59cde0ae..0000000000000000000000000000000000000000 --- a/configs/datasets/fdf256.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -from pathlib import Path -from tops.config import LazyCall as L -import torch -import functools -from dp2.data.datasets.fdf import FDF256Dataset -from dp2.data.build import get_dataloader -from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip -from dp2.metrics.torch_metrics import compute_metrics_iteratively -from dp2.metrics.fid_clip import compute_fid_clip -from dp2.metrics.ppl import calculate_ppl -from .utils import final_eval_fn - - -def final_eval_fn(*args, **kwargs): - result = compute_metrics_iteratively(*args, **kwargs) - result2 = compute_fid_clip(*args, **kwargs) - assert all(key not in result for key in result2) - result.update(result2) - result3 = calculate_ppl(*args, **kwargs,) - assert all(key not in result for key in result3) - result.update(result3) - return result - - -dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data" -metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache" -data_dir = Path(dataset_base_dir, "fdf256") -data = dict( - imsize=(256, 256), - im_channels=3, - semantic_nc=None, - cse_nc=None, - n_keypoints=None, - train=dict( - dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False), - loader=L(get_dataloader)( - shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2, - batch_size="${train.batch_size}", - dataset="${..dataset}", - infinite=True, - gpu_transform=L(torch.nn.Sequential)(*[ - L(ToFloat)(), - L(RandomHorizontalFlip)(p=0.5), - L(Resize)(size="${data.imsize}"), - L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True), - L(CreateCondition)(), - ]) - ) - ), - val=dict( - dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False), - loader=L(get_dataloader)( - shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2, - batch_size="${train.batch_size}", - dataset="${..dataset}", - infinite=False, - gpu_transform=L(torch.nn.Sequential)(*[ - L(ToFloat)(), - L(Resize)(size="${data.imsize}"), - L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True), - L(CreateCondition)(), - ]) - ) - ), - # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP. - train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "fdf_val_train")), - evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val")) -) \ No newline at end of file diff --git a/configs/datasets/fdh.py b/configs/datasets/fdh.py deleted file mode 100644 index 298687d85302a029af2c44decf8dcb248b024031..0000000000000000000000000000000000000000 --- a/configs/datasets/fdh.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -from pathlib import Path -from tops.config import LazyCall as L -import torch -import functools -from dp2.data.datasets.fdh import get_dataloader_fdh_wds -from dp2.data.utils import get_coco_flipmap -from dp2.data.transforms.transforms import ( - Normalize, - ToFloat, - CreateCondition, - RandomHorizontalFlip, - CreateEmbedding, -) -from dp2.metrics.torch_metrics import compute_metrics_iteratively -from dp2.metrics.fid_clip import compute_fid_clip -from .utils import final_eval_fn - - -def train_eval_fn(*args, **kwargs): - result = compute_metrics_iteratively(*args, **kwargs) - result2 = compute_fid_clip(*args, **kwargs) - assert all(key not in result for key in result2) - result.update(result2) - return result - - -dataset_base_dir = ( - os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data" -) -metrics_cache = ( - os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache" -) -data_dir = Path(dataset_base_dir, "fdh") -data = dict( - imsize=(288, 160), - im_channels=3, - cse_nc=16, - n_keypoints=17, - train=dict( - loader=L(get_dataloader_fdh_wds)( - path=data_dir.joinpath("train", "out-{000000..001423}.tar"), - batch_size="${train.batch_size}", - num_workers=6, - transform=L(torch.nn.Sequential)( - L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()), - ), - gpu_transform=L(torch.nn.Sequential)( - L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]), - L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")), - L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True), - L(CreateCondition)(), - ), - infinite=True, - shuffle=True, - partial_batches=False, - load_embedding=True, - ) - ), - val=dict( - loader=L(get_dataloader_fdh_wds)( - path=data_dir.joinpath("val", "out-{000000..000023}.tar"), - batch_size="${train.batch_size}", - num_workers=6, - transform=None, - gpu_transform=L(torch.nn.Sequential)( - L(ToFloat)(keys=["img", "mask", "E_mask", "maskrcnn_mask"], norm=False), - L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")), - L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True), - L(CreateCondition)(), - ), - infinite=False, - shuffle=False, - partial_batches=True, - load_embedding=True, - ) - ), - # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP. - train_evaluation_fn=functools.partial( - train_eval_fn, - cache_directory=Path(metrics_cache, "fdh_v7_train"), - data_len=int(30e3), - ), - evaluation_fn=functools.partial( - final_eval_fn, - cache_directory=Path(metrics_cache, "fdh_v6_val"), - data_len=int(30e3), - ), -) diff --git a/configs/datasets/utils.py b/configs/datasets/utils.py deleted file mode 100644 index ccc72dc4c403ffc83db9e79e6d8ee97d9b181622..0000000000000000000000000000000000000000 --- a/configs/datasets/utils.py +++ /dev/null @@ -1,12 +0,0 @@ -from dp2.metrics.ppl import calculate_ppl -from dp2.metrics.torch_metrics import compute_metrics_iteratively -from dp2.metrics.fid_clip import compute_fid_clip - - -def final_eval_fn(*args, **kwargs): - result = compute_metrics_iteratively(*args, **kwargs) - result2 = calculate_ppl(*args, **kwargs,) - result2 = compute_fid_clip(*args, **kwargs) - assert all(key not in result for key in result2) - result.update(result2) - return result diff --git a/configs/defaults.py b/configs/defaults.py deleted file mode 100644 index 9767dfa920d6261b749ce7064d4881a64ac73798..0000000000000000000000000000000000000000 --- a/configs/defaults.py +++ /dev/null @@ -1,45 +0,0 @@ -import pathlib -import os -import torch -from tops.config import LazyCall as L - -if "PRETRAINED_CHECKPOINTS_PATH" in os.environ: - PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"]) -else: - PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints") -if "BASE_OUTPUT_DIR" in os.environ: - BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"]) -else: - BASE_OUTPUT_DIR = pathlib.Path("outputs") - - - -common = dict( - logger_backend=["wandb", "stdout", "json", "image_dumper"], - wandb_project="fba_test", - output_dir=BASE_OUTPUT_DIR, - experiment_name=None, # Optional experiment name to show on wandb -) - -train = dict( - batch_size=32, - seed=0, - ims_per_log=1024, - ims_per_val=int(200e3), - max_images_to_train=int(12e6), - amp=dict( - enabled=True, - scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"), - scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"), - ), - fp16_ddp_accumulate=False, # All gather gradients in fp16? - broadcast_buffers=False, - bias_act_plugin_enabled=True, - grid_sample_gradfix_enabled=True, - conv2d_gradfix_enabled=False, - channels_last=False, -) - -# exponential moving average -EMA = dict(rampup=0.05) - diff --git a/configs/discriminators/sg2_discriminator.py b/configs/discriminators/sg2_discriminator.py deleted file mode 100644 index e692450da87e04afe62c150dabe8eea3208b1382..0000000000000000000000000000000000000000 --- a/configs/discriminators/sg2_discriminator.py +++ /dev/null @@ -1,42 +0,0 @@ -from tops.config import LazyCall as L -from dp2.discriminator import SG2Discriminator -import torch -from dp2.loss import StyleGAN2Loss - - -discriminator = L(SG2Discriminator)( - imsize="${data.imsize}", - im_channels="${data.im_channels}", - min_fmap_resolution=4, - max_cnum_mul=8, - cnum=80, - input_condition=True, - conv_clamp=256, - input_cse=False, - cse_nc="${data.cse_nc}" -) - - -loss_fnc = L(StyleGAN2Loss)( - lazy_regularization=True, - lazy_reg_interval=16, - r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False), - EP_lambd=0.001, - pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01) -) - -def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs): - if lazy_regularization: - # From Analyzing and improving the image quality of stylegan, CVPR 2020 - c = lazy_reg_interval / (lazy_reg_interval + 1) - betas = [beta ** c for beta in betas] - lr *= c - print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}") - return type(lr=lr, betas=betas, **kwargs) - - -D_optim = L(build_D_optim)( - type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99), - lazy_regularization="${loss_fnc.lazy_regularization}", - lazy_reg_interval="${loss_fnc.lazy_reg_interval}") -G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99)) diff --git a/configs/fdf/stylegan.py b/configs/fdf/stylegan.py deleted file mode 100644 index a4da2c3ad76d3d1fb6e1d91e832cde5c735bf32a..0000000000000000000000000000000000000000 --- a/configs/fdf/stylegan.py +++ /dev/null @@ -1,14 +0,0 @@ -from ..generators.stylegan_unet import generator -from ..datasets.fdf256 import data -from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc -from ..defaults import train, common, EMA - -train.max_images_to_train = int(35e6) -G_optim.lr = 0.002 -D_optim.lr = 0.002 -generator.input_cse = False -loss_fnc.r1_opts.lambd = 1 -train.ims_per_val = int(2e6) - -common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e" -common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07" \ No newline at end of file diff --git a/configs/fdf/stylegan_fdf128.py b/configs/fdf/stylegan_fdf128.py deleted file mode 100644 index ff5c4065732d155f92a8a551efff546899589a4f..0000000000000000000000000000000000000000 --- a/configs/fdf/stylegan_fdf128.py +++ /dev/null @@ -1,13 +0,0 @@ -from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc -from ..datasets.fdf128 import data -from ..generators.stylegan_unet import generator -from ..defaults import train, common, EMA -from tops.config import LazyCall as L - -train.max_images_to_train = int(25e6) -G_optim.lr = 0.002 -D_optim.lr = 0.002 -generator.cnum = 128 -generator.max_cnum_mul = 4 -generator.input_cse = False -loss_fnc.r1_opts.lambd = .1 diff --git a/configs/fdh/styleganL.py b/configs/fdh/styleganL.py deleted file mode 100644 index 48fcf09b43a7141a270fbe5c69bd7932414270fe..0000000000000000000000000000000000000000 --- a/configs/fdh/styleganL.py +++ /dev/null @@ -1,16 +0,0 @@ -from tops.config import LazyCall as L -from ..generators.stylegan_unet import generator -from ..datasets.fdh import data -from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc -from ..defaults import train, common, EMA - -train.max_images_to_train = int(50e6) -train.batch_size = 64 -G_optim.lr = 0.002 -D_optim.lr = 0.002 -data.train.loader.num_workers = 4 -train.ims_per_val = int(1e6) -loss_fnc.r1_opts.lambd = .1 - -common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c" -common.model_md5sum = "3411478b5ec600a4219cccf4499732bd" \ No newline at end of file diff --git a/configs/fdh/styleganL_nocse.py b/configs/fdh/styleganL_nocse.py deleted file mode 100644 index 210fd68743f0b872f89f4407dfaac7c9bf5f0e32..0000000000000000000000000000000000000000 --- a/configs/fdh/styleganL_nocse.py +++ /dev/null @@ -1,14 +0,0 @@ -from tops.config import LazyCall as L -from ..generators.stylegan_unet import generator -from ..datasets.fdh import data -from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc -from ..defaults import train, common, EMA - -train.max_images_to_train = int(50e6) -G_optim.lr = 0.002 -D_optim.lr = 0.002 -generator.input_cse = False -data.load_embeddings = False -common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt" -common.model_md5sum = "fda0d809741bc67487abada793975c37" -generator.fix_errors = False \ No newline at end of file diff --git a/configs/generators/stylegan_unet.py b/configs/generators/stylegan_unet.py deleted file mode 100644 index 638859263a1cb549f533b75b2b19609665b3443e..0000000000000000000000000000000000000000 --- a/configs/generators/stylegan_unet.py +++ /dev/null @@ -1,22 +0,0 @@ -from dp2.generator.stylegan_unet import StyleGANUnet -from tops.config import LazyCall as L - -generator = L(StyleGANUnet)( - imsize="${data.imsize}", - im_channels="${data.im_channels}", - min_fmap_resolution=8, - cnum=64, - max_cnum_mul=8, - n_middle_blocks=0, - z_channels=512, - mask_output=True, - conv_clamp=256, - input_cse=True, - scale_grad=True, - cse_nc="${data.cse_nc}", - w_dim=512, - n_keypoints="${data.n_keypoints}", - input_keypoints=False, - input_keypoint_indices=[], - fix_errors=True -) \ No newline at end of file diff --git a/deep_privacy2 b/deep_privacy2 new file mode 160000 index 0000000000000000000000000000000000000000..37dcbeb23a1f51121d53bcd80d32d086d6822b7b --- /dev/null +++ b/deep_privacy2 @@ -0,0 +1 @@ +Subproject commit 37dcbeb23a1f51121d53bcd80d32d086d6822b7b diff --git a/dp2/__init__.py b/dp2/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/dp2/anonymizer/__init__.py b/dp2/anonymizer/__init__.py deleted file mode 100644 index 3fb33d7e6ad3b247938dc20ab2311728f286eb14..0000000000000000000000000000000000000000 --- a/dp2/anonymizer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .anonymizer import Anonymizer \ No newline at end of file diff --git a/dp2/anonymizer/anonymizer.py b/dp2/anonymizer/anonymizer.py deleted file mode 100644 index d850384b3fa33b08b5d5b6770b584c5b64adce44..0000000000000000000000000000000000000000 --- a/dp2/anonymizer/anonymizer.py +++ /dev/null @@ -1,159 +0,0 @@ -from pathlib import Path -from typing import Union, Optional -import numpy as np -import torch -import tops -import torchvision.transforms.functional as F -from motpy import Detection, MultiObjectTracker -from dp2.utils import load_config -from dp2.infer import build_trained_generator -from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection - - -def load_generator_from_cfg_path(cfg_path: Union[str, Path]): - cfg = load_config(cfg_path) - G = build_trained_generator(cfg) - tops.logger.log(f"Loaded generator from: {cfg_path}") - return G - - -def resize_batch(img, mask, maskrcnn_mask, condition, imsize, **kwargs): - img = F.resize(img, imsize, antialias=True) - mask = (F.resize(mask, imsize, antialias=True) > 0.99).float() - maskrcnn_mask = (F.resize(maskrcnn_mask, imsize, antialias=True) > 0.5).float() - - condition = img * mask - return dict(img=img, mask=mask, maskrcnn_mask=maskrcnn_mask, condition=condition) - - -class Anonymizer: - - def __init__( - self, - detector, - load_cache: bool, - person_G_cfg: Optional[Union[str, Path]] = None, - cse_person_G_cfg: Optional[Union[str, Path]] = None, - face_G_cfg: Optional[Union[str, Path]] = None, - car_G_cfg: Optional[Union[str, Path]] = None, - ) -> None: - self.detector = detector - self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]} - self.load_cache = load_cache - if cse_person_G_cfg is not None: - self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg) - if person_G_cfg is not None: - self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg) - if face_G_cfg is not None: - self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg) - if car_G_cfg is not None: - self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg) - - def initialize_tracker(self, fps: float): - self.tracker = MultiObjectTracker(dt=1/fps) - self.track_to_z_idx = dict() - self.cur_z_idx = 0 - - @torch.no_grad() - def anonymize_detections(self, - im, detection, truncation_value: float, - multi_modal_truncation: bool, amp: bool, z_idx, - all_styles=None, - update_identity=None, - ): - G = self.generators[type(detection)] - if G is None: - return im - C, H, W = im.shape - orig_im = im.clone() - if update_identity is None: - update_identity = [True for i in range(len(detection))] - for idx in range(len(detection)): - if not update_identity[idx]: - continue - batch = detection.get_crop(idx, im) - x0, y0, x1, y1 = batch.pop("boxes")[0] - batch = {k: tops.to_cuda(v) for k, v in batch.items()} - batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255]) - batch["img"] = batch["img"].float() - batch["condition"] = batch["mask"] * batch["img"] - orig_shape = None - if G.imsize and batch["img"].shape[-1] != G.imsize[-1] and batch["img"].shape[-2] != G.imsize[-2]: - orig_shape = batch["img"].shape[-2:] - batch = resize_batch(**batch, imsize=G.imsize) - with torch.cuda.amp.autocast(amp): - if all_styles is not None: - anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"] - elif multi_modal_truncation and hasattr(G, "multi_modal_truncate") and hasattr(G.style_net, "w_centers"): - w_indices = None - if z_idx is not None: - w_indices = [z_idx[idx] % len(G.style_net.w_centers)] - anonymized_im = G.multi_modal_truncate( - **batch, truncation_value=truncation_value, - w_indices=w_indices)["img"] - else: - z = None - if z_idx is not None: - state = np.random.RandomState(seed=z_idx[idx]) - z = state.normal(size=(1, G.z_channels)) - z = tops.to_cuda(torch.from_numpy(z)) - anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"] - if orig_shape is not None: - anonymized_im = F.resize(anonymized_im, orig_shape, antialias=True) - anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255).round().byte() - - # Resize and denormalize image - gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), antialias=True) - mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0) - # Remove padding - pad = [max(-x0,0), max(-y0,0)] - pad = [*pad, max(x1-W,0), max(y1-H,0)] - remove_pad = lambda x: x[...,pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]] - gim = remove_pad(gim) - mask = remove_pad(mask) - x0, y0 = max(x0, 0), max(y0, 0) - x1, y1 = min(x1, W), min(y1, H) - mask = mask.logical_not()[None].repeat(3, 1, 1) - im[:, y0:y1, x0:x1][mask] = gim[mask] - - return im - - def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor: - all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache) - for det in all_detections: - im = det.visualize(im) - return im - - @torch.no_grad() - def forward(self, im: torch.Tensor, cache_id: str = None, track=True, **synthesis_kwargs) -> torch.Tensor: - assert im.dtype == torch.uint8 - im = tops.to_cuda(im) - all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache) - if hasattr(self, "tracker") and track: - [_.pre_process() for _ in all_detections] - import numpy as np - boxes = np.concatenate([_.boxes for _ in all_detections]) - boxes = [Detection(box) for box in boxes] - self.tracker.step(boxes) - track_ids = self.tracker.detections_matched_ids - z_idx = [] - for track_id in track_ids: - if track_id not in self.track_to_z_idx: - self.track_to_z_idx[track_id] = self.cur_z_idx - self.cur_z_idx += 1 - z_idx.append(self.track_to_z_idx[track_id]) - z_idx = np.array(z_idx) - idx_offset = 0 - - for detection in all_detections: - zs = None - if hasattr(self, "tracker") and track: - zs = z_idx[idx_offset:idx_offset+len(detection)] - idx_offset += len(detection) - im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs) - - return im.cpu() - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - diff --git a/dp2/data/__init__.py b/dp2/data/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/dp2/data/build.py b/dp2/data/build.py deleted file mode 100644 index 07f2a4e3630b405bf5a84f6926a733044345d741..0000000000000000000000000000000000000000 --- a/dp2/data/build.py +++ /dev/null @@ -1,148 +0,0 @@ -import io -import torch -import tops -from .utils import collate_fn, jpg_decoder, get_num_workers, png_decoder - -def get_dataloader( - dataset, gpu_transform: torch.nn.Module, - num_workers, - batch_size, - infinite: bool, - drop_last: bool, - prefetch_factor: int, - shuffle, - channels_last=False - ): - sampler = None - dl_kwargs = dict( - pin_memory=True, - ) - if infinite: - sampler = tops.InfiniteSampler( - dataset, rank=tops.rank(), - num_replicas=tops.world_size(), - shuffle=shuffle - ) - elif tops.world_size() > 1: - sampler = torch.utils.data.DistributedSampler( - dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank()) - dl_kwargs["drop_last"] = drop_last - else: - dl_kwargs["shuffle"] = shuffle - dl_kwargs["drop_last"] = drop_last - dataloader = torch.utils.data.DataLoader( - dataset, sampler=sampler, collate_fn=collate_fn, - batch_size=batch_size, - num_workers=num_workers, prefetch_factor=prefetch_factor, - **dl_kwargs - ) - dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last) - return dataloader - - -def get_dataloader_places2_wds( - path, - batch_size: int, - num_workers: int, - transform: torch.nn.Module, - gpu_transform: torch.nn.Module, - infinite: bool, - shuffle: bool, - partial_batches: bool, - sample_shuffle=10_000, - tar_shuffle=100, - channels_last=False, - ): - import webdataset as wds - import os - os.environ["RANK"] = str(tops.rank()) - os.environ["WORLD_SIZE"] = str(tops.world_size()) - - if infinite: - pipeline = [wds.ResampledShards(str(path))] - else: - pipeline = [wds.SimpleShardList(str(path))] - if shuffle: - pipeline.append(wds.shuffle(tar_shuffle)) - pipeline.extend([ - wds.split_by_node, - wds.split_by_worker, - ]) - if shuffle: - pipeline.append(wds.shuffle(sample_shuffle)) - - pipeline.extend([ - wds.tarfile_to_samples(), - wds.decode("torchrgb8"), - wds.rename_keys(["img", "jpg"], ["__key__", "__key__"]), - ]) - if transform is not None: - pipeline.append(wds.map(transform)) - pipeline.extend([ - wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), - ]) - pipeline = wds.DataPipeline(*pipeline) - if infinite: - pipeline = pipeline.repeat(nepochs=1000000) - loader = wds.WebLoader( - pipeline, batch_size=None, shuffle=False, - num_workers=get_num_workers(num_workers), - persistent_workers=True, - ) - loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False) - return loader - - - - -def get_dataloader_celebAHQ_wds( - path, - batch_size: int, - num_workers: int, - transform: torch.nn.Module, - gpu_transform: torch.nn.Module, - infinite: bool, - shuffle: bool, - partial_batches: bool, - sample_shuffle=10_000, - tar_shuffle=100, - channels_last=False, - ): - import webdataset as wds - import os - os.environ["RANK"] = str(tops.rank()) - os.environ["WORLD_SIZE"] = str(tops.world_size()) - - if infinite: - pipeline = [wds.ResampledShards(str(path))] - else: - pipeline = [wds.SimpleShardList(str(path))] - if shuffle: - pipeline.append(wds.shuffle(tar_shuffle)) - pipeline.extend([ - wds.split_by_node, - wds.split_by_worker, - ]) - if shuffle: - pipeline.append(wds.shuffle(sample_shuffle)) - - pipeline.extend([ - wds.tarfile_to_samples(), - wds.decode(wds.handle_extension(".png", png_decoder)), - wds.rename_keys(["img", "png"], ["__key__", "__key__"]), - ]) - if transform is not None: - pipeline.append(wds.map(transform)) - pipeline.extend([ - wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), - ]) - pipeline = wds.DataPipeline(*pipeline) - if infinite: - pipeline = pipeline.repeat(nepochs=1000000) - loader = wds.WebLoader( - pipeline, batch_size=None, shuffle=False, - num_workers=get_num_workers(num_workers), - persistent_workers=True, - ) - loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last) - return loader diff --git a/dp2/data/datasets/__init__.py b/dp2/data/datasets/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/dp2/data/datasets/coco_cse.py b/dp2/data/datasets/coco_cse.py deleted file mode 100644 index 27fa6dfb94f118b939788e0134e6ac42c613b297..0000000000000000000000000000000000000000 --- a/dp2/data/datasets/coco_cse.py +++ /dev/null @@ -1,148 +0,0 @@ -import pickle -import torchvision -import torch -import pathlib -import numpy as np -from typing import Callable, Optional, Union -from torch.hub import get_dir as get_hub_dir - - -def cache_embed_stats(embed_map: torch.Tensor): - mean = embed_map.mean(dim=0, keepdim=True) - rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() - - cache = dict(mean=mean, rstd=rstd, embed_map=embed_map) - path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch") - path.parent.mkdir(exist_ok=True, parents=True) - torch.save(cache, path) - - -class CocoCSE(torch.utils.data.Dataset): - - def __init__(self, - dirpath: Union[str, pathlib.Path], - transform: Optional[Callable], - normalize_E: bool,): - dirpath = pathlib.Path(dirpath) - self.dirpath = dirpath - - self.transform = transform - assert self.dirpath.is_dir(),\ - f"Did not find dataset at: {dirpath}" - self.image_paths, self.embedding_paths = self._load_impaths() - self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy"))) - mean = self.embed_map.mean(dim=0, keepdim=True) - rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() - self.embed_map = (self.embed_map - mean) * rstd - cache_embed_stats(self.embed_map) - - def _load_impaths(self): - image_dir = self.dirpath.joinpath("images") - image_paths = list(image_dir.glob("*.png")) - image_paths.sort() - embedding_paths = [ - self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths - ] - return image_paths, embedding_paths - - def __len__(self): - return len(self.image_paths) - - def __getitem__(self, idx): - im = torchvision.io.read_image(str(self.image_paths[idx])) - vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1) - vertices = torch.from_numpy(vertices.squeeze()).long() - mask = torch.from_numpy(mask.squeeze()).float() - border = torch.from_numpy(border.squeeze()).float() - E_mask = 1 - mask - border - batch = { - "img": im, - "vertices": vertices[None], - "mask": mask[None], - "embed_map": self.embed_map, - "border": border[None], - "E_mask": E_mask[None] - } - if self.transform is None: - return batch - return self.transform(batch) - - -class CocoCSEWithFace(CocoCSE): - - def __init__(self, - dirpath: Union[str, pathlib.Path], - transform: Optional[Callable], - **kwargs): - super().__init__(dirpath, transform, **kwargs) - with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp: - self.face_boxes = pickle.load(fp) - - def __getitem__(self, idx): - item = super().__getitem__(idx) - item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name] - return item - - -class CocoCSESemantic(torch.utils.data.Dataset): - - def __init__(self, - dirpath: Union[str, pathlib.Path], - transform: Optional[Callable], - **kwargs): - dirpath = pathlib.Path(dirpath) - self.dirpath = dirpath - - self.transform = transform - assert self.dirpath.is_dir(),\ - f"Did not find dataset at: {dirpath}" - self.image_paths, self.embedding_paths = self._load_impaths() - self.vertx2cat = torch.from_numpy(np.load(self.dirpath.parent.joinpath("vertx2cat.npy"))) - self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy"))) - - def _load_impaths(self): - image_dir = self.dirpath.joinpath("images") - image_paths = list(image_dir.glob("*.png")) - image_paths.sort() - embedding_paths = [ - self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths - ] - return image_paths, embedding_paths - - def __len__(self): - return len(self.image_paths) - - def __getitem__(self, idx): - im = torchvision.io.read_image(str(self.image_paths[idx])) - vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1) - vertices = torch.from_numpy(vertices.squeeze()).long() - mask = torch.from_numpy(mask.squeeze()).float() - border = torch.from_numpy(border.squeeze()).float() - E_mask = 1 - mask - border - batch = { - "img": im, - "vertices": vertices[None], - "mask": mask[None], - "border": border[None], - "vertx2cat": self.vertx2cat, - "embed_map": self.embed_map, - } - if self.transform is None: - return batch - return self.transform(batch) - - -class CocoCSESemanticWithFace(CocoCSESemantic): - - def __init__(self, - dirpath: Union[str, pathlib.Path], - transform: Optional[Callable], - **kwargs): - super().__init__(dirpath, transform, **kwargs) - with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp: - self.face_boxes = pickle.load(fp) - - def __getitem__(self, idx): - item = super().__getitem__(idx) - item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name] - return item diff --git a/dp2/data/datasets/fdf.py b/dp2/data/datasets/fdf.py deleted file mode 100644 index b05c75692515c5e294af72417371af9fcefbfad8..0000000000000000000000000000000000000000 --- a/dp2/data/datasets/fdf.py +++ /dev/null @@ -1,129 +0,0 @@ -import pathlib -from typing import Tuple -import numpy as np -import torch -import pathlib -try: - import pyspng - PYSPNG_IMPORTED = True -except ImportError: - PYSPNG_IMPORTED = False - print("Could not load pyspng. Defaulting to pillow image backend.") - from PIL import Image -from tops import logger - - -class FDFDataset: - - def __init__(self, - dirpath, - imsize: Tuple[int], - load_keypoints: bool, - transform): - dirpath = pathlib.Path(dirpath) - self.dirpath = dirpath - self.transform = transform - self.imsize = imsize[0] - self.load_keypoints = load_keypoints - assert self.dirpath.is_dir(),\ - f"Did not find dataset at: {dirpath}" - image_dir = self.dirpath.joinpath("images", str(self.imsize)) - self.image_paths = list(image_dir.glob("*.png")) - assert len(self.image_paths) > 0,\ - f"Did not find images in: {image_dir}" - self.image_paths.sort(key=lambda x: int(x.stem)) - self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32) - - self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch")) - assert len(self.image_paths) == len(self.bounding_boxes) - assert len(self.image_paths) == len(self.landmarks) - logger.log( - f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}") - - def get_mask(self, idx): - mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool) - bounding_box = self.bounding_boxes[idx] - x0, y0, x1, y1 = bounding_box - mask[:, y0:y1, x0:x1] = 0 - return mask - - def __len__(self): - return len(self.image_paths) - - def __getitem__(self, index): - impath = self.image_paths[index] - if PYSPNG_IMPORTED: - with open(impath, "rb") as fp: - im = pyspng.load(fp.read()) - else: - with Image.open(impath) as fp: - im = np.array(fp) - im = torch.from_numpy(np.rollaxis(im, -1, 0)) - masks = self.get_mask(index) - landmark = self.landmarks[index] - batch = { - "img": im, - "mask": masks, - } - if self.load_keypoints: - batch["keypoints"] = landmark - if self.transform is None: - return batch - return self.transform(batch) - - -class FDF256Dataset: - - def __init__(self, - dirpath, - load_keypoints: bool, - transform): - dirpath = pathlib.Path(dirpath) - self.dirpath = dirpath - self.transform = transform - self.load_keypoints = load_keypoints - assert self.dirpath.is_dir(),\ - f"Did not find dataset at: {dirpath}" - image_dir = self.dirpath.joinpath("images") - self.image_paths = list(image_dir.glob("*.png")) - assert len(self.image_paths) > 0,\ - f"Did not find images in: {image_dir}" - self.image_paths.sort(key=lambda x: int(x.stem)) - self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32) - self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy"))) - assert len(self.image_paths) == len(self.bounding_boxes) - assert len(self.image_paths) == len(self.landmarks) - logger.log( - f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}") - - def get_mask(self, idx): - mask = torch.ones((1, 256, 256), dtype=torch.bool) - bounding_box = self.bounding_boxes[idx] - x0, y0, x1, y1 = bounding_box - mask[:, y0:y1, x0:x1] = 0 - return mask - - def __len__(self): - return len(self.image_paths) - - def __getitem__(self, index): - impath = self.image_paths[index] - if PYSPNG_IMPORTED: - with open(impath, "rb") as fp: - im = pyspng.load(fp.read()) - else: - with Image.open(impath) as fp: - im = np.array(fp) - im = torch.from_numpy(np.rollaxis(im, -1, 0)) - masks = self.get_mask(index) - landmark = self.landmarks[index] - batch = { - "img": im, - "mask": masks, - } - if self.load_keypoints: - batch["keypoints"] = landmark - if self.transform is None: - return batch - return self.transform(batch) - diff --git a/dp2/data/datasets/fdh.py b/dp2/data/datasets/fdh.py deleted file mode 100644 index 5eb654ba71604f1b088df586e6683418033b4b16..0000000000000000000000000000000000000000 --- a/dp2/data/datasets/fdh.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -import tops -import numpy as np -import io -import webdataset as wds -import os -from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn - - -def kp_decoder(x): - # Keypoints are between [0, 1] for webdataset - keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float() - keypoints[:, 0] /= 160 - keypoints[:, 1] /= 288 - check_outside = lambda x: (x < 0).logical_or(x > 1) - is_outside = check_outside(keypoints[:, 0]).logical_or( - check_outside(keypoints[:, 1]) - ) - keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not()) - return keypoints - - -def vertices_decoder(x): - vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32)) - return vertices.squeeze()[None] - - -def get_dataloader_fdh_wds( - path, - batch_size: int, - num_workers: int, - transform: torch.nn.Module, - gpu_transform: torch.nn.Module, - infinite: bool, - shuffle: bool, - partial_batches: bool, - load_embedding: bool, - sample_shuffle=10_000, - tar_shuffle=100, - read_condition=False, - channels_last=False, - ): - # Need to set this for split_by_node to work. - os.environ["RANK"] = str(tops.rank()) - os.environ["WORLD_SIZE"] = str(tops.world_size()) - if infinite: - pipeline = [wds.ResampledShards(str(path))] - else: - pipeline = [wds.SimpleShardList(str(path))] - if shuffle: - pipeline.append(wds.shuffle(tar_shuffle)) - pipeline.extend([ - wds.split_by_node, - wds.split_by_worker, - ]) - if shuffle: - pipeline.append(wds.shuffle(sample_shuffle)) - - decoder = [ - wds.handle_extension("image.png", png_decoder), - wds.handle_extension("mask.png", mask_decoder), - wds.handle_extension("maskrcnn_mask.png", mask_decoder), - wds.handle_extension("keypoints.npy", kp_decoder), - ] - - rename_keys = [ - ["img", "image.png"], ["mask", "mask.png"], - ["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"] - ] - if load_embedding: - decoder.extend([ - wds.handle_extension("vertices.npy", vertices_decoder), - wds.handle_extension("E_mask.png", mask_decoder) - ]) - rename_keys.extend([ - ["vertices", "vertices.npy"], - ["E_mask", "e_mask.png"] - ]) - - if read_condition: - decoder.append( - wds.handle_extension("condition.png", png_decoder) - ) - rename_keys.append(["condition", "condition.png"]) - - pipeline.extend([ - wds.tarfile_to_samples(), - wds.decode(*decoder), - wds.rename_keys(*rename_keys), - wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), - ]) - if transform is not None: - pipeline.append(wds.map(transform)) - pipeline = wds.DataPipeline(*pipeline) - if infinite: - pipeline = pipeline.repeat(nepochs=1000000) - - loader = wds.WebLoader( - pipeline, batch_size=None, shuffle=False, - num_workers=get_num_workers(num_workers), - persistent_workers=True, - ) - loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False) - return loader diff --git a/dp2/data/transforms/__init__.py b/dp2/data/transforms/__init__.py deleted file mode 100644 index 66a9e160b6513cc81a00b0a62606721f37b70c61..0000000000000000000000000000000000000000 --- a/dp2/data/transforms/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize -from .stylegan2_transform import StyleGANAugmentPipe \ No newline at end of file diff --git a/dp2/data/transforms/functional.py b/dp2/data/transforms/functional.py deleted file mode 100644 index 6d5c695944a47d67cac7e03a8a7f5b400c94417b..0000000000000000000000000000000000000000 --- a/dp2/data/transforms/functional.py +++ /dev/null @@ -1,61 +0,0 @@ -import torchvision.transforms.functional as F -import torch -import pickle -from tops import download_file, assert_shape -from typing import Dict -from functools import lru_cache - -global symmetry_transform - -@lru_cache(maxsize=1) -def get_symmetry_transform(symmetry_url): - file_name = download_file(symmetry_url) - with open(file_name, "rb") as fp: - symmetry = pickle.load(fp) - return torch.from_numpy(symmetry["vertex_transforms"]).long() - - -hflip_handled_cases = set([ - "keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition", - "embedding", "vertx2cat", "maskrcnn_mask", "__key__", - "img_hr", "condition_hr", "mask_hr"]) - -def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]: - container["img"] = F.hflip(container["img"]) - if "condition" in container: - container["condition"] = F.hflip(container["condition"]) - if "embedding" in container: - container["embedding"] = F.hflip(container["embedding"]) - assert all([key in hflip_handled_cases for key in container]), container.keys() - if "keypoints" in container: - assert flip_map is not None - if container["keypoints"].ndim == 3: - keypoints = container["keypoints"][:, flip_map, :] - keypoints[:, :, 0] = 1 - keypoints[:, :, 0] - else: - assert_shape(container["keypoints"], (None, 3)) - keypoints = container["keypoints"][flip_map, :] - keypoints[:, 0] = 1 - keypoints[:, 0] - container["keypoints"] = keypoints - if "mask" in container: - container["mask"] = F.hflip(container["mask"]) - if "border" in container: - container["border"] = F.hflip(container["border"]) - if "semantic_mask" in container: - container["semantic_mask"] = F.hflip(container["semantic_mask"]) - if "vertices" in container: - symmetry_transform = get_symmetry_transform("https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl") - container["vertices"] = F.hflip(container["vertices"]) - symmetry_transform_ = symmetry_transform.to(container["vertices"].device) - container["vertices"] = symmetry_transform_[container["vertices"].long()] - if "E_mask" in container: - container["E_mask"] = F.hflip(container["E_mask"]) - if "maskrcnn_mask" in container: - container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"]) - if "img_hr" in container: - container["img_hr"] = F.hflip(container["img_hr"]) - if "condition_hr" in container: - container["condition_hr"] = F.hflip(container["condition_hr"]) - if "mask_hr" in container: - container["mask_hr"] = F.hflip(container["mask_hr"]) - return container diff --git a/dp2/data/transforms/stylegan2_transform.py b/dp2/data/transforms/stylegan2_transform.py deleted file mode 100644 index 49a143cddf9673d079b87ac7d725c433713e54c5..0000000000000000000000000000000000000000 --- a/dp2/data/transforms/stylegan2_transform.py +++ /dev/null @@ -1,394 +0,0 @@ -import numpy as np -import scipy.signal -import torch -try: - from sg3_torch_utils import misc - from sg3_torch_utils.ops import upfirdn2d - from sg3_torch_utils.ops import grid_sample_gradfix - from sg3_torch_utils.ops import conv2d_gradfix -except: - pass -#---------------------------------------------------------------------------- -# Coefficients of various wavelet decomposition low-pass filters. - -wavelets = { - 'haar': [0.7071067811865476, 0.7071067811865476], - 'db1': [0.7071067811865476, 0.7071067811865476], - 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], - 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], - 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], - 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], - 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], - 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], - 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], - 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], - 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], - 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], - 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], - 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], - 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], - 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], -} - -#---------------------------------------------------------------------------- -# Helpers for constructing transformation matrices. - - -def matrix(*rows, device=None): - assert all(len(row) == len(rows[0]) for row in rows) - elems = [x for row in rows for x in row] - ref = [x for x in elems if isinstance(x, torch.Tensor)] - if len(ref) == 0: - return misc.constant(np.asarray(rows), device=device) - assert device is None or device == ref[0].device - elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] - return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) - - -def translate2d(tx, ty, **kwargs): - return matrix( - [1, 0, tx], - [0, 1, ty], - [0, 0, 1], - **kwargs) - - -def translate3d(tx, ty, tz, **kwargs): - return matrix( - [1, 0, 0, tx], - [0, 1, 0, ty], - [0, 0, 1, tz], - [0, 0, 0, 1], - **kwargs) - - -def scale2d(sx, sy, **kwargs): - return matrix( - [sx, 0, 0], - [0, sy, 0], - [0, 0, 1], - **kwargs) - - -def scale3d(sx, sy, sz, **kwargs): - return matrix( - [sx, 0, 0, 0], - [0, sy, 0, 0], - [0, 0, sz, 0], - [0, 0, 0, 1], - **kwargs) - - -def rotate2d(theta, **kwargs): - return matrix( - [torch.cos(theta), torch.sin(-theta), 0], - [torch.sin(theta), torch.cos(theta), 0], - [0, 0, 1], - **kwargs) - - -def rotate3d(v, theta, **kwargs): - vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] - s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c - return matrix( - [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], - [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], - [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], - [0, 0, 0, 1], - **kwargs) - - -def translate2d_inv(tx, ty, **kwargs): - return translate2d(-tx, -ty, **kwargs) - - -def scale2d_inv(sx, sy, **kwargs): - return scale2d(1 / sx, 1 / sy, **kwargs) - - -def rotate2d_inv(theta, **kwargs): - return rotate2d(-theta, **kwargs) - - -class StyleGANAugmentPipe(torch.nn.Module): - def __init__(self, - rotate90=0, xint=0, xint_max=0.125, - scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125, - brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, - hue_max=1, saturation_std=1, - imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1, - ): - super().__init__() - self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability. - - # Pixel blitting. - self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations. - self.xint = float(xint) # Probability multiplier for integer translation. - self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions. - - # General geometric transformations. - self.scale = float(scale) # Probability multiplier for isotropic scaling. - self.rotate = float(rotate) # Probability multiplier for arbitrary rotation. - self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. - self.xfrac = float(xfrac) # Probability multiplier for fractional translation. - self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. - self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle. - self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. - self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions. - - # Color transformations. - self.brightness = float(brightness) # Probability multiplier for brightness. - self.contrast = float(contrast) # Probability multiplier for contrast. - self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. - self.hue = float(hue) # Probability multiplier for hue rotation. - self.saturation = float(saturation) # Probability multiplier for saturation. - self.brightness_std = float(brightness_std) # Standard deviation of brightness. - self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. - self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. - self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. - - # Image-space filtering. - self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering. - self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands. - self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification. - - # Setup orthogonal lowpass filter for geometric augmentations. - self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6'])) - - # Construct filter bank for image-space filtering. - Hz_lo = np.asarray(wavelets['sym2']) # H(z) - Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z) - Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2 - Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2 - Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i) - for i in range(1, Hz_fbank.shape[0]): - Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1] - Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) - Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2 - self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32)) - - def forward(self, batch, debug_percentile=None): - images = batch["img"] - batch["vertices"] = batch["vertices"].float() - assert isinstance(images, torch.Tensor) and images.ndim == 4 - batch_size, num_channels, height, width = images.shape - device = images.device - self.Hz_fbank = self.Hz_fbank.to(device) - self.Hz_geom = self.Hz_geom.to(device) - if debug_percentile is not None: - debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device) - - # ------------------------------------- - # Select parameters for pixel blitting. - # ------------------------------------- - - # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in - I_3 = torch.eye(3, device=device) - G_inv = I_3 - - # Apply integer translation with probability (xint * strength). - if self.xint > 0: - t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max - t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t)) - if debug_percentile is not None: - t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) - G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height)) - - # -------------------------------------------------------- - # Select parameters for general geometric transformations. - # -------------------------------------------------------- - - # Apply isotropic scaling with probability (scale * strength). - if self.scale > 0: - s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std) - s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s)) - if debug_percentile is not None: - s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std)) - G_inv = G_inv @ scale2d_inv(s, s) - - # Apply pre-rotation with probability p_rot. - p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p - if self.rotate > 0: - theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max - theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) - if debug_percentile is not None: - theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max) - G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. - - # Apply anisotropic scaling with probability (aniso * strength). - if self.aniso > 0: - s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std) - s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s)) - if debug_percentile is not None: - s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std)) - G_inv = G_inv @ scale2d_inv(s, 1 / s) - - # Apply post-rotation with probability p_rot. - if self.rotate > 0: - theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max - theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) - if debug_percentile is not None: - theta = torch.zeros_like(theta) - G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. - - # Apply fractional translation with probability (xfrac * strength). - if self.xfrac > 0: - t = torch.randn([batch_size, 2], device=device) * self.xfrac_std - t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t)) - if debug_percentile is not None: - t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) - G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height) - - # ---------------------------------- - # Execute geometric transformations. - # ---------------------------------- - - # Execute if the transform is not identity. - if G_inv is not I_3: - # Calculate padding. - cx = (width - 1) / 2 - cy = (height - 1) / 2 - cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] - cp = G_inv @ cp.t() # [batch, xyz, idx] - Hz_pad = self.Hz_geom.shape[0] // 4 - margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] - margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] - margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) - margin = margin.max(misc.constant([0, 0] * 2, device=device)) - margin = margin.min(misc.constant([width-1, height-1] * 2, device=device)) - mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) - - # Pad image and adjust origin. - images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') - batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0) - batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0) - batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0) - G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv - - # Upsample. - images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) - batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest") - batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest") - batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest") - G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) - G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) - - # Execute transformation. - shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2] - G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) - grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) - images = grid_sample_gradfix.grid_sample(images, grid) - - batch["mask"] = torch.nn.functional.grid_sample( - input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False) - batch["E_mask"] = torch.nn.functional.grid_sample( - input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False) - batch["vertices"] = torch.nn.functional.grid_sample( - input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False) - - - # Downsample and crop. - images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True) - batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False) - batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False) - batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False) - # -------------------------------------------- - # Select parameters for color transformations. - # -------------------------------------------- - - # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out - I_4 = torch.eye(4, device=device) - C = I_4 - - # Apply brightness with probability (brightness * strength). - if self.brightness > 0: - b = torch.randn([batch_size], device=device) * self.brightness_std - b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b)) - if debug_percentile is not None: - b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std) - C = translate3d(b, b, b) @ C - - # Apply contrast with probability (contrast * strength). - if self.contrast > 0: - c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std) - c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c)) - if debug_percentile is not None: - c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std)) - C = scale3d(c, c, c) @ C - - # Apply luma flip with probability (lumaflip * strength). - v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. - - # Apply hue rotation with probability (hue * strength). - if self.hue > 0 and num_channels > 1: - theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max - theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta)) - if debug_percentile is not None: - theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max) - C = rotate3d(v, theta) @ C # Rotate around v. - - # Apply saturation with probability (saturation * strength). - if self.saturation > 0 and num_channels > 1: - s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std) - s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s)) - if debug_percentile is not None: - s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std)) - C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C - - # ------------------------------ - # Execute color transformations. - # ------------------------------ - - # Execute if the transform is not identity. - if C is not I_4: - images = images.reshape([batch_size, num_channels, height * width]) - if num_channels == 3: - images = C[:, :3, :3] @ images + C[:, :3, 3:] - elif num_channels == 1: - C = C[:, :3, :].mean(dim=1, keepdims=True) - images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] - else: - raise ValueError('Image must be RGB (3 channels) or L (1 channel)') - images = images.reshape([batch_size, num_channels, height, width]) - - # ---------------------- - # Image-space filtering. - # ---------------------- - - if self.imgfilter > 0: - num_bands = self.Hz_fbank.shape[0] - assert len(self.imgfilter_bands) == num_bands - expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f). - - # Apply amplification for each band with probability (imgfilter * strength * band_strength). - g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity). - for i, band_strength in enumerate(self.imgfilter_bands): - t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std) - t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i)) - if debug_percentile is not None: - t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i) - t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector. - t[:, i] = t_i # Replace i'th element. - t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power. - g = g * t # Accumulate into global gain. - - # Construct combined amplification filter. - Hz_prime = g @ self.Hz_fbank # [batch, tap] - Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap] - Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap] - - # Apply filter. - p = self.Hz_fbank.shape[1] // 2 - images = images.reshape([1, batch_size * num_channels, height, width]) - images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect') - images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels) - images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels) - images = images.reshape([batch_size, num_channels, height, width]) - - # ------------------------ - # Image-space corruptions. - # ------------------------ - batch["img"] = images - batch["vertices"] = batch["vertices"].long() - batch["border"] = 1 - batch["E_mask"] - batch["mask"] - return batch diff --git a/dp2/data/transforms/transforms.py b/dp2/data/transforms/transforms.py deleted file mode 100644 index 1221a9121d7f59b1ac33c28e4189339e8df6dadf..0000000000000000000000000000000000000000 --- a/dp2/data/transforms/transforms.py +++ /dev/null @@ -1,247 +0,0 @@ -from pathlib import Path -from typing import Dict, List -import torchvision -import torch -import tops -import torchvision.transforms.functional as F -from .functional import hflip - - -class RandomHorizontalFlip(torch.nn.Module): - - def __init__(self, p: float, flip_map=None,**kwargs): - super().__init__() - self.flip_ratio = p - self.flip_map = flip_map - if self.flip_ratio is None: - self.flip_ratio = 0.5 - assert 0 <= self.flip_ratio <= 1 - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - if torch.rand(1) > self.flip_ratio: - return container - return hflip(container, self.flip_map) - - -class CenterCrop(torch.nn.Module): - """ - Performs the transform on the image. - NOTE: Does not transform the mask to improve runtime. - """ - - def __init__(self, size: List[int]): - super().__init__() - self.size = tuple(size) - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - min_size = min(container["img"].shape[1], container["img"].shape[2]) - if min_size < self.size[0]: - container["img"] = F.center_crop(container["img"], min_size) - container["img"] = F.resize(container["img"], self.size) - return container - container["img"] = F.center_crop(container["img"], self.size) - return container - - -class Resize(torch.nn.Module): - """ - Performs the transform on the image. - NOTE: Does not transform the mask to improve runtime. - """ - - def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR): - super().__init__() - self.size = tuple(size) - self.interpolation = interpolation - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True) - if "semantic_mask" in container: - container["semantic_mask"] = F.resize( - container["semantic_mask"], self.size, F.InterpolationMode.NEAREST) - if "embedding" in container: - container["embedding"] = F.resize( - container["embedding"], self.size, self.interpolation) - if "mask" in container: - container["mask"] = F.resize( - container["mask"], self.size, F.InterpolationMode.NEAREST) - if "E_mask" in container: - container["E_mask"] = F.resize( - container["E_mask"], self.size, F.InterpolationMode.NEAREST) - if "maskrcnn_mask" in container: - container["maskrcnn_mask"] = F.resize( - container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST) - if "vertices" in container: - container["vertices"] = F.resize( - container["vertices"], self.size, F.InterpolationMode.NEAREST) - return container - - def __repr__(self): - repr = super().__repr__() - vars_ = dict(size=self.size, interpolation=self.interpolation) - return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()]) - - -class InsertHRImage(torch.nn.Module): - """ - Resizes mask by maxpool and assumes condition is already created - """ - def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR): - super().__init__() - self.size = tuple(size) - self.interpolation = interpolation - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - assert container["img"].dtype == torch.float32 - container["img_hr"] = F.resize(container["img"], self.size, self.interpolation, antialias=True) - container["condition_hr"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True) - mask = container["mask"] > 0 - container["mask_hr"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float() - container["condition_hr"] = container["condition_hr"] * (1 - container["mask_hr"]) + container["img_hr"] * container["mask_hr"] - return container - - def __repr__(self): - repr = super().__repr__() - vars_ = dict(size=self.size, interpolation=self.interpolation) - return repr + " " - - -class CopyHRImage(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - container["img_hr"] = container["img"] - container["condition_hr"] = container["condition"] - container["mask_hr"] = container["mask"] - return container - - -class Resize2(torch.nn.Module): - """ - Resizes mask by maxpool and assumes condition is already created - """ - def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, downsample_condition: bool = True, mask_condition= True): - super().__init__() - self.size = tuple(size) - self.interpolation = interpolation - self.downsample_condition = downsample_condition - self.mask_condition = mask_condition - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: -# assert container["img"].dtype == torch.float32 - container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True) - mask = container["mask"] > 0 - container["mask"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float() - - if self.downsample_condition: - container["condition"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True) - if self.mask_condition: - container["condition"] = container["condition"] * (1 - container["mask"]) + container["img"] * container["mask"] - return container - - def __repr__(self): - repr = super().__repr__() - vars_ = dict(size=self.size, interpolation=self.interpolation) - return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()]) - - - -class Normalize(torch.nn.Module): - """ - Performs the transform on the image. - NOTE: Does not transform the mask to improve runtime. - """ - - def __init__(self, mean, std, inplace, keys=["img"]): - super().__init__() - self.mean = mean - self.std = std - self.inplace = inplace - self.keys = keys - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - for key in self.keys: - container[key] = F.normalize(container[key], self.mean, self.std, self.inplace) - return container - - def __repr__(self): - repr = super().__repr__() - vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace) - return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()]) - - -class ToFloat(torch.nn.Module): - - def __init__(self, keys=["img"], norm=True) -> None: - super().__init__() - self.keys = keys - self.gain = 255 if norm else 1 - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - for key in self.keys: - container[key] = container[key].float() / self.gain - return container - - -class RandomCrop(torchvision.transforms.RandomCrop): - """ - Performs the transform on the image. - NOTE: Does not transform the mask to improve runtime. - """ - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - container["img"] = super().forward(container["img"]) - return container - - -class CreateCondition(torch.nn.Module): - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - if container["img"].dtype == torch.uint8: - container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127 - return container - container["condition"] = container["img"] * container["mask"] - return container - - -class CreateEmbedding(torch.nn.Module): - - def __init__(self, embed_path: Path, cuda=True) -> None: - super().__init__() - self.embed_map = torch.load(embed_path, map_location=torch.device("cpu")) - if cuda: - self.embed_map = tops.to_cuda(self.embed_map) - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - vertices = container["vertices"] - if vertices.ndim == 3: - embedding = self.embed_map[vertices.long()].squeeze(dim=0) - embedding = embedding.permute(2, 0, 1) * container["E_mask"] - pass - else: - assert vertices.ndim == 4 - embedding = self.embed_map[vertices.long()].squeeze(dim=1) - embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"] - container["embedding"] = embedding - container["embed_map"] = self.embed_map.clone() - return container - - -class UpdateMask(torch.nn.Module): - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - container["mask"] = (container["img"] == container["condition"]).any(dim=1, keepdims=True).float() - return container - - -class LoadClassEmbedding(torch.nn.Module): - - def __init__(self, embedding_path: Path) -> None: - super().__init__() - self.embedding = torch.load(embedding_path, map_location="cpu") - - def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - key = "_".join(container["__key__"].split("train/")[-1].split("/")[:-1]) - container["class_embedding"] = self.embedding[key].view(-1) - return container diff --git a/dp2/data/utils.py b/dp2/data/utils.py deleted file mode 100644 index 1ec03f0ec0091e3263f5aa9b2962ad97a1c36a0f..0000000000000000000000000000000000000000 --- a/dp2/data/utils.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -from PIL import Image -import numpy as np -import multiprocessing -import io -from tops import logger -from torch.utils.data._utils.collate import default_collate - -try: - import pyspng - - PYSPNG_IMPORTED = True -except ImportError: - PYSPNG_IMPORTED = False - print("Could not load pyspng. Defaulting to pillow image backend.") - from PIL import Image - - -def get_coco_keypoints(): - return [ - "nose", - "left_eye", - "right_eye", - "left_ear", - "right_ear", - "left_shoulder", - "right_shoulder", - "left_elbow", - "right_elbow", - "left_wrist", - "right_wrist", - "left_hip", - "right_hip", - "left_knee", - "right_knee", - "left_ankle", - "right_ankle", - ] - - -def get_coco_flipmap(): - keypoints = get_coco_keypoints() - keypoint_flip_map = { - "left_eye": "right_eye", - "left_ear": "right_ear", - "left_shoulder": "right_shoulder", - "left_elbow": "right_elbow", - "left_wrist": "right_wrist", - "left_hip": "right_hip", - "left_knee": "right_knee", - "left_ankle": "right_ankle", - } - for key, value in list(keypoint_flip_map.items()): - keypoint_flip_map[value] = key - keypoint_flip_map["nose"] = "nose" - keypoint_flip_map_idx = [] - for source in keypoints: - keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source])) - return keypoint_flip_map_idx - - -def mask_decoder(x): - mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None] - mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255. - return mask - - -def png_decoder(x): - if PYSPNG_IMPORTED: - return torch.from_numpy(np.rollaxis(pyspng.load(x), 2)) - with Image.open(io.BytesIO(x)) as im: - im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2)) - return im - - -def jpg_decoder(x): - with Image.open(io.BytesIO(x)) as im: - im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2)) - return im - - -def get_num_workers(num_workers: int): - n_cpus = multiprocessing.cpu_count() - if num_workers > n_cpus: - logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}") - return n_cpus - return num_workers - - -def collate_fn(batch): - elem = batch[0] - ignore_keys = set(["embed_map", "vertx2cat"]) - batch_ = { - key: default_collate([d[key] for d in batch]) - for key in elem - if key not in ignore_keys - } - if "embed_map" in elem: - batch_["embed_map"] = elem["embed_map"] - if "vertx2cat" in elem: - batch_["vertx2cat"] = elem["vertx2cat"] - return batch_ diff --git a/dp2/detection/__init__.py b/dp2/detection/__init__.py deleted file mode 100644 index 613969b28384cd1c64fc8db685e7622f4cc02615..0000000000000000000000000000000000000000 --- a/dp2/detection/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .cse_mask_face_detector import CSeMaskFaceDetector -from .person_detector import CSEPersonDetector -from .structures import PersonDetection, VehicleDetection, FaceDetection diff --git a/dp2/detection/base.py b/dp2/detection/base.py deleted file mode 100644 index 32ebba893878e95ddc1da451a9f2027799ffa044..0000000000000000000000000000000000000000 --- a/dp2/detection/base.py +++ /dev/null @@ -1,45 +0,0 @@ -import pickle -import torch -import lzma -from pathlib import Path -from tops import logger - - -class BaseDetector: - - - def __init__(self, cache_directory: str) -> None: - if cache_directory is not None: - self.cache_directory = Path(cache_directory, str(self.__class__.__name__)) - self.cache_directory.mkdir(exist_ok=True, parents=True) - - def save_to_cache(self, detection, cache_path: Path, after_preprocess=True): - logger.log(f"Caching detection to: {cache_path}") - with lzma.open(cache_path, "wb") as fp: - torch.save( - [det.state_dict(after_preprocess=after_preprocess) for det in detection], fp, - pickle_protocol=pickle.HIGHEST_PROTOCOL) - - def load_from_cache(self, cache_path: Path): - logger.log(f"Loading detection from cache path: {cache_path}") - with lzma.open(cache_path, "rb") as fp: - state_dict = torch.load(fp) - return [ - state["cls"].from_state_dict(state_dict=state) for state in state_dict - ] - - def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool): - if cache_id is None: - return self.forward(im) - cache_path = self.cache_directory.joinpath(cache_id + ".torch") - if cache_path.is_file() and load_cache: - try: - return self.load_from_cache(cache_path) - except Exception as e: - logger.warn(f"The cache file was corrupted: {cache_path}") - exit() - detections = self.forward(im) - self.save_to_cache(detections, cache_path) - return detections - - \ No newline at end of file diff --git a/dp2/detection/box_utils.py b/dp2/detection/box_utils.py deleted file mode 100644 index 6091b122a1e72d05b9cad2a25f4111b425eabb93..0000000000000000000000000000000000000000 --- a/dp2/detection/box_utils.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np - - -def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio): - x0, y0, x1, y1 = [int(_) for _ in bbox] - h, w = y1 - y0, x1 - x0 - cur_ratio = h / w - - if cur_ratio == target_aspect_ratio: - return [x0, y0, x1, y1] - if cur_ratio < target_aspect_ratio: - target_height = int(w*target_aspect_ratio) - y0, y1 = expand_axis(y0, y1, target_height, imshape[0]) - else: - target_width = int(h/target_aspect_ratio) - x0, x1 = expand_axis(x0, x1, target_width, imshape[1]) - return x0, y0, x1, y1 - - -def expand_axis(start, end, target_width, limit): - # Can return a bbox outside of limit - cur_width = end - start - start = start - (target_width-cur_width)//2 - end = end + (target_width-cur_width)//2 - if end - start != target_width: - end += 1 - assert end - start == target_width - if start < 0 and end > limit: - return start, end - if start < 0 and end < limit: - to_shift = min(0 - start, limit - end) - start += to_shift - end += to_shift - if end > limit and start > 0: - to_shift = min(end - limit, start) - end -= to_shift - start -= to_shift - assert end - start == target_width - return start, end - - -def expand_box(bbox, imshape, mask, percentage_background: float): - assert isinstance(bbox[0], int) - assert 0 < percentage_background < 1 - # Percentage in S - mask_pixels = mask.long().sum().cpu() - total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - percentage_mask = mask_pixels / total_pixels - if (1 - percentage_mask) > percentage_background: - return bbox - target_pixels = mask_pixels / (1 - percentage_background) - x0, y0, x1, y1 = bbox - H = y1 - y0 - W = x1 - x0 - p = np.sqrt(target_pixels/(H*W)) - target_width = int(np.ceil(p * W)) - target_height = int(np.ceil(p * H)) - x0, x1 = expand_axis(x0, x1, target_width, imshape[1]) - y0, y1 = expand_axis(y0, y1, target_height, imshape[0]) - return [x0, y0, x1, y1] - - -def expand_axises_by_percentage(bbox_XYXY, imshape, percentage): - x0, y0, x1, y1 = bbox_XYXY - H = y1 - y0 - W = x1 - x0 - expansion = int(((H*W)**0.5) * percentage) - new_width = W + expansion - new_height = H + expansion - x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1]) - y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0]) - return [x0, y0, x1, y1] - - -def get_expanded_bbox( - bbox_XYXY, - imshape, - mask, - percentage_background: float, - axis_minimum_expansion: float, - target_aspect_ratio: float): - bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist() - # Expand each axis of the bounding box by a minimum percentage - bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion) - # Find the minimum bbox with the aspect ratio. Can be outside of imshape - bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio) - # Expands square box such that X% of the bbox is background - bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background) - assert isinstance(bbox_XYXY[0], (int, np.int64)) - return bbox_XYXY - - -def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape): - def area_inside_ratio(bbox, imshape): - area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - area_inside = (min(bbox[2], imshape[1]) - max(0,bbox[0])) * (min(imshape[0],bbox[3]) - max(0,bbox[1])) - return area_inside / area - ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0]) - area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) - if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside: - return False - if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area: - return False - return True diff --git a/dp2/detection/box_utils_fdf.py b/dp2/detection/box_utils_fdf.py deleted file mode 100644 index 48e4e8c6ef067eb495ff8a021d2d236606e2e906..0000000000000000000000000000000000000000 --- a/dp2/detection/box_utils_fdf.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -The FDF dataset expands bound boxes differently from what is used for CSE. -""" - -import numpy as np - - -def quadratic_bounding_box(x0, y0, width, height, imshape): - # We assume that we can create a image that is quadratic without - # minimizing any of the sides - assert width <= min(imshape[:2]) - assert height <= min(imshape[:2]) - min_side = min(height, width) - if height != width: - side_diff = abs(height - width) - # Want to extend the shortest side - if min_side == height: - # Vertical side - height += side_diff - if height > imshape[0]: - # Take full frame, and shrink width - y0 = 0 - height = imshape[0] - - side_diff = abs(height - width) - width -= side_diff - x0 += side_diff // 2 - else: - y0 -= side_diff // 2 - y0 = max(0, y0) - else: - # Horizontal side - width += side_diff - if width > imshape[1]: - # Take full frame width, and shrink height - x0 = 0 - width = imshape[1] - - side_diff = abs(height - width) - height -= side_diff - y0 += side_diff // 2 - else: - x0 -= side_diff // 2 - x0 = max(0, x0) - # Check that bbox goes outside image - x1 = x0 + width - y1 = y0 + height - if imshape[1] < x1: - diff = x1 - imshape[1] - x0 -= diff - if imshape[0] < y1: - diff = y1 - imshape[0] - y0 -= diff - assert x0 >= 0, "Bounding box outside image." - assert y0 >= 0, "Bounding box outside image." - assert x0 + width <= imshape[1], "Bounding box outside image." - assert y0 + height <= imshape[0], "Bounding box outside image." - return x0, y0, width, height - - -def expand_bounding_box(bbox, percentage, imshape): - orig_bbox = bbox.copy() - x0, y0, x1, y1 = bbox - width = x1 - x0 - height = y1 - y0 - x0, y0, width, height = quadratic_bounding_box( - x0, y0, width, height, imshape) - expanding_factor = int(max(height, width) * percentage) - - possible_max_expansion = [(imshape[0] - width) // 2, - (imshape[1] - height) // 2, - expanding_factor] - - expanding_factor = min(possible_max_expansion) - # Expand height - - if expanding_factor > 0: - - y0 = y0 - expanding_factor - y0 = max(0, y0) - - height += expanding_factor * 2 - if height > imshape[0]: - y0 -= (imshape[0] - height) - height = imshape[0] - - if height + y0 > imshape[0]: - y0 -= (height + y0 - imshape[0]) - - # Expand width - x0 = x0 - expanding_factor - x0 = max(0, x0) - - width += expanding_factor * 2 - if width > imshape[1]: - x0 -= (imshape[1] - width) - width = imshape[1] - - if width + x0 > imshape[1]: - x0 -= (width + x0 - imshape[1]) - y1 = y0 + height - x1 = x0 + width - assert y0 >= 0, "Y0 is minus" - assert height <= imshape[0], "Height is larger than image." - assert x0 + width <= imshape[1] - assert y0 + height <= imshape[0] - assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!" - assert x0 >= 0, "Y0 is minus" - assert width <= imshape[1], "Height is larger than image." - # Check that original bbox is within new - x0_o, y0_o, x1_o, y1_o = orig_bbox - assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}" - assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}" - assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}" - assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}" - - x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]] - x1 = x0 + width - y1 = y0 + height - return np.array([x0, y0, x1, y1]) - - -def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint): - keypoint = keypoint[:, :3] # only nose + eyes are relevant - kp_X = keypoint[0, :] - kp_Y = keypoint[1, :] - within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1) - within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1) - return within_X and within_Y - - -def expand_bbox_simple(bbox, percentage): - x0, y0, x1, y1 = bbox.astype(float) - width = x1 - x0 - height = y1 - y0 - x_c = int(x0) + width // 2 - y_c = int(y0) + height // 2 - avg_size = max(width, height) - new_width = avg_size * (1 + percentage) - x0 = x_c - new_width // 2 - y0 = y_c - new_width // 2 - x1 = x_c + new_width // 2 - y1 = y_c + new_width // 2 - return np.array([x0, y0, x1, y1]).astype(int) - - -def pad_image(im, bbox, pad_value): - x0, y0, x1, y1 = bbox - if x0 < 0: - pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]), - dtype=np.uint8) + pad_value - im = np.concatenate((pad_im, im), axis=1) - x1 += abs(x0) - x0 = 0 - if y0 < 0: - pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]), - dtype=np.uint8) + pad_value - im = np.concatenate((pad_im, im), axis=0) - y1 += abs(y0) - y0 = 0 - if x1 >= im.shape[1]: - pad_im = np.zeros( - (im.shape[0], x1 - im.shape[1] + 1, im.shape[2]), - dtype=np.uint8) + pad_value - im = np.concatenate((im, pad_im), axis=1) - if y1 >= im.shape[0]: - pad_im = np.zeros( - (y1 - im.shape[0] + 1, im.shape[1], im.shape[2]), - dtype=np.uint8) + pad_value - im = np.concatenate((im, pad_im), axis=0) - return im[y0:y1, x0:x1] - - -def clip_box(bbox, im): - bbox[0] = max(0, bbox[0]) - bbox[1] = max(0, bbox[1]) - bbox[2] = min(im.shape[1] - 1, bbox[2]) - bbox[3] = min(im.shape[0] - 1, bbox[3]) - return bbox - - -def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True): - outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0] - if simple_expand or (outside_im and pad_im): - return pad_image(im, bbox, pad_value) - bbox = clip_box(bbox, im) - x0, y0, x1, y1 = bbox - return im[y0:y1, x0:x1] - - -def expand_bbox( - bbox_ltrb, imshape, simple_expand, default_to_simple=False, - expansion_factor=0.35): - assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox.shape}" - bbox = bbox_ltrb.astype(float) - # FDF256 uses simple expand with ratio 0.4 - if simple_expand: - return expand_bbox_simple(bbox, 0.4) - try: - return expand_bounding_box(bbox, expansion_factor, imshape) - except AssertionError: - return expand_bbox_simple(bbox, expansion_factor * 2) - diff --git a/dp2/detection/cse_mask_face_detector.py b/dp2/detection/cse_mask_face_detector.py deleted file mode 100644 index 74a8cf43eb35516e5e2c2c3354e15fc50ea88016..0000000000000000000000000000000000000000 --- a/dp2/detection/cse_mask_face_detector.py +++ /dev/null @@ -1,116 +0,0 @@ -import torch -import lzma -import tops -from pathlib import Path -from dp2.detection.base import BaseDetector -from .utils import combine_cse_maskrcnn_dets -from face_detection import build_detector as build_face_detector -from .models.cse import CSEDetector -from .models.mask_rcnn import MaskRCNNDetector -from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection -from tops import logger - - -def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor): - assert len(box1.shape) == 2 - assert len(box2.shape) == 2 - box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool) - # This can be batched - for i, box in enumerate(box1): - is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1) - is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1) - is_outside = is_outside_lefttop.logical_or(is_outside_rightbot) - box1_inside[i] = is_outside.logical_not().any() - return box1_inside - - -class CSeMaskFaceDetector(BaseDetector): - - def __init__( - self, - mask_rcnn_cfg, - face_detector_cfg: dict, - cse_cfg: dict, - face_post_process_cfg: dict, - cse_post_process_cfg, - score_threshold: float, - **kwargs - ) -> None: - super().__init__(**kwargs) - self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) - if "confidence_threshold" not in face_detector_cfg: - face_detector_cfg["confidence_threshold"] = score_threshold - if "score_thres" not in cse_cfg: - cse_cfg["score_thres"] = score_threshold - self.cse_detector = CSEDetector(**cse_cfg) - self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True) - self.cse_post_process_cfg = cse_post_process_cfg - self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1)) - self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold") - self.face_post_process_cfg = face_post_process_cfg - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def _detect_faces(self, im: torch.Tensor): - H, W = im.shape[1:] - im = im.float() - self.face_mean - im = self.face_detector.resize(im[None], 1.0) - boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score - boxes_XYXY[:, [0, 2]] *= W - boxes_XYXY[:, [1, 3]] *= H - return boxes_XYXY.round().long() - - def load_from_cache(self, cache_path: Path): - logger.log(f"Loading detection from cache path: {cache_path}",) - with lzma.open(cache_path, "rb") as fp: - state_dict = torch.load(fp, map_location="cpu") - kwargs = dict( - post_process_cfg=self.cse_post_process_cfg, - embed_map=self.cse_detector.embed_map, - **self.face_post_process_cfg - ) - return [ - state["cls"].from_state_dict(**kwargs, state_dict=state) - for state in state_dict - ] - - @torch.no_grad() - def forward(self, im: torch.Tensor): - maskrcnn_dets = self.mask_rcnn(im) - cse_dets = self.cse_detector(im) - embed_map = self.cse_detector.embed_map - print("Calling face detector.") - face_boxes = self._detect_faces(im).cpu() - maskrcnn_person = { - k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items() - } - maskrcnn_other = { - k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items() - } - maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"]) - combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets( - maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold) - - persons_with_cse = CSEPersonDetection( - combined_segmentation, cse_dets, **self.cse_post_process_cfg, - embed_map=embed_map,orig_imshape_CHW=im.shape - ) - persons_with_cse.pre_process() - not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]] - persons_without_cse = PersonDetection( - maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg, - orig_imshape_CHW=im.shape - ) - persons_without_cse.pre_process() - - face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or( - box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes) - ) - face_boxes = face_boxes[face_boxes_covered.logical_not()] - face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg) - - # Order matters. The anonymizer will anonymize FIFO. - # Later detections will overwrite. - all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse] - return all_detections diff --git a/dp2/detection/face_detector.py b/dp2/detection/face_detector.py deleted file mode 100644 index b05565bc3bc095edf1760c24c4238fe20b9962dc..0000000000000000000000000000000000000000 --- a/dp2/detection/face_detector.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -import lzma -import tops -from pathlib import Path -from dp2.detection.base import BaseDetector -from face_detection import build_detector as build_face_detector -from .structures import FaceDetection -from tops import logger - - -def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor): - assert len(box1.shape) == 2 - assert len(box2.shape) == 2 - box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool) - # This can be batched - for i, box in enumerate(box1): - is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1) - is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1) - is_outside = is_outside_lefttop.logical_or(is_outside_rightbot) - box1_inside[i] = is_outside.logical_not().any() - return box1_inside - - -class FaceDetector(BaseDetector): - - def __init__( - self, - face_detector_cfg: dict, - score_threshold: float, - face_post_process_cfg: dict, - **kwargs - ) -> None: - super().__init__(**kwargs) - self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold) - self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1)) - self.face_post_process_cfg = face_post_process_cfg - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def _detect_faces(self, im: torch.Tensor): - H, W = im.shape[1:] - im = im.float() - self.face_mean - im = self.face_detector.resize(im[None], 1.0) - boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score - boxes_XYXY[:, [0, 2]] *= W - boxes_XYXY[:, [1, 3]] *= H - return boxes_XYXY.round().long().cpu() - - @torch.no_grad() - def forward(self, im: torch.Tensor): - face_boxes = self._detect_faces(im) - face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg) - return [face_boxes] - - def load_from_cache(self, cache_path: Path): - logger.log(f"Loading detection from cache path: {cache_path}") - with lzma.open(cache_path, "rb") as fp: - state_dict = torch.load(fp) - return [ - state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict - ] diff --git a/dp2/detection/models/__init__.py b/dp2/detection/models/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/dp2/detection/models/cse.py b/dp2/detection/models/cse.py deleted file mode 100644 index fe6b0f6c876ed86d542604ea0f4d7274afe31584..0000000000000000000000000000000000000000 --- a/dp2/detection/models/cse.py +++ /dev/null @@ -1,135 +0,0 @@ -import torch -from typing import List -import tops -from torchvision.transforms.functional import InterpolationMode, resize -from densepose.data.utils import get_class_to_mesh_name_mapping -from densepose import add_densepose_config -from densepose.structures import DensePoseEmbeddingPredictorOutput -from densepose.vis.extractor import DensePoseOutputsExtractor -from densepose.modeling import build_densepose_embedder -from detectron2.config import get_cfg -from detectron2.data.transforms import ResizeShortestEdge -from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer -from detectron2.modeling import build_model - - -model_urls = { - "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl", - "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl", -} - - -def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape): - assert len(S.shape) == 3 - H, W = imshape - N = len(boxes_XYXY) - segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device) - boxes_XYXY = boxes_XYXY.long() - for i in range(N): - x0, y0, x1, y1 = boxes_XYXY[i] - assert x0 >= 0 and y0 >= 0 - assert x1 <= imshape[1] - assert y1 <= imshape[0] - h = y1 - y0 - w = x1 - x0 - segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0 - return segmentation - - -class CSEDetector: - - def __init__( - self, - cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml", - cfg_2_download: List[str] = [ - "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml", - "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml", - "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"], - score_thres: float = 0.9, - nms_thresh: float = None, - ) -> None: - with tops.logger.capture_log_stdout(): - cfg = get_cfg() - self.device = tops.get_device() - add_densepose_config(cfg) - cfg_path = tops.download_file(cfg_url) - for p in cfg_2_download: - tops.download_file(p) - with tops.logger.capture_log_stdout(): - cfg.merge_from_file(cfg_path) - assert cfg_url in model_urls, cfg_url - model_path = tops.download_file(model_urls[cfg_url]) - cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres - if nms_thresh is not None: - cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh - cfg.MODEL.WEIGHTS = str(model_path) - cfg.MODEL.DEVICE = str(self.device) - cfg.freeze() - with tops.logger.capture_log_stdout(): - self.model = build_model(cfg) - self.model.eval() - DetectionCheckpointer(self.model).load(str(model_path)) - self.input_format = cfg.INPUT.FORMAT - self.densepose_extractor = DensePoseOutputsExtractor() - self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) - - self.embedder = build_densepose_embedder(cfg) - self.mesh_vertex_embeddings = { - mesh_name: self.embedder(mesh_name).to(self.device) - for mesh_name in self.class_to_mesh_name.values() - if self.embedder.has_embeddings(mesh_name) - } - self.cfg = cfg - self.embed_map = self.mesh_vertex_embeddings["smpl_27554"] - tops.logger.log("CSEDetector built.") - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def resize_im(self, im): - H, W = im.shape[1:] - newH, newW = ResizeShortestEdge.get_output_shape( - H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST) - return resize( - im, (newH, newW), InterpolationMode.BILINEAR, antialias=True) - - @torch.no_grad() - def forward(self, im): - assert im.dtype == torch.uint8 - if self.input_format == "BGR": - im = im.flip(0) - H, W = im.shape[1:] - im = self.resize_im(im) - output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"] - scores = output.get("scores") - if len(scores) == 0: - return dict( - instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device), - instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device), - embed_map=self.mesh_vertex_embeddings["smpl_27554"], - bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device), - im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device), - scores=torch.empty((0), dtype=torch.float, device=im.device) - ) - pred_densepose, boxes_xywh, classes = self.densepose_extractor(output) - assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose - S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes) - E = pred_densepose.embedding - mesh_name = self.class_to_mesh_name[classes[0]] - assert mesh_name == "smpl_27554" - x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)] - boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1) - boxes_XYXY = boxes_XYXY.round_().long() - - non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not() - S = S[non_empty_boxes] - E = E[non_empty_boxes] - boxes_XYXY = boxes_XYXY[non_empty_boxes] - scores = scores[non_empty_boxes] - im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W]) - return dict( - instance_segmentation=S, instance_embedding=E, - bbox_XYXY=boxes_XYXY, - im_segmentation=im_segmentation, - scores=scores.view(-1)) - diff --git a/dp2/detection/models/keypoint_maskrcnn.py b/dp2/detection/models/keypoint_maskrcnn.py deleted file mode 100644 index 4fc3fd9e19aa8a023ad8135f6e6997135049c3db..0000000000000000000000000000000000000000 --- a/dp2/detection/models/keypoint_maskrcnn.py +++ /dev/null @@ -1,111 +0,0 @@ -import numpy as np -import torch -from detectron2.checkpoint import DetectionCheckpointer -from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads -from detectron2.data.transforms import ResizeShortestEdge -from detectron2.structures import Instances -from detectron2 import model_zoo -from detectron2.config import instantiate -from detectron2.config import LazyCall as L -from PIL import Image -import tops -import functools -from torchvision.transforms.functional import resize - - -def get_rn50_fpn_keypoint_rcnn(weight_path: str): - from detectron2.modeling.poolers import ROIPooler - from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead - from detectron2.layers import ShapeSpec - model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model - model.roi_heads.update( - num_classes=1, - keypoint_in_features=["p2", "p3", "p4", "p5"], - keypoint_pooler=L(ROIPooler)( - output_size=14, - scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), - sampling_ratio=0, - pooler_type="ROIAlignV2", - ), - keypoint_head=L(KRCNNConvDeconvUpsampleHead)( - input_shape=ShapeSpec(channels=256, width=14, height=14), - num_keypoints=17, - conv_dims=[512] * 8, - loss_normalizer="visible", - ), - ) - - # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2. - # 1000 proposals per-image is found to hurt box AP. - # Therefore we increase it to 1500 per-image. - model.proposal_generator.post_nms_topk = (1500, 1000) - - # Keypoint AP degrades (though box AP improves) when using plain L1 loss - model.roi_heads.box_predictor.smooth_l1_beta = 0.5 - model = instantiate(model) - - dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader - test_transform = instantiate(dataloader.test.mapper.augmentations) - DetectionCheckpointer(model).load(weight_path) - return model, test_transform - - -models = { - "rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth") -} - - - - -class KeypointMaskRCNN: - - def __init__(self, model_name: str, score_threshold: float) -> None: - assert model_name in models, f"Did not find {model_name} in models" - model, test_transform = models[model_name]() - self.model = model.eval().to(tops.get_device()) - if isinstance(self.model.roi_heads, CascadeROIHeads): - for head in self.model.roi_heads.box_predictors: - assert hasattr(head, "test_score_thresh") - head.test_score_thresh = score_threshold - else: - assert isinstance(self.model.roi_heads, StandardROIHeads) - assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh") - self.model.roi_heads.box_predictor.test_score_thresh = score_threshold - - self.test_transform = test_transform - assert len(self.test_transform) == 1 - self.test_transform = self.test_transform[0] - assert isinstance(self.test_transform, ResizeShortestEdge) - assert self.test_transform.interp == Image.BILINEAR - self.image_format = self.model.input_format - - def resize_im(self, im): - H, W = im.shape[-2:] - if self.test_transform.is_range: - size = np.random.randint(self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1) - else: - size = np.random.choice(self.test_transform.short_edge_length) - newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size) - return resize( - im, (newH, newW), antialias=True) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - @torch.no_grad() - def forward(self, im: torch.Tensor) -> Instances: - assert im.ndim == 3 - if self.image_format == "BGR": - im = im.flip(0) - H, W = im.shape[-2:] - im = self.resize_im(im) - im = im.float() - inputs = dict(image=im, height=H, width=W) - # instances contains - # dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps']) - instances = self.model([inputs])[0]["instances"] - return dict( - scores=instances.get("scores").cpu(), - segmentation=instances.get("pred_masks").cpu(), - keypoints=instances.get("pred_keypoints").cpu() - ) diff --git a/dp2/detection/models/mask_rcnn.py b/dp2/detection/models/mask_rcnn.py deleted file mode 100644 index 1f87d151709c8adede45aa00de0c4bce0287114e..0000000000000000000000000000000000000000 --- a/dp2/detection/models/mask_rcnn.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch -import tops -from detectron2.modeling import build_model -from detectron2.checkpoint import DetectionCheckpointer -from detectron2.structures import Boxes -from detectron2.data import MetadataCatalog -from detectron2 import model_zoo -from typing import Dict -from detectron2.data.transforms import ResizeShortestEdge -from torchvision.transforms.functional import resize - - - -model_urls = { - "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml": "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl", - -} -class MaskRCNNDetector: - - def __init__( - self, - cfg_name: str = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml", - score_thres: float = 0.9, - class_filter=["person"], #["car", "bicycle","truck", "bus", "backpack"] - fp16_inference: bool = False - ) -> None: - cfg = model_zoo.get_config(cfg_name) - cfg.MODEL.DEVICE = str(tops.get_device()) - cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres - cfg.freeze() - self.cfg = cfg - with tops.logger.capture_log_stdout(): - self.model = build_model(cfg) - DetectionCheckpointer(self.model).load(model_urls[cfg_name]) - self.model.eval() - self.input_format = cfg.INPUT.FORMAT - self.class_names = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes - self.class_to_keep = set([self.class_names.index(cls_) for cls_ in class_filter]) - self.person_class = self.class_names.index("person") - self.fp16_inference = fp16_inference - tops.logger.log("Mask R-CNN built.") - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def resize_im(self, im): - H, W = im.shape[1:] - newH, newW = ResizeShortestEdge.get_output_shape( - H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST) - return resize( - im, (newH, newW), antialias=True) - - @torch.no_grad() - def forward(self, im: torch.Tensor): - if self.input_format == "BGR": - im = im.flip(0) - else: - assert self.input_format == "RGB" - H, W = im.shape[-2:] - im = self.resize_im(im) - with torch.cuda.amp.autocast(enabled=self.fp16_inference): - output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"] - scores = output.get("scores") - N = len(scores) - classes = output.get("pred_classes") - idx2keep = [i for i in range(N) if classes[i].tolist() in self.class_to_keep] - classes = classes[idx2keep] - assert isinstance(output.get("pred_boxes"), Boxes) - segmentation = output.get("pred_masks")[idx2keep] - assert segmentation.dtype == torch.bool - is_person = classes == self.person_class - return { - "scores": output.get("scores")[idx2keep], - "segmentation": segmentation, - "classes": output.get("pred_classes")[idx2keep], - "is_person": is_person - } - diff --git a/dp2/detection/person_detector.py b/dp2/detection/person_detector.py deleted file mode 100644 index 1bbd0df8c2aa44839a5de8bd9a6aeede054ff2ee..0000000000000000000000000000000000000000 --- a/dp2/detection/person_detector.py +++ /dev/null @@ -1,135 +0,0 @@ -import torch -import lzma -from dp2.detection.base import BaseDetector -from .utils import combine_cse_maskrcnn_dets -from .models.cse import CSEDetector -from .models.mask_rcnn import MaskRCNNDetector -from .models.keypoint_maskrcnn import KeypointMaskRCNN -from .structures import CSEPersonDetection, PersonDetection -from pathlib import Path - - -class CSEPersonDetector(BaseDetector): - def __init__( - self, - score_threshold: float, - mask_rcnn_cfg: dict, - cse_cfg: dict, - cse_post_process_cfg: dict, - **kwargs - ) -> None: - super().__init__(**kwargs) - self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) - self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold) - self.post_process_cfg = cse_post_process_cfg - self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold") - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def load_from_cache(self, cache_path: Path): - with lzma.open(cache_path, "rb") as fp: - state_dict = torch.load(fp) - kwargs = dict( - post_process_cfg=self.post_process_cfg, - embed_map=self.cse_detector.embed_map, - ) - return [ - state["cls"].from_state_dict(**kwargs, state_dict=state) - for state in state_dict - ] - - @torch.no_grad() - def forward(self, im: torch.Tensor, cse_dets=None): - mask_dets = self.mask_rcnn(im) - if cse_dets is None: - cse_dets = self.cse_detector(im) - segmentation = mask_dets["segmentation"] - segmentation, cse_dets, _ = combine_cse_maskrcnn_dets( - segmentation, cse_dets, self.iou_combine_threshold - ) - det = CSEPersonDetection( - segmentation=segmentation, - cse_dets=cse_dets, - embed_map=self.cse_detector.embed_map, - orig_imshape_CHW=im.shape, - **self.post_process_cfg - ) - return [det] - - -class MaskRCNNPersonDetector(BaseDetector): - def __init__( - self, - score_threshold: float, - mask_rcnn_cfg: dict, - cse_post_process_cfg: dict, - **kwargs - ) -> None: - super().__init__(**kwargs) - self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) - self.post_process_cfg = cse_post_process_cfg - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def load_from_cache(self, cache_path: Path): - with lzma.open(cache_path, "rb") as fp: - state_dict = torch.load(fp) - kwargs = dict( - post_process_cfg=self.post_process_cfg, - ) - return [ - state["cls"].from_state_dict(**kwargs, state_dict=state) - for state in state_dict - ] - - @torch.no_grad() - def forward(self, im: torch.Tensor): - mask_dets = self.mask_rcnn(im) - segmentation = mask_dets["segmentation"] - det = PersonDetection( - segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape - ) - return [det] - - -class KeypointMaskRCNNPersonDetector(BaseDetector): - def __init__( - self, - score_threshold: float, - mask_rcnn_cfg: dict, - cse_post_process_cfg: dict, - **kwargs - ) -> None: - super().__init__(**kwargs) - self.mask_rcnn = KeypointMaskRCNN( - **mask_rcnn_cfg, score_threshold=score_threshold - ) - self.post_process_cfg = cse_post_process_cfg - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def load_from_cache(self, cache_path: Path): - with lzma.open(cache_path, "rb") as fp: - state_dict = torch.load(fp) - kwargs = dict( - post_process_cfg=self.post_process_cfg, - ) - return [ - state["cls"].from_state_dict(**kwargs, state_dict=state) - for state in state_dict - ] - - @torch.no_grad() - def forward(self, im: torch.Tensor): - mask_dets = self.mask_rcnn(im) - segmentation = mask_dets["segmentation"] - det = PersonDetection( - segmentation, - **self.post_process_cfg, - orig_imshape_CHW=im.shape, - keypoints=mask_dets["keypoints"] - ) - return [det] diff --git a/dp2/detection/structures.py b/dp2/detection/structures.py deleted file mode 100644 index 3c78de781612b118b2d12e318f78100301439a95..0000000000000000000000000000000000000000 --- a/dp2/detection/structures.py +++ /dev/null @@ -1,464 +0,0 @@ -import torch -import numpy as np -from dp2 import utils -from dp2.utils import vis_utils, crop_box -from .utils import ( - cut_pad_resize, masks_to_boxes, - get_kernel, transform_embedding, initialize_cse_boxes - ) -from .box_utils import get_expanded_bbox, include_box -import torchvision -import tops -from .box_utils_fdf import expand_bbox as expand_bbox_fdf - - -class VehicleDetection: - - def __init__(self, segmentation: torch.BoolTensor) -> None: - self.segmentation = segmentation - self.boxes = masks_to_boxes(segmentation) - assert self.boxes.shape[1] == 4, self.boxes.shape - self.n_detections = self.segmentation.shape[0] - area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0]) - - sorted_idx = torch.argsort(area, descending=True) - self.segmentation = self.segmentation[sorted_idx] - self.boxes = self.boxes[sorted_idx].cpu() - - def pre_process(self): - pass - - def get_crop(self, idx: int, im): - assert idx < len(self) - box = self.boxes[idx] - im = crop_box(self.im, box) - mask = crop_box(self.segmentation[idx]) - mask = mask == 0 - return dict(img=im, mask=mask.float(), boxes=box) - - def visualize(self, im): - if len(self) == 0: - return im - im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not()) - return im - - def __len__(self): - return self.n_detections - - @staticmethod - def from_state_dict(state_dict, **kwargs): - numel = np.prod(state_dict["shape"]) - arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel) - segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"]) - return VehicleDetection(segmentation) - - def state_dict(self, **kwargs): - segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy())) - return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape) - - -class FaceDetection: - - def __init__(self, boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool, **kwargs) -> None: - self.boxes = boxes_ltrb.cpu() - assert self.boxes.shape[1] == 4, self.boxes.shape - self.target_imsize = tuple(target_imsize) - # Sory by area to paste in largest faces last - area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1) - idx = area.argsort(descending=False) - self.boxes = self.boxes[idx] - self.fdf128_expand = fdf128_expand - - def visualize(self, im): - if len(self) == 0: - return im - orig_device = im.device - for box in self.boxes: - simple_expand = False if self.fdf128_expand else True - e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand)) - im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2) - im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2) - - return im.to(device=orig_device) - - def get_crop(self, idx: int, im): - assert idx < len(self) - box = self.boxes[idx].numpy() - simple_expand = False if self.fdf128_expand else True - expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], simple_expand=simple_expand) - im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True) - area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1) - - # Find the square mask corresponding to box. - box_mask = box.copy().astype(float) - box_mask[[0, 2]] -= expanded_boxes[0] - box_mask[[1, 3]] -= expanded_boxes[1] - - width = expanded_boxes[2] - expanded_boxes[0] - resize_factor = self.target_imsize[0] / width - box_mask = (box_mask * resize_factor).astype(int) - mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32) - crop_box(mask, box_mask).fill_(0) - return dict( - img=im[None], mask=mask[None], - boxes=torch.from_numpy(expanded_boxes).view(1, -1)) - - def __len__(self): - return len(self.boxes) - - @staticmethod - def from_state_dict(state_dict, **kwargs): - return FaceDetection(state_dict["boxes"].cpu(), **kwargs) - - def state_dict(self, **kwargs): - return dict(boxes=self.boxes, cls=self.__class__) - - def pre_process(self): - pass - - -def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape): - """ - Dilation happens after padding, which could place dilation in the padded area. - Remove this. - """ - x0, y0, x1, y1 = exp_box - H, W = orig_imshape - # Padding in original image space - p_y0 = max(0, -y0) - p_y1 = max(y1 - H, 0) - p_x0 = max(0, -x0) - p_x1 = max(x1 - W, 0) - resize_ratio = mask.shape[-2] / (y1-y0) - p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]] - mask[..., :p_y0, :] = 0 - mask[..., :p_x0] = 0 - mask[..., mask.shape[-2] - p_y1:, :] = 0 - mask[..., mask.shape[-1] - p_x1:] = 0 - - -class CSEPersonDetection: - - def __init__(self, - segmentation, cse_dets, - target_imsize, - exp_bbox_cfg, exp_bbox_filter, - dilation_percentage: float, - embed_map: torch.Tensor, - orig_imshape_CHW, - normalize_embedding: bool) -> None: - self.segmentation = segmentation - self.cse_dets = cse_dets - self.target_imsize = list(target_imsize) - self.pre_processed = False - self.exp_bbox_cfg = exp_bbox_cfg - self.exp_bbox_filter = exp_bbox_filter - self.dilation_percentage = dilation_percentage - self.embed_map = embed_map - self.normalize_embedding = normalize_embedding - if self.normalize_embedding: - embed_map_mean = self.embed_map.mean(dim=0, keepdim=True) - embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() - self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd - self.orig_imshape_CHW = orig_imshape_CHW - - @torch.no_grad() - def pre_process(self): - if self.pre_processed: - return - boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu() - expanded_boxes = [] - included_boxes = [] - for i in range(len(boxes)): - exp_box = get_expanded_bbox( - boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg, - target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1]) - if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter): - continue - included_boxes.append(i) - expanded_boxes.append(exp_box) - expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4) - self.segmentation = self.segmentation[included_boxes] - self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()} - - self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool) - area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)) - for i, box in enumerate(expanded_boxes): - self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0] - - dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage)) - self.maskrcnn_mask = self.mask.clone().logical_not()[:, None] - self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel) - [remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))] - self.boxes = expanded_boxes.cpu() - self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask) - - self.pre_processed = True - self.n_detections = len(self.boxes) - self.mask = self.mask.logical_not() - - E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool) - self.vertices = torch.zeros_like(E_mask, dtype=torch.long) - for i in range(self.n_detections): - E_, E_mask[i] = transform_embedding( - self.cse_dets["instance_embedding"][i], - self.cse_dets["instance_segmentation"][i], - self.boxes[i], - self.cse_dets["bbox_XYXY"][i].cpu(), - self.target_imsize - ) - self.vertices[i] = utils.from_E_to_vertex(E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None] - self.E_mask = E_mask - - sorted_idx = torch.argsort(area, descending=False) - self.mask = self.mask[sorted_idx] - self.boxes = self.boxes[sorted_idx.cpu()] - self.vertices = self.vertices[sorted_idx] - self.E_mask = self.E_mask[sorted_idx] - self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx] - - def get_crop(self, idx: int, im): - self.pre_process() - assert idx < len(self) - box = self.boxes[idx] - mask = self.mask[idx] - im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0) - - vertices_ = self.vertices[idx] - E_mask_ = self.E_mask[idx].float() - if self.normalize_embedding: - embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_ - else: - embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_ - - return dict( - img=im, - mask=mask.float()[None], - boxes=box.reshape(1, -1), - E_mask=E_mask_[None], - vertices=vertices_[None], - embed_map=self.embed_map, - embedding=embedding[None], - maskrcnn_mask=self.maskrcnn_mask[idx].float()[None] - ) - - def __len__(self): - self.pre_process() - return self.n_detections - - def state_dict(self, after_preprocess=False): - """ - The processed annotations occupy more space than the original detections. - """ - if not after_preprocess: - return { - "combined_segmentation": self.segmentation.bool(), - "cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(), - "cse_instance_embedding": self.cse_dets["instance_embedding"], - "cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(), - "cls": self.__class__, - "orig_imshape_CHW": self.orig_imshape_CHW - } - self.pre_process() - return dict( - E_mask=torch.from_numpy(np.packbits(self.E_mask.bool().cpu().numpy())), - mask=torch.from_numpy(np.packbits(self.mask.bool().cpu().numpy())), - maskrcnn_mask=torch.from_numpy(np.packbits(self.maskrcnn_mask.bool().cpu().numpy())), - vertices=self.vertices.to(torch.int16).cpu(), - cls=self.__class__, - boxes=self.boxes, - orig_imshape_CHW=self.orig_imshape_CHW, - ) - - @staticmethod - def from_state_dict( - state_dict, embed_map, - post_process_cfg, **kwargs): - after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict - if after_preprocess: - detection = CSEPersonDetection( - segmentation=None, cse_dets=None, embed_map=embed_map, - orig_imshape_CHW=state_dict["orig_imshape_CHW"], - **post_process_cfg) - detection.vertices = tops.to_cuda(state_dict["vertices"].long()) - numel = np.prod(detection.vertices.shape) - detection.E_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["E_mask"].numpy(), count=numel))).view(*detection.vertices.shape) - detection.mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["mask"].numpy(), count=numel))).view(*detection.vertices.shape) - detection.maskrcnn_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["maskrcnn_mask"].numpy(), count=numel))).view(*detection.vertices.shape) - detection.n_detections = len(detection.mask) - detection.pre_processed = True - - if isinstance(state_dict["boxes"], np.ndarray): - state_dict["boxes"] = torch.from_numpy(state_dict["boxes"]) - detection.boxes = state_dict["boxes"] - return detection - - cse_dets = dict( - instance_segmentation=state_dict["cse_instance_segmentation"], - instance_embedding=state_dict["cse_instance_embedding"], - embed_map=embed_map, - bbox_XYXY=state_dict["cse_bbox_XYXY"]) - cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()} - - segmentation = state_dict["combined_segmentation"] - return CSEPersonDetection( - segmentation, cse_dets, embed_map=embed_map, - orig_imshape_CHW=state_dict["orig_imshape_CHW"], - **post_process_cfg) - - def visualize(self, im): - self.pre_process() - if len(self) == 0: - return im - im = vis_utils.draw_cropped_masks( - im.clone(), self.mask, self.boxes, visualize_instances=False) - E = self.embed_map[self.vertices.long()].squeeze(1).permute(0,3, 1, 2) - im = im.to(E.device) - im = vis_utils.draw_cse_all( - E, self.E_mask.squeeze(1).bool(), im, - self.boxes, self.embed_map) - im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2) - return im - - -def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes): - keypoints = keypoints.clone() - N = boxes.shape[0] - tops.assert_shape(keypoints, (N, None, 3)) - tops.assert_shape(boxes, (N, 4)) - x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T] - - w = x1 - x0 - h = y1 - y0 - keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w - keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h - check_outside = lambda x: (x < 0).logical_or(x > 1) - is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1])) - keypoints[:, :, 2] = keypoints[:, :, 2] >= 0 - keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not()) - return keypoints - - -class PersonDetection: - - def __init__( - self, - segmentation, - target_imsize, - exp_bbox_cfg, exp_bbox_filter, - dilation_percentage: float, - orig_imshape_CHW, - keypoints=None, - **kwargs) -> None: - self.segmentation = segmentation - self.target_imsize = list(target_imsize) - self.pre_processed = False - self.exp_bbox_cfg = exp_bbox_cfg - self.exp_bbox_filter = exp_bbox_filter - self.dilation_percentage = dilation_percentage - self.orig_imshape_CHW = orig_imshape_CHW - self.keypoints = keypoints - - @torch.no_grad() - def pre_process(self): - if self.pre_processed: - return - boxes = masks_to_boxes(self.segmentation).cpu() - expanded_boxes = [] - included_boxes = [] - for i in range(len(boxes)): - exp_box = get_expanded_bbox( - boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg, - target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1]) - if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter): - continue - included_boxes.append(i) - expanded_boxes.append(exp_box) - expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4) - self.segmentation = self.segmentation[included_boxes] - if self.keypoints is not None: - self.keypoints = self.keypoints[included_boxes] - area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)) - self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool) - for i, box in enumerate(expanded_boxes): - self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0] - if self.keypoints is not None: - self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes) - dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage)) - self.maskrcnn_mask = self.mask.clone().logical_not()[:, None] - self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel) - - [remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))] - self.boxes = expanded_boxes - self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask) - - self.pre_processed = True - self.n_detections = len(self.boxes) - self.mask = self.mask.logical_not() - - sorted_idx = torch.argsort(area, descending=False) - self.mask = self.mask[sorted_idx] - self.boxes = self.boxes[sorted_idx.cpu()] - self.segmentation = self.segmentation[sorted_idx] - self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx] - if self.keypoints is not None: - self.keypoints = self.keypoints[sorted_idx] - - def get_crop(self, idx: int, im: torch.Tensor): - assert idx < len(self) - self.pre_process() - box = self.boxes[idx] - mask = self.mask[idx][None].float() - im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0) - batch = dict( - img=im, mask=mask, boxes=box.reshape(1, -1), - maskrcnn_mask=self.maskrcnn_mask[idx][None].float()) - if self.keypoints is not None: - batch["keypoints"] = self.keypoints[idx:idx+1] - return batch - - def __len__(self): - self.pre_process() - return self.n_detections - - def state_dict(self, **kwargs): - return dict( - segmentation=self.segmentation.bool(), - cls=self.__class__, - orig_imshape_CHW=self.orig_imshape_CHW, - keypoints=self.keypoints - ) - - @staticmethod - def from_state_dict( - state_dict, - post_process_cfg, **kwargs): - return PersonDetection( - state_dict["segmentation"], - orig_imshape_CHW=state_dict["orig_imshape_CHW"], - **post_process_cfg, - keypoints=state_dict["keypoints"]) - - def visualize(self, im): - self.pre_process() - im = im.cpu() - if len(self) == 0: - return im - im = vis_utils.draw_cropped_masks(im.clone(), self.mask, self.boxes, visualize_instances=False) - im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes) - return im - - -def get_dilated_boxes(exp_bbox: torch.LongTensor, mask): - """ - mask: resized mask - """ - assert exp_bbox.shape[0] == mask.shape[0] - boxes = masks_to_boxes(mask.squeeze(1)).cpu() - H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0] - boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long() - boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long() - boxes[:, [0, 2]] += exp_bbox[:, 0:1] - boxes[:, [1, 3]] += exp_bbox[:, 1:2] - return boxes - diff --git a/dp2/detection/utils.py b/dp2/detection/utils.py deleted file mode 100644 index 85dbd29c1832d2f48b20ed14c3d0357c958732f6..0000000000000000000000000000000000000000 --- a/dp2/detection/utils.py +++ /dev/null @@ -1,174 +0,0 @@ -import cv2 -import numpy as np -import torch -import tops -from skimage.morphology import disk -from torchvision.transforms.functional import resize, InterpolationMode -from functools import lru_cache - - -@lru_cache(maxsize=200) -def get_kernel(n: int): - kernel = disk(n, dtype=bool) - return tops.to_cuda(torch.from_numpy(kernel).bool()) - - -def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape): - """ - Transforms the detected embedding/mask directly to the target image shape - """ - - C, HE, WE = E.shape - assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox) - assert E_bbox[2] >= exp_bbox[0] - assert E_bbox[1] >= exp_bbox[1] - assert E_bbox[3] >= exp_bbox[1] - assert E_bbox[2] <= exp_bbox[2] - assert E_bbox[3] <= exp_bbox[3] - - x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1])) - x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1])) - y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0])) - y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0])) - new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32) - new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool) - - E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR) - new_E[:, y0:y1, x0:x1] = E - S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0 - new_S[y0:y1, x0:x1] = S - return new_E, new_S - - -def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor): - """ - mask: shape [N, H, W] - """ - assert len(mask1.shape) == 3 - assert len(mask2.shape) == 3 - assert mask1.device == mask2.device, (mask1.device, mask2.device) - assert mask2.dtype == mask2.dtype - assert mask1.dtype == torch.bool - assert mask1.shape[1:] == mask2.shape[1:] - N1, H1, W1 = mask1.shape - N2, H2, W2 = mask2.shape - iou = torch.zeros((N1, N2), dtype=torch.float32) - for i in range(N1): - cur = mask1[i:i+1] - inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu() - union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu() - iou[i] = inter / union - return iou - - -def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float): - N1 = mask1.shape[0] - N2 = mask2.shape[0] - ious = pairwise_mask_iou(mask1, mask2).cpu().numpy() - indices = np.array([idx for idx, iou in np.ndenumerate(ious)]) - ious = ious.flatten() - mask = ious >= iou_threshold - ious = ious[mask] - indices = indices[mask] - - # do not sort by iou to keep ordering of mask rcnn / cse sorting. - taken1 = np.zeros((N1), dtype=bool) - taken2 = np.zeros((N2), dtype=bool) - matches = [] - for i, j in indices: - if taken1[i].any() or taken2[j].any(): - continue - matches.append((i, j)) - taken1[i] = True - taken2[j] = True - return matches - - -def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float): - assert 0 < iou_threshold <= 1 - matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold) - H, W = segmentation.shape[1:] - new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device) - cse_im_seg = cse_dets["im_segmentation"] - for idx, (i, j) in enumerate(matches): - new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j]) - cse_dets = dict( - instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]], - instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]], - bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]], - scores=cse_dets["scores"][[j for (i, j) in matches]], - ) - return new_seg, cse_dets, np.array(matches).reshape(-1, 2) - - -def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor): - """ - cse_boxes can be outside of segmentation. - """ - boxes = masks_to_boxes(segmentation) - - assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape) - combined = torch.stack((boxes, cse_boxes), dim=-1) - boxes = torch.cat(( - combined[:, :2].min(dim=2).values, - combined[:, 2:].max(dim=2).values, - ), dim=1) - return boxes - - -def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False): - """ - Crops or pads x to fit in the bbox and resize to target shape. - """ - C, H, W = x.shape - x0, y0, x1, y1 = bbox - - if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H: - new_x = x[:, y0:y1, x0:x1] - else: - new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device) - y0_t = max(0, -y0) - y1_t = min(y1-y0, (y1-y0)-(y1-H)) - x0_t = max(0, -x0) - x1_t = min(x1-x0, (x1-x0)-(x1-W)) - x0 = max(0, x0) - y0 = max(0, y0) - x1 = min(x1, W) - y1 = min(y1, H) - new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1] - if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]: - return new_x - if x.dtype == torch.bool: - new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5 - elif x.dtype == torch.float32: - new_x = resize(new_x, target_shape, interpolation=InterpolationMode.BILINEAR, antialias=True) - elif x.dtype == torch.uint8: - if fdf_resize: # FDF dataset is created with cv2 INTER_AREA. - # Incorrect resizing generates noticeable poorer inpaintings. - upsampling = ((y1-y0) *(x1-x0)) < (target_shape[0] * target_shape[1]) - if upsampling: - new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC, antialias=True).round().clamp(0, 255).byte() - else: - device = new_x.device - new_x = new_x.permute(1, 2, 0).cpu().numpy() - new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA) - new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device) - else: - new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BILINEAR, antialias=True).round().clamp(0, 255).byte() - else: - raise ValueError(f"Not supported dtype: {x.dtype}") - return new_x - - - -def masks_to_boxes(segmentation: torch.Tensor): - assert len(segmentation.shape) == 3 - x = segmentation.any(dim=1).byte() # Compress rows - x0 = x.argmax(dim=1) - - x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1) - y = segmentation.any(dim=2).byte() - y0 = y.argmax(dim=1) - y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1) - return torch.stack([x0, y0, x1, y1], dim=1) - diff --git a/dp2/discriminator/__init__.py b/dp2/discriminator/__init__.py deleted file mode 100644 index 77a4a773eafecde739caf086660063a19cf2160f..0000000000000000000000000000000000000000 --- a/dp2/discriminator/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sg2_discriminator import SG2Discriminator \ No newline at end of file diff --git a/dp2/discriminator/sg2_discriminator.py b/dp2/discriminator/sg2_discriminator.py deleted file mode 100644 index 5f44e8f9d6cfb46c0763cc8ab5d96037fcd2c4e2..0000000000000000000000000000000000000000 --- a/dp2/discriminator/sg2_discriminator.py +++ /dev/null @@ -1,76 +0,0 @@ -from sg3_torch_utils.ops import upfirdn2d -import torch -import numpy as np -import torch.nn as nn -from .. import layers -from ..layers.sg2_layers import DiscriminatorEpilogue, ResidualBlock, Block - - -class SG2Discriminator(layers.Module): - - def __init__( - self, - cnum: int, - max_cnum_mul: int, - imsize, - min_fmap_resolution: int, - im_channels: int, - input_condition: bool, - conv_clamp: int, - input_cse: bool, - cse_nc: int): - super().__init__() - - cse_nc = 0 if cse_nc is None else cse_nc - self._max_imsize = max(imsize) - self._cnum = cnum - self._max_cnum_mul = max_cnum_mul - self._min_fmap_resolution = min_fmap_resolution - self._input_condition = input_condition - self.input_cse = input_cse - self.layers = nn.ModuleList() - - out_ch = self.get_chsize(self._max_imsize) - self.from_rgb = Block( - im_channels + input_condition*(im_channels+1) + input_cse*(cse_nc+1), - out_ch, conv_clamp=conv_clamp - ) - n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1 - - for i in range(n_levels): - resolution = [x//2**i for x in imsize] - in_ch = self.get_chsize(max(resolution)) - out_ch = self.get_chsize(max(max(resolution)//2, min_fmap_resolution)) - - down = 2 - if i == 0: - down = 1 - block = ResidualBlock( - in_ch, out_ch, down=down, conv_clamp=conv_clamp - ) - self.layers.append(block) - self.output_layer = DiscriminatorEpilogue( - out_ch, resolution, conv_clamp=conv_clamp) - - self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1])) - - def forward(self, img, condition, mask, embedding=None, E_mask=None,**kwargs): - to_cat = [img] - if self._input_condition: - to_cat.extend([condition, mask,]) - if self.input_cse: - to_cat.extend([embedding, E_mask]) - x = torch.cat(to_cat, dim=1) - x = self.from_rgb(x) - - for i, layer in enumerate(self.layers): - x = layer(x) - - x = self.output_layer(x) - return dict(score=x) - - def get_chsize(self, imsize): - n = int(np.log2(self._max_imsize) - np.log2(imsize)) - mul = min(2 ** n, self._max_cnum_mul) - ch = self._cnum * mul - return int(ch) diff --git a/dp2/gan_trainer.py b/dp2/gan_trainer.py deleted file mode 100644 index e9cdbcd86f596c7be984b5463e486aafe6a0890e..0000000000000000000000000000000000000000 --- a/dp2/gan_trainer.py +++ /dev/null @@ -1,324 +0,0 @@ -import atexit -from collections import defaultdict -import logging -import typing -import torch -import time -from dp2.utils import vis_utils -from dp2 import utils -from tops import logger, checkpointer -import tops -from easydict import EasyDict - - -def accumulate_gradients(params, fp16_ddp_accumulate): - if len(params) == 0: - return - params = [param for param in params if param.grad is not None] - flat = torch.cat([param.grad.flatten() for param in params]) - orig_dtype = flat.dtype - if tops.world_size() > 1: - if fp16_ddp_accumulate: - flat = flat.half() / tops.world_size() - else: - flat /= tops.world_size() - torch.distributed.all_reduce(flat) - flat = flat.to(orig_dtype) - grads = flat.split([param.numel() for param in params]) - for param, grad in zip(params, grads): - param.grad = grad.reshape(param.shape) - - -def accumulate_buffers(module: torch.nn.Module): - buffers = [buf for buf in module.buffers()] - if len(buffers) == 0: - return - flat = torch.cat([buf.flatten() for buf in buffers]) - if tops.world_size() > 1: - torch.distributed.all_reduce(flat) - flat /= tops.world_size() - bufs = flat.split([buf.numel() for buf in buffers]) - for old, new in zip(buffers, bufs): - old.copy_(new.reshape(old.shape), non_blocking=True) - - -def check_ddp_consistency(module): - if tops.world_size() == 1: - return - assert isinstance(module, torch.nn.Module) - assert isinstance(module, torch.nn.Module) - params_buffs = list(module.named_parameters()) + list(module.named_buffers()) - for name, tensor in params_buffs: - fullname = type(module).__name__ + '.' + name - tensor = tensor.detach() - if tensor.is_floating_point(): - tensor = torch.nan_to_num(tensor) - other = tensor.clone() - torch.distributed.broadcast(tensor=other, src=0) - assert (tensor == other).all(), fullname - -class AverageMeter(): - def __init__(self) -> None: - self.to_log = dict() - self.n = defaultdict(int) - pass - - @torch.no_grad() - def update(self, values: dict): - for key, value in values.items(): - self.n[key] += 1 - if key in self.to_log: - self.to_log[key] += value.mean().detach() - else: - self.to_log[key] = value.mean().detach() - - def get_average(self): - return {key: value / self.n[key] for key, value in self.to_log.items()} - - -class GANTrainer: - - def __init__( - self, - G: torch.nn.Module, - D: torch.nn.Module, - G_EMA: torch.nn.Module, - D_optim: torch.optim.Optimizer, - G_optim: torch.optim.Optimizer, - dl_train: typing.Iterator, - dl_val: typing.Iterable, - scaler_D: torch.cuda.amp.GradScaler, - scaler_G: torch.cuda.amp.GradScaler, - ims_per_log: int, - max_images_to_train: int, - loss_handler, - ims_per_val: int, - evaluate_fn, - batch_size: int, - broadcast_buffers: bool, - fp16_ddp_accumulate: bool, - save_state: bool, - *args, **kwargs): - super().__init__(*args, **kwargs) - - self.G = G - self.D = D - self.G_EMA = G_EMA - self.D_optim = D_optim - self.G_optim = G_optim - self.dl_train = dl_train - self.dl_val = dl_val - self.scaler_D = scaler_D - self.scaler_G = scaler_G - self.loss_handler = loss_handler - self.max_images_to_train = max_images_to_train - self.images_per_val = ims_per_val - self.images_per_log = ims_per_log - self.evaluate_fn = evaluate_fn - self.batch_size = batch_size - self.broadcast_buffers = broadcast_buffers - self.fp16_ddp_accumulate = fp16_ddp_accumulate - - self.train_state = EasyDict( - next_log_step=0, - next_val_step=ims_per_val, - total_time=0 - ) - - checkpointer.register_models(dict( - generator=G, discriminator=D, EMA_generator=G_EMA, - D_optimizer=D_optim, - G_optimizer=G_optim, - train_state=self.train_state, - scaler_D=self.scaler_D, - scaler_G=self.scaler_G - )) - if checkpointer.has_checkpoint(): - checkpointer.load_registered_models() - logger.log(f"Resuming training from: global step: {logger.global_step()}") - else: - logger.add_dict({ - "stats/discriminator_parameters": tops.num_parameters(self.D), - "stats/generator_parameters": tops.num_parameters(self.G), - }, commit=False) - if save_state: - # If the job is unexpectedly killed, there could be a mismatch between previously saved checkpoint and the current checkpoint. - atexit.register(checkpointer.save_registered_models) - - self._ims_per_log = ims_per_log - - self.to_log = AverageMeter() - self.trainable_params_D = [param for param in self.D.parameters() if param.requires_grad] - self.trainable_params_G = [param for param in self.G.parameters() if param.requires_grad] - logger.add_dict({ - "stats/discriminator_trainable_parameters": sum(p.numel() for p in self.trainable_params_D), - "stats/generator_trainable_parameters": sum(p.numel() for p in self.trainable_params_G), - }, commit=False, level=logging.INFO) - check_ddp_consistency(self.D) - check_ddp_consistency(self.G) - check_ddp_consistency(self.G_EMA.generator) - - def train_loop(self): - self.log_time() - while logger.global_step() <= self.max_images_to_train: - batch = next(self.dl_train) - self.G_EMA.update_beta() - self.to_log.update(self.step_D(batch)) - self.to_log.update(self.step_G(batch)) - self.G_EMA.update(self.G) - - if logger.global_step() >= self.train_state.next_log_step: - to_log = {f"loss/{key}": item.item() for key, item in self.to_log.get_average().items()} - to_log.update({"amp/grad_scale_G": self.scaler_G.get_scale()}) - to_log.update({"amp/grad_scale_D": self.scaler_D.get_scale()}) - self.to_log = AverageMeter() - logger.add_dict(to_log, commit=True) - self.train_state.next_log_step += self.images_per_log - if self.scaler_D.get_scale() < 1e-8 or self.scaler_G.get_scale() < 1e-8: - print("Stopping training as gradient scale < 1e-8") - logger.log("Stopping training as gradient scale < 1e-8") - break - - if logger.global_step() >= self.train_state.next_val_step: - self.evaluate() - self.log_time() - self.save_images() - self.train_state.next_val_step += self.images_per_val - logger.step(self.batch_size*tops.world_size()) - logger.log(f"Reached end of training at step {logger.global_step()}.") - checkpointer.save_registered_models() - - def estimate_ims_per_hour(self): - batch = next(self.dl_train) - n_ims = int(100e3) - n_steps = int(n_ims / (self.batch_size * tops.world_size())) - n_ims = n_steps * self.batch_size * tops.world_size() - for i in range(10): # Warmup - self.G_EMA.update_beta() - self.step_D(batch) - self.step_G(batch) - self.G_EMA.update(self.G) - start_time = time.time() - for i in utils.tqdm_(list(range(n_steps))): - self.G_EMA.update_beta() - self.step_D(batch) - self.step_G(batch) - self.G_EMA.update(self.G) - total_time = time.time() - start_time - ims_per_sec = n_ims / total_time - ims_per_hour = ims_per_sec * 60*60 - ims_per_day = ims_per_hour * 24 - logger.log(f"Images per hour: {ims_per_hour/1e6:.3f}M") - logger.log(f"Images per day: {ims_per_day/1e6:.3f}M") - import math - ims_per_4_day = int(math.ceil(ims_per_day / tops.world_size() * 4)) - logger.log(f"Images per 4 days: {ims_per_4_day}") - logger.add_dict({ - "stats/ims_per_day": ims_per_day, - "stats/ims_per_4_day": ims_per_4_day - }) - - def log_time(self): - if not hasattr(self, "start_time"): - self.start_time = time.time() - self.last_time_step = logger.global_step() - return - n_images = logger.global_step() - self.last_time_step - if n_images == 0: - return - n_secs = time.time() - self.start_time - n_ims_per_sec = n_images / n_secs - training_time_hours = n_secs / 60/ 60 - self.train_state.total_time += training_time_hours - remaining_images = self.max_images_to_train - logger.global_step() - remaining_time = remaining_images / n_ims_per_sec / 60 / 60 - logger.add_dict({ - "stats/n_ims_per_sec": n_ims_per_sec, - "stats/total_traing_time_hours": self.train_state.total_time, - "stats/remaining_time_hours": remaining_time - }) - self.last_time_step = logger.global_step() - self.start_time = time.time() - - def save_images(self): - dl_val = iter(self.dl_val) - batch = next(dl_val) - # TRUNCATED visualization - ims_to_log = 8 - self.G_EMA.eval() - z = self.G.get_z(batch["img"]) - fakes_truncated = self.G_EMA.sample(**batch, truncation_value=0)["img"] - fakes_truncated = utils.denormalize_img(fakes_truncated).mul(255).byte()[:ims_to_log].cpu() - if "__key__" in batch: - batch.pop("__key__") - real = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log] - to_vis = torch.cat((real, fakes_truncated)) - logger.add_images("images/truncated", to_vis, nrow=2) - - # Diverse images - ims_diverse = 3 - batch = next(dl_val) - to_vis = [] - - for i in range(ims_diverse): - z = self.G.get_z(batch["img"])[:1].repeat(batch["img"].shape[0], 1) - fakes = utils.denormalize_img(self.G_EMA(**batch, z=z)["img"]).mul(255).byte()[:ims_to_log].cpu() - to_vis.append(fakes) - if "__key__" in batch: - batch.pop("__key__") - reals = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log] - to_vis.insert(0, reals) - to_vis = torch.cat(to_vis) - logger.add_images("images/diverse", to_vis, nrow=ims_diverse+1) - - self.G_EMA.train() - pass - - def evaluate(self): - logger.log("Stating evaluation.") - self.G_EMA.eval() - try: - checkpointer.save_registered_models(max_keep=3) - except Exception: - logger.log("Could not save checkpoint.") - if self.broadcast_buffers: - check_ddp_consistency(self.G) - check_ddp_consistency(self.D) - metrics = self.evaluate_fn(generator=self.G_EMA, dataloader=self.dl_val) - metrics = {f"metrics/{k}": v for k,v in metrics.items()} - logger.add_dict(metrics, level=logger.logger.INFO) - - def step_D(self, batch): - utils.set_requires_grad(self.trainable_params_D, True) - utils.set_requires_grad(self.trainable_params_G, False) - tops.zero_grad(self.D) - loss, to_log = self.loss_handler.D_loss(batch, grad_scaler=self.scaler_D) - with torch.autograd.profiler.record_function("D_step"): - self.scaler_D.scale(loss).backward() - accumulate_gradients(self.trainable_params_D, fp16_ddp_accumulate=self.fp16_ddp_accumulate) - if self.broadcast_buffers: - accumulate_buffers(self.D) - accumulate_buffers(self.G) - # Step will not unscale if unscale is called previously. - self.scaler_D.step(self.D_optim) - self.scaler_D.update() - utils.set_requires_grad(self.trainable_params_D, False) - utils.set_requires_grad(self.trainable_params_G, False) - return to_log - - def step_G(self, batch): - utils.set_requires_grad(self.trainable_params_D, False) - utils.set_requires_grad(self.trainable_params_G, True) - tops.zero_grad(self.G) - loss, to_log = self.loss_handler.G_loss(batch, grad_scaler=self.scaler_G) - with torch.autograd.profiler.record_function("G_step"): - self.scaler_G.scale(loss).backward() - accumulate_gradients(self.trainable_params_G, fp16_ddp_accumulate=self.fp16_ddp_accumulate) - if self.broadcast_buffers: - accumulate_buffers(self.G) - accumulate_buffers(self.D) - self.scaler_G.step(self.G_optim) - self.scaler_G.update() - utils.set_requires_grad(self.trainable_params_D, False) - utils.set_requires_grad(self.trainable_params_G, False) - return to_log \ No newline at end of file diff --git a/dp2/generator/__init__.py b/dp2/generator/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/dp2/generator/base.py b/dp2/generator/base.py deleted file mode 100644 index 851f00ee2fed4f8e0601405ffd635338280ec993..0000000000000000000000000000000000000000 --- a/dp2/generator/base.py +++ /dev/null @@ -1,144 +0,0 @@ -import torch -import numpy as np -import tqdm -import tops -from ..layers import Module -from ..layers.sg2_layers import FullyConnectedLayer -from dp2 import utils - - -class BaseGenerator(Module): - - def __init__(self, z_channels: int): - super().__init__() - self.z_channels = z_channels - self.latent_space = "Z" - - @torch.no_grad() - def get_z( - self, - x: torch.Tensor = None, - z: torch.Tensor = None, - truncation_value: float = None, - batch_size: int = None, - dtype=None, device=None) -> torch.Tensor: - """Generates a latent variable for generator. - """ - if z is not None: - return z - if x is not None: - batch_size = x.shape[0] - dtype = x.dtype - device = x.device - if device is None: - device = utils.get_device() - if truncation_value == 0: - return torch.zeros((batch_size, self.z_channels), device=device, dtype=dtype) - z = torch.randn((batch_size, self.z_channels), device=device, dtype=dtype) - if truncation_value is None: - return z - while z.abs().max() > truncation_value: - m = z.abs() > truncation_value - z[m] = torch.rand_like(z)[m] - return z - - def sample(self, truncation_value, z=None, **kwargs): - """ - Samples via interpolating to the mean (0). - """ - if truncation_value is None: - return self.forward(**kwargs) - truncation_value = max(0, truncation_value) - truncation_value = min(truncation_value, 1) - if z is None: - z = self.get_z(kwargs["condition"]) - z = z * truncation_value - return self.forward(**kwargs, z=z) - - - -class SG2StyleNet(torch.nn.Module): - def __init__(self, - z_dim, # Input latent (Z) dimensionality. - w_dim, # Intermediate latent (W) dimensionality. - num_layers = 2, # Number of mapping layers. - lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. - w_avg_beta = 0.998, # Decay for tracking the moving average of W during training. - ): - super().__init__() - self.z_dim = z_dim - self.w_dim = w_dim - self.num_layers = num_layers - self.w_avg_beta = w_avg_beta - # Construct layers. - features = [self.z_dim] + [self.w_dim] * self.num_layers - for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): - layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) - setattr(self, f'fc{idx}', layer) - self.register_buffer('w_avg', torch.zeros([w_dim])) - - def forward(self, z, update_emas=False, y=None): - tops.assert_shape(z, [None, self.z_dim]) - - # Embed, normalize, and concatenate inputs. - x = z.to(torch.float32) - x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() - # Execute layers. - for idx in range(self.num_layers): - x = getattr(self, f'fc{idx}')(x) - # Update moving average of W. - if update_emas: - self.w_avg.copy_(x.float().detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) - - return x - - def extra_repr(self): - return f'z_dim={self.z_dim:d}, w_dim={self.w_dim:d}' - - def update_w(self, n=int(10e3), batch_size=32): - """ - Calculate w_ema over n iterations. - Useful in cases where w_ema is calculated incorrectly during training. - """ - n = n // batch_size - for i in tqdm.trange(n, desc="Updating w"): - z = torch.randn((batch_size, self.z_dim), device=tops.get_device()) - self(z, update_emas=True) - - -class BaseStyleGAN(BaseGenerator): - - def __init__(self, z_channels: int, w_dim: int): - super().__init__(z_channels) - self.style_net = SG2StyleNet(z_channels, w_dim) - self.latent_space = "W" - - def get_w(self, z, update_emas): - return self.style_net(z, update_emas=update_emas) - - @torch.no_grad() - def sample(self, truncation_value, **kwargs): - if truncation_value is None: - return self.forward(**kwargs) - truncation_value = max(0, truncation_value) - truncation_value = min(truncation_value, 1) - w = self.get_w(self.get_z(kwargs["condition"]), False) - w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value) - return self.forward(**kwargs, w=w) - - def update_w(self, *args, **kwargs): - self.style_net.update_w(*args, **kwargs) - - - @torch.no_grad() - def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): - if truncation_value is None: - return self.forward(**kwargs) - truncation_value = max(0, truncation_value) - truncation_value = min(truncation_value, 1) - w = self.get_w(self.get_z(kwargs["condition"]), False) - if w_indices is None: - w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) - w_centers = self.style_net.w_centers[w_indices].to(w.device) - w = w_centers.to(w.dtype).lerp(w, truncation_value) - return self.forward(**kwargs, w=w) diff --git a/dp2/generator/dummy_generators.py b/dp2/generator/dummy_generators.py deleted file mode 100644 index c319e65c8191bf3f77d9d348698bff10c44e0b21..0000000000000000000000000000000000000000 --- a/dp2/generator/dummy_generators.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -import torch.nn as nn -from .base import BaseGenerator - - -class PixelationGenerator(BaseGenerator): - - def __init__(self, pixelation_size, **kwargs): - super().__init__(z_channels=0) - self.pixelation_size = pixelation_size - self.z_channels = 0 - self.latent_space=None - - def forward(self, img, condition, mask, **kwargs): - old_shape = img.shape[-2:] - img = nn.functional.interpolate(img, size=(self.pixelation_size, self.pixelation_size), mode="bilinear", align_corners=True) - img = nn.functional.interpolate(img, size=old_shape, mode="bilinear", align_corners=True) - out = img*(1-mask) + condition*mask - return {"img": out} - - -class MaskOutGenerator(BaseGenerator): - - def __init__(self, noise: str, **kwargs): - super().__init__(z_channels=0) - self.noise = noise - self.z_channels = 0 - assert self.noise in ["rand", "constant"] - self.latent_space = None - - def forward(self, img, condition, mask, **kwargs): - - if self.noise == "constant": - img = torch.zeros_like(img) - elif self.noise == "rand": - img = torch.rand_like(img) - out = img*(1-mask) + condition*mask - return {"img": out} - - -class IdentityGenerator(BaseGenerator): - - def __init__(self): - super().__init__(z_channels=0) - - def forward(self, img, condition, mask, **kwargs): - return dict(img=img) \ No newline at end of file diff --git a/dp2/generator/imagen3_old.py b/dp2/generator/imagen3_old.py deleted file mode 100644 index 87f35d7595dd018f1d0db2baddcfeb3789596ed0..0000000000000000000000000000000000000000 --- a/dp2/generator/imagen3_old.py +++ /dev/null @@ -1,1210 +0,0 @@ -# What is missing from this implementation -# 1. Global context in res block -# 2. Cross attention of conditional information in resnet block -# -from functools import partial -import tops -from tops.config import instantiate -import warnings -from typing import Iterable, List, Tuple -import numpy as np -import torch -import torch.nn as nn -from torch import einsum -from einops import rearrange -from dp2 import infer, utils -from .base import BaseGenerator -from sg3_torch_utils.ops import bias_act -from dp2.layers import Sequential -import torch.nn.functional as F -from torchvision.transforms.functional import resize, InterpolationMode -from sg3_torch_utils.ops import conv2d_resample, fma, upfirdn2d - - - - -class Upfirdn2d(torch.nn.Module): - - - def __init__(self, down=1, up=1, fix_gain=True): - super().__init__() - self.register_buffer("resample_filter", upfirdn2d.setup_filter([1, 3, 3, 1])) - fw, fh = upfirdn2d._get_filter_size(self.resample_filter) - px0, px1, py0, py1 = upfirdn2d._parse_padding(0) - self.down = down - self.up = up - if up > 1: - px0 += (fw + up - 1) // 2 - px1 += (fw - up) // 2 - py0 += (fh + up - 1) // 2 - py1 += (fh - up) // 2 - if down > 1: - px0 += (fw - down + 1) // 2 - px1 += (fw - down) // 2 - py0 += (fh - down + 1) // 2 - py1 += (fh - down) // 2 - self.padding = [px0,px1,py0,py1] - self.gain = up**2 if fix_gain else 1 - - def forward(self, x, *args): - if isinstance(x, dict): - x = {k: v for k, v in x.items()} - x["x"] = upfirdn2d.upfirdn2d(x["x"], self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain) - return x - x = upfirdn2d.upfirdn2d(x, self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain) - if len(args) == 0: - return x - return (x, *args) -@torch.no_grad() -def spatial_embed_keypoints(keypoints: torch.Tensor, x): - tops.assert_shape(keypoints, (None, None, 3)) - B, N_K, _ = keypoints.shape - H, W = x.shape[-2:] - keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32) - x, y, visible = keypoints.chunk(3, dim=2) - x = (x * W).round().long().clamp(0, W-1) - y = (y * H).round().long().clamp(0, H-1) - kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1) - pos = (kp_idx*(H*W) + y*W + x + 1) - # Offset all by 1 to index invisible keypoints as 0 - pos = (pos * visible.round().long()).squeeze(dim=-1) - keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32) - keypoint_spatial.scatter_(1, pos, 1) - keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W) - return keypoint_spatial - - -def modulated_conv2d( - x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. - weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. - styles, # Modulation coefficients of shape [batch_size, in_channels]. - noise = None, # Optional noise tensor to add to the output activations. - up = 1, # Integer upsampling factor. - down = 1, # Integer downsampling factor. - padding = 0, # Padding with respect to the upsampled image. - resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). - demodulate = True, # Apply weight demodulation? - flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). - fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? -): - batch_size = x.shape[0] - out_channels, in_channels, kh, kw = weight.shape - tops.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] - tops.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] - tops.assert_shape(styles, [batch_size, in_channels]) # [NI] - - # Pre-normalize inputs to avoid FP16 overflow. - if x.dtype == torch.float16 and demodulate: - weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk - styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I - - # Calculate per-sample weights and demodulation coefficients. - w = None - dcoefs = None - if demodulate or fused_modconv: - w = weight.unsqueeze(0) # [NOIkk] - w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] - if demodulate: - dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] - if demodulate and fused_modconv: - w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] - - # Execute by scaling the activations before and after the convolution. - if not fused_modconv: - x = x * styles.reshape(batch_size, -1, 1, 1) - x = conv2d_resample.conv2d_resample(x=x, w=weight, f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) - if demodulate and noise is not None: - x = fma.fma(x, dcoefs.reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) - elif demodulate: - x = x * dcoefs.reshape(batch_size, -1, 1, 1) - elif noise is not None: - x = x.add_(noise.to(x.dtype)) - return x - - with tops.suppress_tracer_warnings(): # this value will be treated as a constant - batch_size = int(batch_size) - # Execute as one fused op using grouped convolution. - tops.assert_shape(x, [batch_size, in_channels, None, None]) - x = x.reshape(1, -1, *x.shape[2:]) - w = w.reshape(-1, in_channels, kh, kw) - x = conv2d_resample.conv2d_resample(x=x, w=w, f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) - x = x.reshape(batch_size, -1, *x.shape[2:]) - if noise is not None: - x = x.add_(noise) - return x - - -class Identity(nn.Module): - - def __init__(self) -> None: - super().__init__() - - def forward(self, x, *args, **kwargs): - return x - - -class LayerNorm(nn.Module): - def __init__(self, dim, stable=False): - super().__init__() - self.stable = stable - self.g = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - if self.stable: - x = x / x.amax(dim=-1, keepdim=True).detach() - - eps = 1e-5 if x.dtype == torch.float32 else 1e-3 - var = torch.var(x, dim=-1, unbiased=False, keepdim=True) - mean = torch.mean(x, dim=-1, keepdim=True) - return (x - mean) * (var + eps).rsqrt() * self.g - - -class FullyConnectedLayer(torch.nn.Module): - def __init__(self, - in_features, # Number of input features. - out_features, # Number of output features. - bias = True, # Apply additive bias before the activation function? - activation = 'linear', # Activation function: 'relu', 'lrelu', etc. - lr_multiplier = 1, # Learning rate multiplier. - bias_init = 0, # Initial value for the additive bias. - ): - super().__init__() - self.repr = dict( - in_features=in_features, out_features=out_features, bias=bias, - activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init) - self.activation = activation - self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) - self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None - self.weight_gain = lr_multiplier / np.sqrt(in_features) - self.bias_gain = lr_multiplier - self.in_features = in_features - self.out_features = out_features - - def forward(self, x): - w = self.weight * self.weight_gain - b = self.bias - if b is not None: - if self.bias_gain != 1: - b = b * self.bias_gain - x = F.linear(x, w) - x = bias_act.bias_act(x, b, act=self.activation) - return x - - def extra_repr(self) -> str: - return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) - - - -def checkpoint_fn(fn, *args, **kwargs): - warnings.simplefilter("ignore") - return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs) - -class Conv2d(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size=3, - activation='lrelu', - conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. - bias=True, - norm=None, - lr_multiplier=1, - bias_init=0, - w_dim=None, - gradient_checkpoint_norm=False, - gain=1, - ): - super().__init__() - self.fused_modconv = False - if norm == torch.nn.InstanceNorm2d: - self.norm = torch.nn.InstanceNorm2d(None) - elif isinstance(norm, torch.nn.Module): - self.norm = norm - elif norm == "fused_modconv": - self.fused_modconv = True - elif norm: - self.norm = torch.nn.InstanceNorm2d(None) - elif norm is not None: - raise ValueError(f"norm not supported: {norm}") - self.activation = activation - self.conv_clamp = conv_clamp - self.out_channels = out_channels - self.in_channels = in_channels - self.padding = kernel_size // 2 - self.repr = dict( - in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, - activation=activation, conv_clamp=conv_clamp, bias=bias, - fused_modconv=self.fused_modconv - ) - self.act_gain = bias_act.activation_funcs[activation].def_gain * gain - self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2)) - self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size])) - self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None - self.bias_gain = lr_multiplier - if w_dim is not None: - if self.fused_modconv: - self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) - else: - self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) - self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0) - self.gradient_checkpoint_norm = gradient_checkpoint_norm - - def forward(self, x, w=None, gain=1, **kwargs): - if self.fused_modconv: - styles = self.affine(w) - with torch.cuda.amp.autocast(enabled=False): - x = modulated_conv2d(x=x.half(), weight=self.weight.half(), styles=styles.half(), noise=None, - padding=self.padding, flip_weight=True, fused_modconv=False).to(x.dtype) - else: - if hasattr(self, "affine"): - gamma = self.affine(w).view(-1, self.in_channels, 1, 1) - beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1) - x = fma.fma(x, gamma ,beta) - w = self.weight * self.weight_gain - x = F.conv2d(input=x, weight=w, padding=self.padding,) - - if hasattr(self, "norm"): - if self.gradient_checkpoint_norm: - x = checkpoint_fn(self.norm, x) - else: - x = self.norm(x) - act_gain = self.act_gain * gain - act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None - b = self.bias * self.bias_gain if self.bias is not None else None - x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) - return x - - def extra_repr(self) -> str: - return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) - - -class CrossAttention(nn.Module): - def __init__( - self, - dim, - context_dim, - dim_head=64, - heads=8, - norm_context=False, - ): - super().__init__() - self.scale = dim_head ** -0.5 - - self.heads = heads - inner_dim = dim_head * heads - - self.norm = nn.InstanceNorm1d(dim) - self.norm_context = nn.InstanceNorm1d(None) if norm_context else Identity() - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim, bias=False), - nn.InstanceNorm1d(None) - ) - - def forward(self, x, w): - x = self.norm(x) - w = self.norm_context(w) - - q, k, v = (self.to_q(x), *self.to_kv(w).chunk(2, dim = -1)) - q = rearrange(q, "b n (h d) -> b h n d", h = self.heads) - k = rearrange(k, "b n (h d) -> b h n d", h = self.heads) - v = rearrange(v, "b n (h d) -> b h n d", h = self.heads) - q = q * self.scale - # similarities - sim = einsum('b h i d, b h j d -> b h i j', q, k) - attn = sim.softmax(dim = -1, dtype = torch.float32) - - out = einsum('b h i j, b h j d -> b h i d', attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) - - -class SG2ResidualBlock(torch.nn.Module): - def __init__( - self, - in_channels, # Number of input channels, 0 = first block. - out_channels, # Number of output channels. - conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. - skip_gain=np.sqrt(.5), - cross_attention: bool = False, - cross_attention_len: int = None, - use_adain: bool = True, - **layer_kwargs, # Arguments for conv layer. - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - w_dim = layer_kwargs.pop("w_dim") if "w_dim" in layer_kwargs else None - if use_adain: - layer_kwargs["w_dim"] = w_dim - - self.conv0 = Conv2d(in_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs) - self.conv1 = Conv2d(out_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs, gain=skip_gain) - - self.skip = Conv2d(in_channels, out_channels, kernel_size=1, bias=False, gain=skip_gain) - if cross_attention and w_dim is not None: - self.cross_attention_len = cross_attention_len - self.cross_attn = CrossAttention( - dim=out_channels, context_dim=w_dim//self.cross_attention_len, - gain=skip_gain) - - def forward(self, x, w=None, **layer_kwargs): - y = self.skip(x) - x = self.conv0(x, w, **layer_kwargs) - x = self.conv1(x, w, **layer_kwargs) - if hasattr(self, "cross_attn"): - h = x.shape[-2] - x = rearrange(x, "b c h w -> b (h w) c") - w = rearrange(w, "b (n c) -> b n c", n=self.cross_attention_len) - x = self.cross_attn(x, w=w) + x - x = rearrange(x, "b (h w) c -> b c h w", h=h) - return y + x - - -def default(val, d): - if val is not None: - return val - return d() if callable(d) else d - - -def cast_tuple(val, length=None): - if isinstance(val, Iterable) and not isinstance(val, str): - val = tuple(val) - output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) - if length is not None: - assert len(output) == length, (output, length) - return output - - -class Attention(nn.Module): - # This is a version of Multi-Query Attention () - # Fast Transformer Decoding: One Write-Head is All You Need - # Ablated in: https://arxiv.org/pdf/2203.07814.pdf - # and https://arxiv.org/pdf/2204.02311.pdf - def __init__(self, dim, norm, attn_fix_gain, gradient_checkpoint, dim_head=64, heads=8, cosine_sim_attn=False, fix_attention_again=False, gain=None): - super().__init__() - self.scale = dim_head**-0.5 if not cosine_sim_attn else 1.0 - self.cosine_sim_attn = cosine_sim_attn - self.cosine_sim_scale = 16 if cosine_sim_attn else 1 - self.gradient_checkpoint = gradient_checkpoint - self.heads = heads - self.dim = dim - self.fix_attention_again = fix_attention_again - inner_dim = dim_head * heads - if norm == "LN": - self.norm = LayerNorm(dim) - elif norm == "IN": - self.norm = nn.InstanceNorm1d(dim) - elif norm is None: - self.norm = nn.Identity() - else: - raise ValueError(f"Norm not supported: {norm}") - - self.to_q = FullyConnectedLayer(dim, inner_dim, bias=False) - self.to_kv = FullyConnectedLayer(dim, dim_head*2, bias=False) - - self.to_out = nn.Sequential( - FullyConnectedLayer(inner_dim, dim, bias=False), - LayerNorm(dim) if norm == "LN" else nn.InstanceNorm1d(dim) - ) - if fix_attention_again: - assert gain is not None - self.gain = gain - else: - self.gain = np.sqrt(.5) if attn_fix_gain else 1 - - def run_function(self, x, attn_bias): - b, c, h, w = x.shape - x = rearrange(x, "b c h w -> b (h w) c") - in_ = x - b, n, device = *x.shape[:2], x.device - x = self.norm(x) - q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) - - q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) - q = q * self.scale - - # calculate query / key similarities - sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale - - if attn_bias is not None: - attn_bias = attn_bias - attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)") - sim = sim + attn_bias - - attn = sim.softmax(dim=-1) - - out = einsum("b h i j, b j d -> b h i d", attn, v) - - out = rearrange(out, "b h n d -> b n (h d)") - if self.fix_attention_again: - out = self.to_out(out)*self.gain + in_ - else: - out = (self.to_out(out) + in_) * self.gain - out = rearrange(out, "b (h w) c -> b c h w", h=h) - return out - - def forward(self, x, *args, attn_bias=None, **kwargs): - if self.gradient_checkpoint: - return checkpoint_fn(self.run_function, x, attn_bias) - return self.run_function(x, attn_bias) - - def get_attention(self, x, attn_bias=None): - b, c, h, w = x.shape - x = rearrange(x, "b c h w -> b (h w) c") - in_ = x - b, n, device = *x.shape[:2], x.device - x = self.norm(x) - q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) - - q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) - q = q * self.scale - - # calculate query / key similarities - sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale - - if attn_bias is not None: - attn_bias = attn_bias - attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)") - sim = sim + attn_bias - - attn = sim.softmax(dim=-1) - return attn, None - - -class BiasedAttention(Attention): - - def __init__(self, *args, head_wise: bool=True, **kwargs): - super().__init__(*args, **kwargs) - out_ch = self.heads if head_wise else 1 - self.conv = Conv2d(self.dim+2, out_ch, activation="linear", kernel_size=3, bias_init=0) - nn.init.zeros_(self.conv.weight.data) - - def forward(self, x, mask): - mask = resize(mask, size=x.shape[-2:]) - bias = self.conv(torch.cat((x, mask, 1-mask), dim=1)) - return super().forward(x=x, attn_bias=bias) - - def get_attention(self, x, mask): - mask = resize(mask, size=x.shape[-2:]) - bias = self.conv(torch.cat((x, mask, 1-mask), dim=1)) - return super().get_attention(x, bias)[0], bias - -class UNet(BaseGenerator): - - def __init__( - self, - im_channels: int, - dim: int, - dim_mults: tuple, - num_resnet_blocks, # Number of resnet blocks per resolution - n_middle_blocks: int, - z_channels: int, - conv_clamp: int, - layer_attn, - w_dim: int, - norm_enc: bool, - norm_dec: str, - stylenet: nn.Module, - enc_style: bool, # Toggle style injection in encoder - use_maskrcnn_mask: bool, - skip_all_unets: bool, - fix_resize:bool, - comodulate: bool, - comod_net: nn.Module, - lr_comod: float, - dec_style: bool, - input_keypoints: bool, - n_keypoints: int, - input_keypoint_indices: Tuple[int], - use_adain: bool, - cross_attention: bool, - cross_attention_len: int, - gradient_checkpoint_norm: bool, - attn_cls: partial, - mask_out_train: bool, - fix_gain_again: bool, - ) -> None: - super().__init__(z_channels) - self.enc_style = enc_style - self.n_keypoints = n_keypoints - self.input_keypoint_indices = list(input_keypoint_indices) - self.input_keypoints = input_keypoints - self.mask_out_train = mask_out_train - n_layers = len(dim_mults) - self.n_layers = n_layers - layer_attn = cast_tuple(layer_attn, n_layers) - num_resnet_blocks = cast_tuple(num_resnet_blocks, n_layers) - self._cnum = dim - self._image_channels = im_channels - self._z_channels = z_channels - encoder_layers = [] - condition_ch = im_channels - self.from_rgb = Conv2d( - condition_ch + 2 + 2*int(use_maskrcnn_mask) + self.input_keypoints*len(input_keypoint_indices) - , dim, 7) - - self.use_maskrcnn_mask = use_maskrcnn_mask - self.skip_all_unets = skip_all_unets - dims = [dim*m for m in dim_mults] - enc_blk = partial( - SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_enc, - use_adain=use_adain and self.enc_style, - w_dim=w_dim, - cross_attention=cross_attention, - cross_attention_len=cross_attention_len, - gradient_checkpoint_norm=gradient_checkpoint_norm - ) - dec_blk = partial( - SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_dec, - use_adain=use_adain and dec_style, - w_dim=w_dim, - cross_attention=cross_attention, - cross_attention_len=cross_attention_len, - gradient_checkpoint_norm=gradient_checkpoint_norm - ) - # Currently up/down sampling is done by bilinear upsampling. - # This can be simplified by replacing it with a strided upsampling layer... - self.encoder_attns = nn.ModuleList() - for lidx in range(n_layers): - gain = np.sqrt(1/3) if layer_attn[lidx] and fix_gain_again else np.sqrt(.5) - dim_in = dims[lidx] - dim_out = dims[min(lidx+1, n_layers-1)] - res_blocks = nn.ModuleList() - for i in range(num_resnet_blocks[lidx]): - is_last = num_resnet_blocks[lidx] - 1 == i - cur_dim = dim_out if is_last else dim_in - block = enc_blk(dim_in, cur_dim, skip_gain=gain) - res_blocks.append(block) - if layer_attn[lidx]: - self.encoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain)) - else: - self.encoder_attns.append(Identity()) - encoder_layers.append(res_blocks) - self.encoder = torch.nn.ModuleList(encoder_layers) - - # initialize decoder - decoder_layers = [] - self.unet_layers = torch.nn.ModuleList() - self.decoder_attns = torch.nn.ModuleList() - for lidx in range(n_layers): - dim_in = dims[min(-lidx, -1)] - dim_out = dims[-1-lidx] - res_blocks = nn.ModuleList() - unet_skips = nn.ModuleList() - for i in range(num_resnet_blocks[-lidx-1]): - is_first = i == 0 - has_unet = is_first or skip_all_unets - is_last = i == num_resnet_blocks[-lidx-1] - 1 - cur_dim = dim_in if is_first else dim_out - if has_unet and is_last and layer_attn[-lidx-1] and fix_gain_again: # x + residual + unet + layer attn - gain = np.sqrt(1/4) - elif has_unet: # x + residual + unet - gain = np.sqrt(1/3) - elif layer_attn[-lidx-1] and fix_gain_again: # x + residual + attention - gain = np.sqrt(1/3) - else: # x + residual - gain = np.sqrt(1/2) # Only residual block - block = dec_blk(cur_dim, dim_out, skip_gain=gain) - res_blocks.append(block) - if has_unet: - unet_block = Conv2d( - cur_dim, cur_dim, kernel_size=1, conv_clamp=conv_clamp, - norm=nn.InstanceNorm2d(None), - gradient_checkpoint_norm=gradient_checkpoint_norm, - gain=gain) - unet_skips.append(unet_block) - else: - unet_skips.append(torch.nn.Identity()) - if layer_attn[-lidx-1]: - self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain)) - else: - self.decoder_attns.append(Identity()) - - decoder_layers.append(res_blocks) - self.unet_layers.append(unet_skips) - - middle_blocks = [] - for i in range(n_middle_blocks): - block = dec_blk(dims[-1], dims[-1]) - middle_blocks.append(block) - if n_middle_blocks != 0: - self.middle_blocks = Sequential(*middle_blocks) - self.decoder = torch.nn.ModuleList(decoder_layers) - self.to_rgb = Conv2d(dim, im_channels, 1, activation="linear", conv_clamp=conv_clamp) - self.stylenet = stylenet - self.downsample = Upfirdn2d(down=2, fix_gain=fix_resize) - self.upsample = Upfirdn2d(up=2, fix_gain=fix_resize) - self.comodulate = comodulate - if comodulate: - assert not self.enc_style - self.to_y = nn.Sequential( - Conv2d(dims[-1], dims[-1], lr_multiplier=lr_comod, gradient_checkpoint_norm=gradient_checkpoint_norm), - nn.AdaptiveAvgPool2d(1), - nn.Flatten(), - FullyConnectedLayer(dims[-1], 512, activation="lrelu", lr_multiplier=lr_comod) - ) - self.comod_net = comod_net - - - def forward(self, condition, mask, maskrcnn_mask=None, z=None, w=None, update_emas=False, keypoints=None, return_decoder_features=False, **kwargs): - if z is None: - z = self.get_z(condition) - if w is None: - w = self.stylenet(z, update_emas=update_emas) - if self.use_maskrcnn_mask: - x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) - else: - x = torch.cat((condition, mask, 1-mask), dim=1) - - if self.input_keypoints: - keypoints = keypoints[:, self.input_keypoint_indices] - one_hot_pose = spatial_embed_keypoints(keypoints, x) - x = torch.cat((x, one_hot_pose), dim=1) - x = self.from_rgb(x) - x, unet_features = self.forward_enc(x, mask, w) - x, decoder_features = self.forward_dec(x, mask, w, unet_features) - x = self.to_rgb(x) - unmasked = x - if self.mask_out_train: - x = mask * condition + (1-mask) * x - out = dict(img=x, unmasked=unmasked) - if return_decoder_features: - out["decoder_features"] = decoder_features - return out - - def forward_enc(self, x, mask, w): - unet_features = [] - for i, res_blocks in enumerate(self.encoder): - is_last = i == len(self.encoder) - 1 - for block in res_blocks: - x = block(x, w=w) - unet_features.append(x) - x = self.encoder_attns[i](x, mask=mask) - if not is_last: - x = self.downsample(x) - if self.comodulate: - y = self.to_y(x) - y = torch.cat((w, y), dim=-1) - w = self.comod_net(y) - return x, unet_features - - def forward_dec(self, x, mask, w, unet_features): - if hasattr(self, "middle_blocks"): - x = self.middle_blocks(x, w=w) - features = [] - unet_features = iter(reversed(unet_features)) - for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)): - is_last = i == len(self.decoder) - 1 - for skip, block in zip(unet_skip, res_blocks): - skip_x = next(unet_features) - if not isinstance(skip, torch.nn.Identity): - skip_x = skip(skip_x) - x = x + skip_x - x = block(x, w=w) - x = self.decoder_attns[i](x, mask=mask) - features.append(x) - if not is_last: - x = self.upsample(x) - return x, features - - def get_w(self, z, update_emas): - return self.stylenet(z, update_emas=update_emas) - - @torch.no_grad() - def sample(self, truncation_value, **kwargs): - if truncation_value is None: - return self.forward(**kwargs) - truncation_value = max(0, truncation_value) - truncation_value = min(truncation_value, 1) - w = self.get_w(self.get_z(kwargs["condition"]), False) - w = self.stylenet.w_avg.to(w.dtype).lerp(w, truncation_value) - return self.forward(**kwargs, w=w) - - def update_w(self, *args, **kwargs): - self.style_net.update_w(*args, **kwargs) - - @property - def style_net(self): - return self.stylenet - - @torch.no_grad() - def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): - if truncation_value is None: - return self.forward(**kwargs) - truncation_value = max(0, truncation_value) - truncation_value = min(truncation_value, 1) - w = self.get_w(self.get_z(kwargs["condition"]), False) - if w_indices is None: - w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) - w_centers = self.style_net.w_centers[w_indices].to(w.device) - w = w_centers.to(w.dtype).lerp(w, truncation_value) - return self.forward(**kwargs, w=w) - - -def get_stem_unet_kwargs(cfg): - if "stem_cfg" in cfg.generator: # If the stem has another stem, recursively apply get_stem_unet_kwargs - return get_stem_unet_kwargs(cfg.generator.stem_cfg) - return dict(cfg.generator) - - -class GrowingUnet(BaseGenerator): - - def __init__( - self, - coarse_stem_cfg: str, # This can be a coarse generator or None - sr_cfg: str, # Can be a previous progressive u-net, Unet or None - residual: bool, - new_dataset: bool, # The "new dataset" creates condition first -> resizes - **unet_kwargs): - kwargs = dict() - if coarse_stem_cfg is not None: - coarse_stem_cfg = utils.load_config(coarse_stem_cfg) - kwargs = get_stem_unet_kwargs(coarse_stem_cfg) - if sr_cfg is not None: - sr_cfg = utils.load_config(sr_cfg) - sr_stem_unet_kwargs = get_stem_unet_kwargs(sr_cfg) - kwargs.update(sr_stem_unet_kwargs) - kwargs.update(unet_kwargs) - kwargs["stylenet"] = None - kwargs.pop("_target_") - if "sr_cfg" in kwargs: # Unet kwargs are inherited, do not pass this to the new u-net - del kwargs["sr_cfg"] - if "coarse_stem_cfg" in kwargs: - del kwargs["coarse_stem_cfg"] - super().__init__(z_channels=kwargs["z_channels"]) - if coarse_stem_cfg is not None: - z_channels = coarse_stem_cfg.generator.z_channels - super().__init__(z_channels) - self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval() - self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize) - utils.set_requires_grad(self.coarse_stem, False) - else: - assert not residual - - if sr_cfg is not None: - self.sr_stem = infer.build_trained_generator(sr_cfg, map_location="cpu").eval() - del self.sr_stem.from_rgb - del self.sr_stem.to_rgb - if hasattr(self.sr_stem, "coarse_stem"): - del self.sr_stem.coarse_stem - if isinstance(self.sr_stem, UNet): - del self.sr_stem.encoder[0][0] # Delete first residual block - del self.sr_stem.decoder[-1][-1] # Delete last residual block - else: - assert isinstance(self.sr_stem, GrowingUnet) - del self.sr_stem.unet.encoder[0][0] # Delete first residual block - del self.sr_stem.unet.decoder[-1][-1] # Delete last residual block - utils.set_requires_grad(self.sr_stem, False) - - - args = kwargs.pop("_args_") - if hasattr(self, "sr_stem"): # Growing the SR stem - Add a new layer to match sr - n_layers = len(kwargs["dim_mults"]) - dim_mult = sr_stem_unet_kwargs["dim"] / (kwargs["dim"] * max(kwargs["dim_mults"])) - kwargs["dim_mults"] = [*kwargs["dim_mults"], int(dim_mult)] - kwargs["layer_attn"] = [*cast_tuple(kwargs["layer_attn"], n_layers), False] - kwargs["num_resnet_blocks"] = [*cast_tuple(kwargs["num_resnet_blocks"], n_layers), 1] - self.unet = UNet( - *args, - **kwargs - ) - self.from_rgb = self.unet.from_rgb - self.to_rgb = self.unet.to_rgb - self.residual = residual - self.new_dataset = new_dataset - if residual: - nn.init.zeros_(self.to_rgb.weight.data) - del self.unet.from_rgb, self.unet.to_rgb - - def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, **kwargs): - # Downsample for stem - if z is None: - z = self.get_z(img) - if w is None: - w = self.style_net(z) - if hasattr(self, "coarse_stem"): - with torch.no_grad(): - if self.new_dataset: - img_stem = utils.denormalize_img(img)*255 - condition_stem = img_stem * mask + (1-mask)*127 - condition_stem = condition_stem.round() - condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True) - condition_stem = condition_stem / 255 *2 - 1 - mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float() - maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float() - else: - mask_stem = (resize(mask, self.coarse_stem.imsize, antialias=True) > .99).float() - maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, antialias=True) > .5).float() - img_stem = utils.denormalize_img(img)*255 - img_stem = resize(img_stem, self.coarse_stem.imsize, antialias=True).round() - img_stem = img_stem / 255 * 2 - 1 - condition_stem = img_stem * mask_stem - stem_out = self.coarse_stem( - condition=condition_stem, mask=mask_stem, - maskrcnn_mask=maskrcnn_stem, w=w, - keypoints=keypoints) - x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True) - condition = condition*mask + (1-mask) * x_lr - if self.unet.use_maskrcnn_mask: - x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) - else: - x = torch.cat((condition, mask, 1-mask), dim=1) - if self.unet.input_keypoints: - keypoints = keypoints[:, self.unet.input_keypoint_indices] - one_hot_pose = spatial_embed_keypoints(keypoints, x) - x = torch.cat((x, one_hot_pose), dim=1) - x = self.from_rgb(x) - x, unet_features = self.forward_enc(x, mask, w) - x = self.forward_dec(x, mask, w, unet_features) - if self.residual: - x = self.to_rgb(x) + condition - else: - x = self.to_rgb(x) - return dict( - img=condition * mask + (1-mask) * x, - unmasked=x, - x_lowres=[condition] - ) - - def forward_enc(self, x, mask, w): - x, unet_features = self.unet.forward_enc(x, mask, w) - if hasattr(self, "sr_stem"): - x, unet_features_stem = self.sr_stem.forward_enc(x, mask, w) - else: - unet_features_stem = None - return x, [unet_features, unet_features_stem] - - def forward_dec(self, x, mask, w, unet_features): - unet_features, unet_features_stem = unet_features - if hasattr(self, "sr_stem"): - x = self.sr_stem.forward_dec(x, mask, w, unet_features_stem) - x, unet_features = self.unet.forward_dec(x, mask, w, unet_features) - return x - - def get_z(self, *args, **kwargs): - if hasattr(self, "coarse_stem"): - return self.coarse_stem.get_z(*args, **kwargs) - if hasattr(self, "sr_stem"): - return self.sr_stem.get_z(*args, **kwargs) - raise AttributeError() - - @property - def style_net(self): - if hasattr(self, "coarse_stem"): - return self.coarse_stem.style_net - if hasattr(self, "sr_stem"): - return self.sr_stem.style_net - raise AttributeError() - - def update_w(self, *args, **kwargs): - self.style_net.update_w(*args, **kwargs) - - def get_w(self, z, update_emas): - return self.style_net(z, update_emas=update_emas) - - @torch.no_grad() - def sample(self, truncation_value, **kwargs): - if truncation_value is None: - return self.forward(**kwargs) - truncation_value = max(0, truncation_value) - truncation_value = min(truncation_value, 1) - w = self.get_w(self.get_z(kwargs["condition"]), False) - w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value) - return self.forward(**kwargs, w=w) - - @torch.no_grad() - def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): - if truncation_value is None: - return self.forward(**kwargs) - truncation_value = max(0, truncation_value) - truncation_value = min(truncation_value, 1) - w = self.get_w(self.get_z(kwargs["condition"]), False) - if w_indices is None: - w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) - w_centers = self.style_net.w_centers[w_indices].to(w.device) - w = w_centers.to(w.dtype).lerp(w, truncation_value) - return self.forward(**kwargs, w=w) - - -class CascadedUnet(BaseGenerator): - - def __init__( - self, - coarse_stem_cfg: str, # This can be a coarse generator or None - residual: bool, - new_dataset: bool, # The "new dataset" creates condition first -> resizes - imsize: tuple, - cascade:bool, - **unet_kwargs): - kwargs = dict() - coarse_stem_cfg = utils.load_config(coarse_stem_cfg) - kwargs = get_stem_unet_kwargs(coarse_stem_cfg) - kwargs.update(unet_kwargs) - super().__init__(z_channels=kwargs["z_channels"]) - - self.input_keypoints = kwargs["input_keypoints"] - self.input_keypoint_indices = kwargs["input_keypoint_indices"] - self.use_maskrcnn_mask = kwargs["use_maskrcnn_mask"] - self.imsize = imsize - self.residual = residual - self.new_dataset = new_dataset - - - # Setup coarse stem - stem_dims = [m*coarse_stem_cfg.generator.dim for m in coarse_stem_cfg.generator.dim_mults] - self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval() - self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize) - utils.set_requires_grad(self.coarse_stem, False) - - self.stem_res_to_layer_idx = { - self.coarse_stem.imsize[0] // 2^i: stem_dims[i] - for i in range(len(stem_dims)) - } - - dim = kwargs["dim"] - dim_mults = kwargs["dim_mults"] - n_layers = len(dim_mults) - dims = [dim*s for s in dim_mults] - layer_attn = cast_tuple(kwargs["layer_attn"], n_layers) - num_resnet_blocks = cast_tuple(kwargs["num_resnet_blocks"], n_layers) - attn_cls = kwargs["attn_cls"] - if not isinstance(attn_cls, partial): - attn_cls = instantiate(attn_cls) - - dec_blk = partial( - SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_dec"], - use_adain=kwargs["use_adain"] and kwargs["dec_style"], - w_dim=kwargs["w_dim"], - cross_attention=kwargs["cross_attention"], - cross_attention_len=kwargs["cross_attention_len"], - gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"] - ) - enc_blk = partial( - SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_enc"], - use_adain=kwargs["use_adain"] and kwargs["enc_style"], - w_dim=kwargs["w_dim"], - cross_attention=kwargs["cross_attention"], - cross_attention_len=kwargs["cross_attention_len"], - gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"] - ) - - # Currently up/down sampling is done by bilinear upsampling. - # This can be simplified by replacing it with a strided upsampling layer... - self.encoder_attns = nn.ModuleList() - self.encoder_unet_skips = nn.ModuleDict() - self.encoder = nn.ModuleList() - for lidx in range(n_layers): - has_stem_feature = imsize[0]//2^lidx in self.stem_res_to_layer_idx and cascade - next_layer_has_stem_features = lidx+1 < n_layers and imsize[0]//2^(lidx+1) in self.stem_res_to_layer_idx and cascade - - dim_in = dims[lidx] - dim_out = dims[min(lidx+1, n_layers-1)] - res_blocks = nn.ModuleList() - if has_stem_feature: - prev_layer_has_attention = lidx != 0 and layer_attn[lidx-1] - stem_lidx = self.stem_res_to_layer_idx[imsize[0]//2^lidx] - self.encoder_unet_skips.add_module( - str(imsize[0]//2^lidx), - Conv2d( - stem_dims[stem_lidx], dim_in, kernel_size=1, - conv_clamp=kwargs["conv_clamp"], - norm=nn.InstanceNorm2d(None), - gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"], - gain=np.sqrt(1/4) if prev_layer_has_attention else np.sqrt(1/3) # This + previous residual + attention - ) - ) - for i in range(num_resnet_blocks[lidx]): - is_last = num_resnet_blocks[lidx] - 1 == i - cur_dim = dim_out if is_last else dim_in - if not is_last: - gain = np.sqrt(.5) - elif next_layer_has_stem_features and layer_attn[lidx]: - gain = np.sqrt(1/4) - elif layer_attn[lidx] or next_layer_has_stem_features: - gain = np.sqrt(1/3) - else: - gain = np.sqrt(.5) - block = enc_blk(dim_in, cur_dim, skip_gain=gain) - res_blocks.append(block) - if layer_attn[lidx]: - self.encoder_attns.append(attn_cls(dim=dim_out, gain=gain, fix_attention_again=True)) - else: - self.encoder_attns.append(Identity()) - self.encoder.append(res_blocks) - - # initialize decoder - self.decoder = torch.nn.ModuleList() - self.unet_layers = torch.nn.ModuleList() - self.decoder_attns = torch.nn.ModuleList() - for lidx in range(n_layers): - dim_in = dims[min(-lidx, -1)] - dim_out = dims[-1-lidx] - res_blocks = nn.ModuleList() - unet_skips = nn.ModuleList() - for i in range(num_resnet_blocks[-lidx-1]): - is_first = i == 0 - has_unet = is_first or kwargs["skip_all_unets"] - is_last = i == num_resnet_blocks[-lidx-1] - 1 - cur_dim = dim_in if is_first else dim_out - if has_unet and is_last and layer_attn[-lidx-1]: # x + residual + unet + layer attn - gain = np.sqrt(1/4) - elif has_unet: # x + residual + unet - gain = np.sqrt(1/3) - elif layer_attn[-lidx-1]: # x + residual + attention - gain = np.sqrt(1/3) - else: # x + residual - gain = np.sqrt(1/2) # Only residual block - block = dec_blk(cur_dim, dim_out, skip_gain=gain) - res_blocks.append(block) - if kwargs["skip_all_unets"] or is_first: - unet_block = Conv2d( - cur_dim, cur_dim, kernel_size=1, conv_clamp=kwargs["conv_clamp"], - norm=nn.InstanceNorm2d(None), - gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"], - gain=gain) - unet_skips.append(unet_block) - else: - unet_skips.append(torch.nn.Identity()) - if layer_attn[-lidx-1]: - self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=True, gain=gain)) - else: - self.decoder_attns.append(Identity()) - - self.decoder.append(res_blocks) - self.unet_layers.append(unet_skips) - - self.from_rgb = Conv2d( - 3 + 2 + 2*int(kwargs["use_maskrcnn_mask"]) + self.input_keypoints*len(kwargs["input_keypoint_indices"]) - , dim, 7) - self.to_rgb = Conv2d(dim, 3, 1, activation="linear", conv_clamp=kwargs["conv_clamp"]) - - self.downsample = Upfirdn2d(down=2, fix_gain=True) - self.upsample = Upfirdn2d(up=2, fix_gain=True) - self.cascade = cascade - if residual: - nn.init.zeros_(self.to_rgb.weight.data) - - def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, return_decoder_features=False, **kwargs): - # Downsample for stem - if z is None: - z = self.get_z(img) - - with torch.no_grad(): # Forward pass stem - if w is None: - w = self.style_net(z) - img_stem = utils.denormalize_img(img)*255 - condition_stem = img_stem * mask + (1-mask)*127 - condition_stem = condition_stem.round() - condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True) - condition_stem = condition_stem / 255 *2 - 1 - mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float() - maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float() - stem_out = self.coarse_stem( - condition=condition_stem, mask=mask_stem, - maskrcnn_mask=maskrcnn_stem, w=w, - keypoints=keypoints, - return_decoder_features=True) - stem_features = stem_out["decoder_features"] - x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True) - condition = condition*mask + (1-mask) * x_lr - - if self.use_maskrcnn_mask: - x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) - else: - x = torch.cat((condition, mask, 1-mask), dim=1) - if self.input_keypoints: - keypoints = keypoints[:, self.input_keypoint_indices] - one_hot_pose = spatial_embed_keypoints(keypoints, x) - x = torch.cat((x, one_hot_pose), dim=1) - x = self.from_rgb(x) - x, unet_features = self.forward_enc(x, mask, w, stem_features) - x, decoder_features = self.forward_dec(x, mask, w, unet_features) - if self.residual: - x = self.to_rgb(x) + condition - else: - x = self.to_rgb(x) - out= dict( - img=condition * mask + (1-mask) * x, # TODO: Probably do not want masked here... or ?? - unmasked=x, - x_lowres=[condition] - ) - if return_decoder_features: - out["decoder_features"] = decoder_features - return out - - def forward_enc(self, x, mask, w, stem_features: List[torch.Tensor]): - unet_features = [] - stem_features.reverse() - for i, res_blocks in enumerate(self.encoder): - is_last = i == len(self.encoder) - 1 - res = self.imsize[0]//2^i - if str(res) in self.encoder_unet_skips.keys() and self.cascade: - y = stem_features[self.stem_res_to_layer_idx[res]] - y = self.encoder_unet_skips[i](y) - x = y + x - for block in res_blocks: - x = block(x, w=w) - unet_features.append(x) - x = self.encoder_attns[i](x, mask) - if not is_last: - x = self.downsample(x) - return x, unet_features - - def forward_dec(self, x, mask, w, unet_features): - features = [] - unet_features = iter(reversed(unet_features)) - for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)): - is_last = i == len(self.decoder) - 1 - for skip, block in zip(unet_skip, res_blocks): - skip_x = next(unet_features) - if not isinstance(skip, torch.nn.Identity): - skip_x = skip(skip_x) - x = x + skip_x - x = block(x, w=w) - x = self.decoder_attns[i](x, mask) - features.append(x) - if not is_last: - x = self.upsample(x) - return x, features - - def get_z(self, *args, **kwargs): - return self.coarse_stem.get_z(*args, **kwargs) - - @property - def style_net(self): - return self.coarse_stem.style_net - - def update_w(self, *args, **kwargs): - self.style_net.update_w(*args, **kwargs) - - def get_w(self, z, update_emas): - return self.style_net(z, update_emas=update_emas) - - @torch.no_grad() - def sample(self, truncation_value, **kwargs): - if truncation_value is None: - return self.forward(**kwargs) - truncation_value = max(0, truncation_value) - truncation_value = min(truncation_value, 1) - w = self.get_w(self.get_z(kwargs["condition"]), False) - w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value) - return self.forward(**kwargs, w=w) - - @torch.no_grad() - def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): - if truncation_value is None: - return self.forward(**kwargs) - truncation_value = max(0, truncation_value) - truncation_value = min(truncation_value, 1) - w = self.get_w(self.get_z(kwargs["condition"]), False) - if w_indices is None: - w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) - w_centers = self.style_net.w_centers[w_indices].to(w.device) - w = w_centers.to(w.dtype).lerp(w, truncation_value) - return self.forward(**kwargs, w=w) diff --git a/dp2/generator/stylegan_unet.py b/dp2/generator/stylegan_unet.py deleted file mode 100644 index 68c7f7706d601e4ae2eb80a4b3fd03bc127b2164..0000000000000000000000000000000000000000 --- a/dp2/generator/stylegan_unet.py +++ /dev/null @@ -1,208 +0,0 @@ -import torch -import numpy as np -from dp2.layers import Sequential -from dp2.layers.sg2_layers import Conv2d, FullyConnectedLayer, ResidualBlock -from .base import BaseStyleGAN -from typing import List, Tuple -from .utils import spatial_embed_keypoints, mask_output - - -def get_chsize(imsize, cnum, max_imsize, max_cnum_mul): - n = int(np.log2(max_imsize) - np.log2(imsize)) - mul = min(2**n, max_cnum_mul) - ch = cnum * mul - return int(ch) - -class StyleGANUnet(BaseStyleGAN): - def __init__( - self, - scale_grad: bool, - im_channels: int, - min_fmap_resolution: int, - imsize: List[int], - cnum: int, - max_cnum_mul: int, - mask_output: bool, - conv_clamp: int, - input_cse: bool, - cse_nc: int, - n_middle_blocks: int, - input_keypoints: bool, - n_keypoints: int, - input_keypoint_indices: Tuple[int], - fix_errors: bool, - **kwargs - ) -> None: - super().__init__(**kwargs) - self.n_keypoints = n_keypoints - self.input_keypoint_indices = list(input_keypoint_indices) - self.input_keypoints = input_keypoints - assert not (input_cse and input_keypoints) - cse_nc = 0 if cse_nc is None else cse_nc - self.imsize = imsize - self._cnum = cnum - self._max_cnum_mul = max_cnum_mul - self._min_fmap_resolution = min_fmap_resolution - self._image_channels = im_channels - self._max_imsize = max(imsize) - self.input_cse = input_cse - self.gain_unet = np.sqrt(1/3) - n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1 - encoder_layers = [] - self.from_rgb = Conv2d( - im_channels + 1 + input_cse*(cse_nc+1) + input_keypoints*len(self.input_keypoint_indices), - cnum, 1 - ) - for i in range(n_levels): # Encoder layers - resolution = [x//2**i for x in imsize] - in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) - second_ch = in_ch - out_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul) - down = 2 - - if i == 0: # first (lowest) block. Downsampling is performed at the start of the block - down = 1 - if i == n_levels - 1: - out_ch = second_ch - block = ResidualBlock(in_ch, out_ch, down=down, conv_clamp=conv_clamp, fix_residual=fix_errors) - encoder_layers.append(block) - self._encoder_out_shape = [ - get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul), - *resolution] - - self.encoder = torch.nn.ModuleList(encoder_layers) - - # initialize decoder - decoder_layers = [] - for i in range(n_levels): - resolution = [x//2**(n_levels-1-i) for x in imsize] - in_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul) - out_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) - if i == 0: # first (lowest) block - in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) - - up = 1 - if i != n_levels - 1: - up = 2 - block = ResidualBlock( - in_ch, out_ch, conv_clamp=conv_clamp, gain_out=np.sqrt(1/3), - w_dim=self.style_net.w_dim, norm=True, up=up, - fix_residual=fix_errors - ) - decoder_layers.append(block) - if i != 0: - unet_block = Conv2d( - in_ch, in_ch, kernel_size=1, conv_clamp=conv_clamp, norm=True, - gain=np.sqrt(1/3) if fix_errors else np.sqrt(.5)) - setattr(self, f"unet_block{i}", unet_block) - - # Initialize "middle blocks" that do not have down/up sample - middle_blocks = [] - for i in range(n_middle_blocks): - ch = get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul) - block = ResidualBlock( - ch, ch, conv_clamp=conv_clamp, gain_out=np.sqrt(.5) if fix_errors else np.sqrt(1/3), - w_dim=self.style_net.w_dim, norm=True, - ) - middle_blocks.append(block) - if n_middle_blocks != 0: - self.middle_blocks = Sequential(*middle_blocks) - self.decoder = torch.nn.ModuleList(decoder_layers) - self.to_rgb = Conv2d(cnum, im_channels, 1, activation="linear", conv_clamp=conv_clamp) - # Initialize "middle blocks" that do not have down/up sample - self.decoder = torch.nn.ModuleList(decoder_layers) - self.scale_grad = scale_grad - self.mask_output = mask_output - - def forward_dec(self, x, w, unet_features, condition, mask, s, **kwargs): - for i, layer in enumerate(self.decoder): - if i != 0: - unet_layer = getattr(self, f"unet_block{i}") - x = x + unet_layer(unet_features[-i]) - x = layer(x, w=w, s=s) - x = self.to_rgb(x) - if self.mask_output: - x = mask_output(True, condition, x, mask) - return dict(img=x) - - def forward_enc(self, condition, mask, embedding, keypoints, E_mask, **kwargs): - if self.input_cse: - x = torch.cat((condition, mask, embedding, E_mask), dim=1) - else: - x = torch.cat((condition, mask), dim=1) - if self.input_keypoints: - keypoints = keypoints[:, self.input_keypoint_indices] - one_hot_pose = spatial_embed_keypoints(keypoints, x) - x = torch.cat((x, one_hot_pose), dim=1) - x = self.from_rgb(x) - - unet_features = [] - for i, layer in enumerate(self.encoder): - x = layer(x) - if i != len(self.encoder)-1: - unet_features.append(x) - if hasattr(self, "middle_blocks"): - for layer in self.middle_blocks: - x = layer(x) - return x, unet_features - - def forward( - self, condition, mask, - z=None, embedding=None, w=None, update_emas=False, x=None, - s=None, - keypoints=None, - unet_features=None, - E_mask=None, - **kwargs): - # Used to skip sampling from encoder in inference. E.g. for w projection. - if x is not None and unet_features is not None: - assert not self.training - else: - x, unet_features = self.forward_enc(condition, mask, embedding, keypoints, E_mask, **kwargs) - if w is None: - if z is None: - z = self.get_z(condition) - w = self.get_w(z, update_emas=update_emas) - return self.forward_dec(x, w, unet_features, condition, mask, s, **kwargs) - -class ComodStyleUNet(StyleGANUnet): - - def __init__(self, min_comod_res=4, lr_multiplier_comod=1, **kwargs) -> None: - super().__init__(**kwargs) - min_fmap = min(self._encoder_out_shape[1:]) - enc_out_ch = self._encoder_out_shape[0] - n_down = int(np.ceil(np.log2(min_fmap) - np.log2(min_comod_res))) - comod_layers = [] - in_ch = enc_out_ch - for i in range(n_down): - comod_layers.append(Conv2d(enc_out_ch, 256, kernel_size=3, down=2, lr_multiplier=lr_multiplier_comod)) - in_ch = 256 - if n_down == 0: - comod_layers = [Conv2d(in_ch, 256, kernel_size=3)] - comod_layers.append(torch.nn.Flatten()) - out_res = [x//2**n_down for x in self._encoder_out_shape[1:]] - in_ch_fc = np.prod(out_res) * 256 - comod_layers.append(FullyConnectedLayer(in_ch_fc, 512, lr_multiplier=lr_multiplier_comod)) - self.comod_block = Sequential(*comod_layers) - self.comod_fc = FullyConnectedLayer(512+self.style_net.w_dim, self.style_net.w_dim, lr_multiplier=lr_multiplier_comod) - - def forward_dec(self, x, w, unet_features, condition, mask, **kwargs): - y = self.comod_block(x) - y = torch.cat((y, w), dim=1) - y = self.comod_fc(y) - for i, layer in enumerate(self.decoder): - if i != 0: - unet_layer = getattr(self, f"unet_block{i}") - x = x + unet_layer(unet_features[-i], gain=np.sqrt(.5)) - x = layer(x, w=y) - x = self.to_rgb(x) - if self.mask_output: - x = mask_output(True, condition, x, mask) - return dict(img=x) - - def get_comod_y(self, batch, w): - x, unet_features = self.forward_enc(**batch) - y = self.comod_block(x) - y = torch.cat((y, w), dim=1) - y = self.comod_fc(y) - return y diff --git a/dp2/generator/utils.py b/dp2/generator/utils.py deleted file mode 100644 index 5732b2c511a42f4bffd4b512244cf790bec96ef0..0000000000000000000000000000000000000000 --- a/dp2/generator/utils.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import tops -import torch -from torch.cuda.amp import custom_bwd, custom_fwd - - -@torch.no_grad() -def spatial_embed_keypoints(keypoints: torch.Tensor, x): - tops.assert_shape(keypoints, (None, None, 3)) - B, N_K, _ = keypoints.shape - H, W = x.shape[-2:] - keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32) - x, y, visible = keypoints.chunk(3, dim=2) - x = (x * W).round().long().clamp(0, W-1) - y = (y * H).round().long().clamp(0, H-1) - kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1) - pos = (kp_idx*(H*W) + y*W + x + 1) - # Offset all by 1 to index invisible keypoints as 0 - pos = (pos * visible.round().long()).squeeze(dim=-1) - keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32) - keypoint_spatial.scatter_(1, pos, 1) - keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W) - return keypoint_spatial - -class MaskOutput(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, x_real, x_fake, mask): - ctx.save_for_backward(mask) - out = x_real * mask + (1-mask) * x_fake - return out - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - fake_grad = grad_output - mask, = ctx.saved_tensors - fake_grad = fake_grad * (1 - mask) - known_percentage = mask.view(mask.shape[0], -1).mean(dim=1) - fake_grad = fake_grad / (1-known_percentage).view(-1, 1, 1, 1) - return None, fake_grad, None - - -def mask_output(scale_grad, x_real, x_fake, mask): - if scale_grad: - return MaskOutput.apply(x_real, x_fake, mask) - return x_real * mask + (1-mask) * x_fake diff --git a/dp2/infer.py b/dp2/infer.py deleted file mode 100644 index ddbac8dbfd0914979575a0799c10b4eeb00aff73..0000000000000000000000000000000000000000 --- a/dp2/infer.py +++ /dev/null @@ -1,72 +0,0 @@ -import tops -import torch -from tops import checkpointer -from tops.config import instantiate -from tops.logger import warn - -def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None): - state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"] - if ckpt_mapper is not None: - state = ckpt_mapper(state) - load_state_dict(G, state) - tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M") - print(ckpt.keys()) - if "w_centers" in ckpt: - print("Has w_centers!") - G.style_net.w_centers = ckpt["w_centers"] - tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}") - - -def build_trained_generator(cfg, map_location=None): - map_location = map_location if map_location is not None else tops.get_device() - G = instantiate(cfg.generator).to(map_location) - G.eval() - G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None - if hasattr(cfg, "ckpt_mapper"): - ckpt_mapper = instantiate(cfg.ckpt_mapper) - else: - ckpt_mapper = None - if "model_url" in cfg.common: - ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum, map_location=torch.device("cpu")) - load_generator_state(ckpt, G, ckpt_mapper) - return G - try: - ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu") - load_generator_state(ckpt, G, ckpt_mapper) - except FileNotFoundError as e: - tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}") - return G - - -def build_trained_discriminator(cfg, map_location=None): - map_location = map_location if map_location is not None else tops.get_device() - D = instantiate(cfg.discriminator).to(map_location) - D.eval() - try: - ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu") - if hasattr(cfg, "ckpt_mapper_D"): - ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"]) - D.load_state_dict(ckpt["discriminator"]) - except FileNotFoundError as e: - tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}") - return D - - -def load_state_dict(module: torch.nn.Module, state_dict: dict): - module_sd = module.state_dict() - to_remove = [] - for key, item in state_dict.items(): - if key not in module_sd: - continue - if item.shape != module_sd[key].shape: - to_remove.append(key) - warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}") - for key in to_remove: - state_dict.pop(key) - for key, item in state_dict.items(): - if key not in module_sd: - warn(f"Did not fin key in model state dict: {key}") - for key, item in module_sd.items(): - if key not in state_dict: - warn(f"Did not find key in state dict: {key}") - module.load_state_dict(state_dict, strict=False) \ No newline at end of file diff --git a/dp2/layers/__init__.py b/dp2/layers/__init__.py deleted file mode 100644 index 4f54bed21f71c7facaa8129e1b689f16c5f9d13d..0000000000000000000000000000000000000000 --- a/dp2/layers/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Dict -import torch -import tops -import torch.nn as nn - -class Sequential(nn.Sequential): - - def forward(self, x: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]: - for module in self: - x = module(x, **kwargs) - return x - -class Module(nn.Module): - - def __init__(self, *args, **kwargs): - super().__init__() - - def extra_repr(self): - num_params = tops.num_parameters(self) / 10**6 - return f"Num params: {num_params:.3f}M" diff --git a/dp2/layers/sg2_layers.py b/dp2/layers/sg2_layers.py deleted file mode 100644 index 3aac03935d0cb6ec172f2cd0aabefbdc74d0141e..0000000000000000000000000000000000000000 --- a/dp2/layers/sg2_layers.py +++ /dev/null @@ -1,227 +0,0 @@ -from typing import List -import numpy as np -import torch -import tops -import torch.nn.functional as F -from sg3_torch_utils.ops import conv2d_resample -from sg3_torch_utils.ops import upfirdn2d -from sg3_torch_utils.ops import bias_act -from sg3_torch_utils.ops.fma import fma - - -class FullyConnectedLayer(torch.nn.Module): - def __init__(self, - in_features, # Number of input features. - out_features, # Number of output features. - bias = True, # Apply additive bias before the activation function? - activation = 'linear', # Activation function: 'relu', 'lrelu', etc. - lr_multiplier = 1, # Learning rate multiplier. - bias_init = 0, # Initial value for the additive bias. - ): - super().__init__() - self.repr = dict( - in_features=in_features, out_features=out_features, bias=bias, - activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init) - self.activation = activation - self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) - self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None - self.weight_gain = lr_multiplier / np.sqrt(in_features) - self.bias_gain = lr_multiplier - self.in_features = in_features - self.out_features = out_features - - def forward(self, x): - w = self.weight * self.weight_gain - b = self.bias - if b is not None and self.bias_gain != 1: - b = b * self.bias_gain - x = F.linear(x, w) - x = bias_act.bias_act(x, b, act=self.activation) - return x - - def extra_repr(self) -> str: - return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) - - -class Conv2d(torch.nn.Module): - def __init__(self, - in_channels, # Number of input channels. - out_channels, # Number of output channels. - kernel_size = 3, # Convolution kernel size. - up = 1, # Integer upsampling factor. - down = 1, # Integer downsampling factor - activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. - resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. - conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. - bias = True, - norm = False, - lr_multiplier=1, - bias_init=0, - w_dim=None, - gain=1, - ): - super().__init__() - if norm: - self.norm = torch.nn.InstanceNorm2d(None) - assert norm in [True, False] - self.up = up - self.down = down - self.activation = activation - self.conv_clamp = conv_clamp if conv_clamp is None else conv_clamp * gain - self.out_channels = out_channels - self.in_channels = in_channels - self.padding = kernel_size // 2 - - self.repr = dict( - in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, up=up, down=down, - activation=activation, resample_filter=resample_filter, conv_clamp=conv_clamp, bias=bias, - ) - - if self.up == 1 and self.down == 1: - self.resample_filter = None - else: - self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) - - self.act_gain = bias_act.activation_funcs[activation].def_gain * gain - self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2)) - self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size])) - self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None - self.bias_gain = lr_multiplier - if w_dim is not None: - self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) - self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0) - - def forward(self, x, w=None, s=None): - tops.assert_shape(x, [None, self.weight.shape[1], None, None]) - if s is not None: - s = s[..., :self.in_channels*2] - gamma, beta = s.view(-1, self.in_channels*2, 1, 1).chunk(2, dim=1) - x = fma(x, gamma, beta) - elif hasattr(self, "affine"): - gamma = self.affine(w).view(-1, self.in_channels, 1, 1) - beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1) - x = fma(x, gamma, beta) - w = self.weight * self.weight_gain - # Removing flip weight is not safe. - x = conv2d_resample.conv2d_resample(x, w, self.resample_filter, self.up, self.down, self.padding, flip_weight=self.up==1) - if hasattr(self, "norm"): - x = self.norm(x) - b = self.bias * self.bias_gain if self.bias is not None else None - x = bias_act.bias_act(x, b, act=self.activation, gain=self.act_gain, clamp=self.conv_clamp) - return x - - def extra_repr(self) -> str: - return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) - - -class Block(torch.nn.Module): - def __init__(self, - in_channels, # Number of input channels, 0 = first block. - out_channels, # Number of output channels. - conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. - up = 1, - down = 1, - **layer_kwargs, # Arguments for SynthesisLayer. - ): - super().__init__() - self.in_channels = in_channels - self.down = down - self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs) - self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, **layer_kwargs) - - def forward(self, x, **layer_kwargs): - x = self.conv0(x, **layer_kwargs) - x = self.conv1(x, **layer_kwargs) - return x - - -class ResidualBlock(torch.nn.Module): - def __init__(self, - in_channels, # Number of input channels, 0 = first block. - out_channels, # Number of output channels. - conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. - up = 1, - down = 1, - gain_out=np.sqrt(0.5), - fix_residual: bool = False, - **layer_kwargs, # Arguments for conv layer. - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.down = down - self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs) - - self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, gain=gain_out,**layer_kwargs) - - self.skip = Conv2d( - in_channels, out_channels, kernel_size=1, bias=False, up=up, down=down, - activation="linear" if fix_residual else "lrelu", - gain=gain_out - ) - self.gain_out = gain_out - - def forward(self, x, w=None, s=None, **layer_kwargs): - y = self.skip(x) - s_ = next(s) if s is not None else None - x = self.conv0(x, w, s=s_, **layer_kwargs) - s_ = next(s) if s is not None else None - x = self.conv1(x, w, s=s_, **layer_kwargs) - x = y + x - return x - - -class MinibatchStdLayer(torch.nn.Module): - def __init__(self, group_size, num_channels=1): - super().__init__() - self.group_size = group_size - self.num_channels = num_channels - - def forward(self, x): - N, C, H, W = x.shape - with tops.suppress_tracer_warnings(): # as_tensor results are registered as constants - G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N - F = self.num_channels - c = C // F - - y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. - y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. - y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. - y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. - y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels. - y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. - y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. - x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. - return x - -#---------------------------------------------------------------------------- - -class DiscriminatorEpilogue(torch.nn.Module): - def __init__(self, - in_channels, # Number of input channels. - resolution: List[int], # Resolution of this block. - mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. - mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. - activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. - conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. - ): - super().__init__() - self.in_channels = in_channels - self.resolution = resolution - self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None - self.conv = Conv2d( - in_channels + mbstd_num_channels, in_channels, - kernel_size=3, activation=activation, conv_clamp=conv_clamp) - self.fc = FullyConnectedLayer(in_channels * resolution[0]*resolution[1], in_channels, activation=activation) - self.out = FullyConnectedLayer(in_channels, 1) - - def forward(self, x): - tops.assert_shape(x, [None, self.in_channels, *self.resolution]) # [NCHW] - # Main layers. - if self.mbstd is not None: - x = self.mbstd(x) - x = self.conv(x) - x = self.fc(x.flatten(1)) - x = self.out(x) - return x diff --git a/dp2/loss/__init__.py b/dp2/loss/__init__.py deleted file mode 100644 index 79165e22915aeeebabbcddebcccb281deedb40c2..0000000000000000000000000000000000000000 --- a/dp2/loss/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sg2_loss import StyleGAN2Loss \ No newline at end of file diff --git a/dp2/loss/pl_regularization.py b/dp2/loss/pl_regularization.py deleted file mode 100644 index 3a557a8b92cbcb2572e8f25633e7c89e996ef906..0000000000000000000000000000000000000000 --- a/dp2/loss/pl_regularization.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import tops -import numpy as np -from sg3_torch_utils.ops import conv2d_gradfix - -pl_mean_total = torch.zeros([]) - -class PLRegularization: - - def __init__(self, weight: float, batch_shrink: int, pl_decay:float, scale_by_mask: bool,**kwargs): - self.pl_mean = torch.zeros([], device=tops.get_device()) - self.pl_weight = weight - self.batch_shrink = batch_shrink - self.pl_decay = pl_decay - self.scale_by_mask = scale_by_mask - - def __call__(self, G, batch, grad_scaler): - batch_size = batch["img"].shape[0] // self.batch_shrink - batch = {k: v[:batch_size] for k, v in batch.items() if k != "embed_map"} - if "embed_map" in batch: - batch["embed_map"] = batch["embed_map"] - z = G.get_z(batch["img"]) - - with torch.cuda.amp.autocast(tops.AMP()): - gen_ws = G.style_net(z) - gen_img = G(**batch, w=gen_ws)["img"].float() - pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) - with conv2d_gradfix.no_weight_gradients(): - # Sums over HWC - pl_grads = torch.autograd.grad( - outputs=[grad_scaler.scale(gen_img * pl_noise)], - inputs=[gen_ws], - create_graph=True, - grad_outputs=torch.ones_like(gen_img), - only_inputs=True)[0] - - pl_grads = pl_grads.float() / grad_scaler.get_scale() - if self.scale_by_mask: - # Percentage of pixels known - scaling = batch["mask"].flatten(start_dim=1).mean(dim=1).view(-1, 1) - pl_grads = pl_grads / scaling - pl_lengths = pl_grads.square().sum(1).sqrt() - pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) - if not torch.isnan(pl_mean).any(): - self.pl_mean.copy_(pl_mean.detach()) - pl_penalty = (pl_lengths - pl_mean).square() - to_log = dict(pl_penalty=pl_penalty.mean().detach()) - return pl_penalty.view(-1) * self.pl_weight, to_log \ No newline at end of file diff --git a/dp2/loss/r1_regularization.py b/dp2/loss/r1_regularization.py deleted file mode 100644 index 2098d792a591e4652259703085abe9d0bc55489b..0000000000000000000000000000000000000000 --- a/dp2/loss/r1_regularization.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import tops - -def r1_regularization( - real_img, real_score, mask, lambd: float, lazy_reg_interval: int, - lazy_regularization: bool, - scaler: torch.cuda.amp.GradScaler, mask_out: bool, - mask_out_scale: bool, - **kwargs - ): - grad = torch.autograd.grad( - outputs=scaler.scale(real_score), - inputs=real_img, - grad_outputs=torch.ones_like(real_score), - create_graph=True, - only_inputs=True, - )[0] - inv_scale = 1.0 / scaler.get_scale() - grad = grad * inv_scale - with torch.cuda.amp.autocast(tops.AMP()): - if mask_out: - grad = grad * (1 - mask) - grad = grad.square().sum(dim=[1, 2, 3]) - if mask_out and mask_out_scale: - total_pixels = real_img.shape[1] * real_img.shape[2] * real_img.shape[3] - n_fake = (1-mask).sum(dim=[1, 2, 3]) - scaling = total_pixels / n_fake - grad = grad * scaling - if lazy_regularization: - lambd_ = lambd * lazy_reg_interval / 2 # From stylegan2, lazy regularization - return grad * lambd_, grad.detach() \ No newline at end of file diff --git a/dp2/loss/sg2_loss.py b/dp2/loss/sg2_loss.py deleted file mode 100644 index 4a0d1150a66bbff4d7f408d9b68612adac179cdb..0000000000000000000000000000000000000000 --- a/dp2/loss/sg2_loss.py +++ /dev/null @@ -1,94 +0,0 @@ -import functools -import torch -import tops -from tops import logger -from dp2.utils import forward_D_fake -from .utils import nsgan_d_loss, nsgan_g_loss -from .r1_regularization import r1_regularization -from .pl_regularization import PLRegularization - -class StyleGAN2Loss: - - def __init__( - self, - D, - G, - r1_opts: dict, - EP_lambd: float, - lazy_reg_interval: int, - lazy_regularization: bool, - pl_reg_opts: dict, - ) -> None: - self.gradient_step_D = 0 - self._lazy_reg_interval = lazy_reg_interval - self.D = D - self.G = G - self.EP_lambd = EP_lambd - self.lazy_regularization = lazy_regularization - self.r1_reg = functools.partial( - r1_regularization, **r1_opts, lazy_reg_interval=lazy_reg_interval, - lazy_regularization=lazy_regularization) - self.do_PL_Reg = False - if pl_reg_opts.weight > 0: - self.pl_reg = PLRegularization(**pl_reg_opts) - self.do_PL_Reg = True - self.pl_start_nimg = pl_reg_opts.start_nimg - - def D_loss(self, batch: dict, grad_scaler): - to_log = {} - # Forward through G and D - do_GP = self.lazy_regularization and self.gradient_step_D % self._lazy_reg_interval == 0 - if do_GP: - batch["img"] = batch["img"].detach().requires_grad_(True) - with torch.cuda.amp.autocast(enabled=tops.AMP()): - with torch.no_grad(): - G_fake = self.G(**batch, update_emas=True) - D_out_real = self.D(**batch) - - D_out_fake = forward_D_fake(batch, G_fake["img"], self.D) - - # Non saturating loss - nsgan_loss = nsgan_d_loss(D_out_real["score"], D_out_fake["score"]) - tops.assert_shape(nsgan_loss, (batch["img"].shape[0], )) - to_log["d_loss"] = nsgan_loss.mean() - total_loss = nsgan_loss - epsilon_penalty = D_out_real["score"].pow(2).view(-1) - to_log["epsilon_penalty"] = epsilon_penalty.mean() - tops.assert_shape(epsilon_penalty, total_loss.shape) - total_loss = total_loss + epsilon_penalty * self.EP_lambd - - # Improved gradient penalty with lazy regularization - # Gradient penalty applies specialized autocast. - if do_GP: - gradient_pen, grad_unscaled = self.r1_reg(batch["img"], D_out_real["score"], batch["mask"], scaler=grad_scaler) - to_log["r1_gradient_penalty"] = grad_unscaled.mean() - tops.assert_shape(gradient_pen, total_loss.shape) - total_loss = total_loss + gradient_pen - - batch["img"] = batch["img"].detach().requires_grad_(False) - if "score" in D_out_real: - to_log["real_scores"] = D_out_real["score"] - to_log["real_logits_sign"] = D_out_real["score"].sign() - to_log["fake_logits_sign"] = D_out_fake["score"].sign() - to_log["fake_scores"] = D_out_fake["score"] - to_log = {key: item.mean().detach() for key, item in to_log.items()} - self.gradient_step_D += 1 - return total_loss.mean(), to_log - - def G_loss(self, batch: dict, grad_scaler): - with torch.cuda.amp.autocast(enabled=tops.AMP()): - to_log = {} - # Forward through G and D - G_fake = self.G(**batch) - D_out_fake = forward_D_fake(batch, G_fake["img"], self.D) - # Adversarial Loss - total_loss = nsgan_g_loss(D_out_fake["score"]).view(-1) - to_log["g_loss"] = total_loss.mean() - tops.assert_shape(total_loss, (batch["img"].shape[0], )) - - if self.do_PL_Reg and logger.global_step() >= self.pl_start_nimg: - pl_reg, to_log_ = self.pl_reg(self.G, batch, grad_scaler=grad_scaler) - total_loss = total_loss + pl_reg.mean() - to_log.update(to_log_) - to_log = {key: item.mean().detach() for key, item in to_log.items()} - return total_loss.mean(), to_log diff --git a/dp2/loss/utils.py b/dp2/loss/utils.py deleted file mode 100644 index 1c9fe96156f40ffcc08ad84652d50327263db2fd..0000000000000000000000000000000000000000 --- a/dp2/loss/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -import torch.nn.functional as F - -def nsgan_g_loss(fake_score): - """ - Non-saturating criterion from Goodfellow et al. 2014 - """ - return torch.nn.functional.softplus(-fake_score) - - -def nsgan_d_loss(real_score, fake_score): - """ - Non-saturating criterion from Goodfellow et al. 2014 - """ - d_loss = F.softplus(-real_score) + F.softplus(fake_score) - return d_loss.view(-1) - - -def smooth_masked_l1_loss(x, target, mask): - """ - Pixel-wise l1 loss for the area indicated by mask - """ - # Beta=.1 <-> square loss if pixel difference <= 12.8 - l1 = F.smooth_l1_loss(x*mask, target*mask, beta=.1, reduction="none").sum(dim=[1,2,3]) / mask.sum(dim=[1, 2, 3]) - return l1 diff --git a/dp2/metrics/__init__.py b/dp2/metrics/__init__.py deleted file mode 100644 index fc224b42cc5ceeaf2e68d6371ab9da1dade557ed..0000000000000000000000000000000000000000 --- a/dp2/metrics/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .torch_metrics import compute_metrics_iteratively -from .fid import compute_fid -from .ppl import calculate_ppl \ No newline at end of file diff --git a/dp2/metrics/fid.py b/dp2/metrics/fid.py deleted file mode 100644 index 66eb5e0060d60294c4cdf80254e583cf8fad8bc2..0000000000000000000000000000000000000000 --- a/dp2/metrics/fid.py +++ /dev/null @@ -1,72 +0,0 @@ -import tops -from dp2 import utils -from pathlib import Path -from torch_fidelity.generative_model_modulewrapper import GenerativeModelModuleWrapper -import torch -import torch_fidelity - - -class GeneratorIteratorWrapper(GenerativeModelModuleWrapper): - - def __init__(self, generator, dataloader, zero_z: bool, n_diverse: int): - if isinstance(generator, utils.EMA): - generator = generator.generator - z_size = generator.z_channels - super().__init__(generator, z_size, "normal", 0) - self.zero_z = zero_z - self.dataloader = iter(dataloader) - self.n_diverse = n_diverse - self.cur_div_idx = 0 - - @torch.no_grad() - def forward(self, z, **kwargs): - if self.cur_div_idx == 0: - self.batch = next(self.dataloader) - if self.zero_z: - z = z.zero_() - self.cur_div_idx += 1 - self.cur_div_idx = 0 if self.cur_div_idx == self.n_diverse else self.cur_div_idx - with torch.cuda.amp.autocast(enabled=tops.AMP()): - img = self.module(**self.batch)["img"] - img = (utils.denormalize_img(img)*255).byte() - return img - - -def compute_fid(generator, dataloader, real_directory, n_source, zero_z, n_diverse): - generator = GeneratorIteratorWrapper(generator, dataloader, zero_z, n_diverse) - batch_size = dataloader.batch_size - num_samples = (n_source * n_diverse) // batch_size * batch_size - assert n_diverse >= 1 - assert (not zero_z) or n_diverse == 1 - assert num_samples % batch_size == 0 - assert n_source <= batch_size * len(dataloader), (batch_size*len(dataloader), n_source, n_diverse) - metrics = torch_fidelity.calculate_metrics( - input1=generator, - input2=real_directory, - cuda=torch.cuda.is_available(), - fid=True, - input2_cache_name="_".join(Path(real_directory).parts) + "_cached", - input1_model_num_samples=int(num_samples), - batch_size=dataloader.batch_size - ) - return metrics["frechet_inception_distance"] - - -if __name__ == "__main__": - import click - from dp2.config import Config - from dp2.data import build_dataloader_val - from dp2.infer import build_trained_generator - @click.command() - @click.argument("config_path") - @click.option("--n_source", default=200, type=int) - @click.option("--n_diverse", default=5, type=int) - @click.option("--zero_z", default=False, is_flag=True) - def run(config_path, n_source: int, n_diverse: int, zero_z: bool): - cfg = Config.fromfile(config_path) - dataloader = build_dataloader_val(cfg) - generator, _ = build_trained_generator(cfg) - print(compute_fid( - generator, dataloader, cfg.fid_real_directory, n_source, zero_z, n_diverse)) - - run() \ No newline at end of file diff --git a/dp2/metrics/fid_clip.py b/dp2/metrics/fid_clip.py deleted file mode 100644 index 6712fd48503b787bc8e2197a6a737bcd73546b35..0000000000000000000000000000000000000000 --- a/dp2/metrics/fid_clip.py +++ /dev/null @@ -1,84 +0,0 @@ -import pickle -import torch -import torchvision -from pathlib import Path -from dp2 import utils -import tops -try: - import clip -except ImportError: - print("Could not import clip.") -from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric -clip_model = None -clip_preprocess = None - - -@torch.no_grad() -def compute_fid_clip( - dataloader, generator, - cache_directory, - data_len=None, - **kwargs - ) -> dict: - """ - FID CLIP following the description in The Role of ImageNet Classes in Frechet Inception Distance, Thomas Kynkaamniemi et al. - Args: - n_samples (int): Creates N samples from same image to calculate stats - """ - global clip_model, clip_preprocess - if clip_model is None: - clip_model, preprocess = clip.load("ViT-B/32", device="cpu") - normalize_fn = preprocess.transforms[-1] - img_mean = normalize_fn.mean - img_std = normalize_fn.std - clip_model = tops.to_cuda(clip_model.visual) - clip_preprocess = tops.to_cuda(torch.nn.Sequential( - torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC), - torchvision.transforms.Normalize(img_mean, img_std) - )) - cache_directory = Path(cache_directory) - if data_len is None: - data_len = len(dataloader)*dataloader.batch_size - fid_cache_path = cache_directory.joinpath("fid_stats_clip.pkl") - has_fid_cache = fid_cache_path.is_file() - if not has_fid_cache: - fid_features_real = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device()) - fid_features_fake = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device()) - eidx = 0 - n_samples_seen = 0 - for batch in utils.tqdm_(iter(dataloader), desc="Computing FID CLIP."): - sidx = eidx - eidx = sidx + batch["img"].shape[0] - n_samples_seen += batch["img"].shape[0] - with torch.cuda.amp.autocast(tops.AMP()): - fakes = generator(**batch)["img"] - real_data = batch["img"] - fakes = utils.denormalize_img(fakes) - real_data = utils.denormalize_img(real_data) - if not has_fid_cache: - real_data = clip_preprocess(real_data) - fid_features_real[sidx:eidx] = clip_model(real_data) - fakes = clip_preprocess(fakes) - fid_features_fake[sidx:eidx] = clip_model(fakes) - fid_features_fake = fid_features_fake[:n_samples_seen] - fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu() - if has_fid_cache: - if tops.rank() == 0: - with open(fid_cache_path, "rb") as fp: - fid_stat_real = pickle.load(fp) - else: - fid_features_real = fid_features_real[:n_samples_seen] - fid_features_real = tops.all_gather_uneven(fid_features_real).cpu() - assert fid_features_real.shape == fid_features_fake.shape - if tops.rank() == 0: - fid_stat_real = fid_features_to_statistics(fid_features_real) - cache_directory.mkdir(exist_ok=True, parents=True) - with open(fid_cache_path, "wb") as fp: - pickle.dump(fid_stat_real, fp) - - if tops.rank() == 0: - print("Starting calculation of fid from features of shape:", fid_features_fake.shape) - fid_stat_fake = fid_features_to_statistics(fid_features_fake) - fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"] - return dict(fid_clip=fid_) - return dict(fid_clip=-1) diff --git a/dp2/metrics/lpips.py b/dp2/metrics/lpips.py deleted file mode 100644 index cfda315da83fab7bb1bf5e8d1b09f3a721ded064..0000000000000000000000000000000000000000 --- a/dp2/metrics/lpips.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import tops -import sys -from contextlib import redirect_stdout -from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average - -class SampleSimilarityLPIPS(torch.nn.Module): - SUPPORTED_DTYPES = { - 'uint8': torch.uint8, - 'float32': torch.float32, - } - - def __init__(self): - - super().__init__() - self.chns = [64, 128, 256, 512, 512] - self.L = len(self.chns) - self.lin0 = NetLinLayer(self.chns[0], use_dropout=True) - self.lin1 = NetLinLayer(self.chns[1], use_dropout=True) - self.lin2 = NetLinLayer(self.chns[2], use_dropout=True) - self.lin3 = NetLinLayer(self.chns[3], use_dropout=True) - self.lin4 = NetLinLayer(self.chns[4], use_dropout=True) - self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] - with redirect_stdout(sys.stderr): - fp = tops.download_file(URL_VGG16_LPIPS) - state_dict = torch.load(fp, map_location="cpu") - self.load_state_dict(state_dict) - self.net = VGG16features() - self.eval() - for param in self.parameters(): - param.requires_grad = False - mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2 - inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255) - self.register_buffer("mean", mean_rescaled) - self.register_buffer("std", inv_std_rescaled) - - def normalize(self, x): - # torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] - x = (x.float() - self.mean) * self.std - return x - - @staticmethod - def resize(x, size): - if x.shape[-1] > size and x.shape[-2] > size: - x = torch.nn.functional.interpolate(x, (size, size), mode='area') - else: - x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False) - return x - - def lpips_from_feats(self, feats0, feats1): - diffs = {} - for kk in range(self.L): - diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 - - res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)] - val = sum(res) - return val - - def get_feats(self, x): - assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW' - if x.shape[-2] < 16: # Resize images < 16x16 - f = 16 / x.shape[-2] - size = tuple([int(f*_) for _ in x.shape[-2:]]) - x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False) - in0_input = self.normalize(x) - outs0 = self.net.forward(in0_input) - - feats = {} - for kk in range(self.L): - feats[kk] = normalize_tensor(outs0[kk]) - return feats - - def forward(self, in0, in1): - feats0 = self.get_feats(in0) - feats1 = self.get_feats(in1) - return self.lpips_from_feats(feats0, feats1), feats0, feats1 diff --git a/dp2/metrics/ppl.py b/dp2/metrics/ppl.py deleted file mode 100644 index 3d30b220546bcd8e44c36eede56361868e26879a..0000000000000000000000000000000000000000 --- a/dp2/metrics/ppl.py +++ /dev/null @@ -1,110 +0,0 @@ -import numpy as np -import torch -import tops -from dp2 import utils -from torch_fidelity.helpers import get_kwarg, vassert -from torch_fidelity.defaults import DEFAULTS as PPL_DEFAULTS -from torch_fidelity.utils import sample_random, batch_interp, create_sample_similarity - - -def slerp(a, b, t): - a = a / a.norm(dim=-1, keepdim=True) - b = b / b.norm(dim=-1, keepdim=True) - d = (a * b).sum(dim=-1, keepdim=True) - p = t * torch.acos(d) - c = b - d * a - c = c / c.norm(dim=-1, keepdim=True) - d = a * torch.cos(p) + c * torch.sin(p) - d = d / d.norm(dim=-1, keepdim=True) - return d - - -@torch.no_grad() -def calculate_ppl( - dataloader, - generator, - latent_space=None, - data_len=None, - **kwargs) -> dict: - """ - Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py - """ - if latent_space is None: - latent_space = generator.latent_space - assert latent_space in ["Z", "W"], f"Not supported latent space: {latent_space}" - epsilon = PPL_DEFAULTS["ppl_epsilon"] - interp = PPL_DEFAULTS['ppl_z_interp_mode'] - similarity_name = PPL_DEFAULTS['ppl_sample_similarity'] - sample_similarity_resize = PPL_DEFAULTS['ppl_sample_similarity_resize'] - sample_similarity_dtype = PPL_DEFAULTS['ppl_sample_similarity_dtype'] - discard_percentile_lower = PPL_DEFAULTS['ppl_discard_percentile_lower'] - discard_percentile_higher = PPL_DEFAULTS['ppl_discard_percentile_higher'] - - vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number') - vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile') - vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile') - if discard_percentile_lower is not None and discard_percentile_higher is not None: - vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles') - - sample_similarity = create_sample_similarity( - similarity_name, - sample_similarity_resize=sample_similarity_resize, - sample_similarity_dtype=sample_similarity_dtype, - cuda=False, - **kwargs - ) - sample_similarity = tops.to_cuda(sample_similarity) - rng = np.random.RandomState(get_kwarg('rng_seed', kwargs)) - distances = [] - if data_len is None: - data_len = len(dataloader) * dataloader.batch_size - z0 = sample_random(rng, (data_len, generator.z_channels), "normal") - z1 = sample_random(rng, (data_len, generator.z_channels), "normal") - if latent_space == "Z": - z1 = batch_interp(z0, z1, epsilon, interp) - - distances = torch.zeros(data_len, dtype=torch.float32, device=tops.get_device()) - print(distances.shape) - end = 0 - n_samples = 0 - for it, batch in enumerate(utils.tqdm_(dataloader, desc="Perceptual Path Length")): - start = end - end = start + batch["img"].shape[0] - n_samples += batch["img"].shape[0] - batch_lat_e0 = tops.to_cuda(z0[start:end]) - batch_lat_e1 = tops.to_cuda(z1[start:end]) - if latent_space == "W": - w0 = generator.get_w(batch_lat_e0, update_emas=False) - w1 = generator.get_w(batch_lat_e1, update_emas=False) - w1 = w0.lerp(w1, epsilon) # PPL end - rgb1 = generator(**batch, w=w0)["img"] - rgb2 = generator(**batch, w=w1)["img"] - else: - rgb1 = generator(**batch, z=batch_lat_e0)["img"] - rgb2 = generator(**batch, z=batch_lat_e1)["img"] - rgb1 = utils.denormalize_img(rgb1).mul(255).byte() - rgb2 = utils.denormalize_img(rgb2).mul(255).byte() - - sim = sample_similarity(rgb1, rgb2) - dist_lat_e01 = sim / (epsilon ** 2) - distances[start:end] = dist_lat_e01.view(-1) - distances = distances[:n_samples] - distances = tops.all_gather_uneven(distances).cpu().numpy() - if tops.rank() != 0: - return {"ppl/mean": -1, "ppl/std": -1} - if tops.rank() == 0: - cond, lo, hi = None, None, None - if discard_percentile_lower is not None: - lo = np.percentile(distances, discard_percentile_lower, interpolation='lower') - cond = lo <= distances - if discard_percentile_higher is not None: - hi = np.percentile(distances, discard_percentile_higher, interpolation='higher') - cond = np.logical_and(cond, distances <= hi) - if cond is not None: - distances = np.extract(cond, distances) - return { - "ppl/mean": float(np.mean(distances)), - "ppl/std": float(np.std(distances)), - } - else: - return {"ppl/mean"} diff --git a/dp2/metrics/torch_metrics.py b/dp2/metrics/torch_metrics.py deleted file mode 100644 index a6682afbbbe9e5205d48743ac39db8c647607666..0000000000000000000000000000000000000000 --- a/dp2/metrics/torch_metrics.py +++ /dev/null @@ -1,176 +0,0 @@ -import pickle -import numpy as np -import torch -import time -from pathlib import Path -from dp2 import utils -import tops -from .lpips import SampleSimilarityLPIPS -from torch_fidelity.defaults import DEFAULTS as trf_defaults -from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric -from torch_fidelity.utils import create_feature_extractor -lpips_model = None -fid_model = None - -@torch.no_grad() -def mse(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: - se = (images1 - images2) ** 2 - se = se.view(images1.shape[0], -1).mean(dim=1) - return se - -@torch.no_grad() -def psnr(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: - mse_ = mse(images1, images2) - psnr = 10 * torch.log10(1 / mse_) - return psnr - -@torch.no_grad() -def lpips(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: - return _lpips_w_grad(images1, images2) - - -def _lpips_w_grad(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: - global lpips_model - if lpips_model is None: - lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) - - images1 = images1.mul(255) - images2 = images2.mul(255) - with torch.cuda.amp.autocast(tops.AMP()): - dists = lpips_model(images1, images2)[0].view(-1) - return dists - - - - -@torch.no_grad() -def compute_metrics_iteratively( - dataloader, generator, - cache_directory, - data_len=None, - truncation_value: float=None, - ) -> dict: - """ - Args: - n_samples (int): Creates N samples from same image to calculate stats - dataset_percentage (float): The percentage of the dataset to compute metrics on. - """ - - global lpips_model, fid_model - if lpips_model is None: - lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) - if fid_model is None: - fid_model = create_feature_extractor( - trf_defaults["feature_extractor"], [trf_defaults["feature_layer_fid"]], cuda=False) - fid_model = tops.to_cuda(fid_model) - cache_directory = Path(cache_directory) - start_time = time.time() - lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device()) - diversity_total = torch.zeros_like(lpips_total) - fid_cache_path = cache_directory.joinpath("fid_stats.pkl") - has_fid_cache = fid_cache_path.is_file() - if data_len is None: - data_len = len(dataloader)*dataloader.batch_size - if not has_fid_cache: - fid_features_real = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device()) - fid_features_fake = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device()) - n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device()) - eidx = 0 - for batch in utils.tqdm_(iter(dataloader), desc="Computing FID, LPIPS and LPIPS Diversity"): - sidx = eidx - eidx = sidx + batch["img"].shape[0] - n_samples_seen += batch["img"].shape[0] - with torch.cuda.amp.autocast(tops.AMP()): - fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"] - fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"] - fakes1 = utils.denormalize_img(fakes1).mul(255) - fakes2 = utils.denormalize_img(fakes2).mul(255) - real_data = utils.denormalize_img(batch["img"]).mul(255) - lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1) - fake2_lpips_feats = lpips_model.get_feats(fakes2) - lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats) - - lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2) - diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum() - if not has_fid_cache: - fid_features_real[sidx:eidx] = fid_model(real_data.byte())[0] - fid_features_fake[sidx:eidx] = fid_model(fakes1.byte())[0] - fid_features_fake = fid_features_fake[:n_samples_seen] - if has_fid_cache: - if tops.rank() == 0: - with open(fid_cache_path, "rb") as fp: - fid_stat_real = pickle.load(fp) - else: - fid_features_real = fid_features_real[:n_samples_seen] - fid_features_real = tops.all_gather_uneven(fid_features_real).cpu() - if tops.rank() == 0: - fid_stat_real = fid_features_to_statistics(fid_features_real) - cache_directory.mkdir(exist_ok=True, parents=True) - with open(fid_cache_path, "wb") as fp: - pickle.dump(fid_stat_real, fp) - fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu() - if tops.rank() == 0: - print("Starting calculation of fid from features of shape:", fid_features_fake.shape) - fid_stat_fake = fid_features_to_statistics(fid_features_fake) - fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"] - tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM) - tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM) - tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM) - lpips_total = lpips_total / n_samples_seen - diversity_total = diversity_total / n_samples_seen - to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total) - if tops.rank() == 0: - to_return["fid"] = fid_ - else: - to_return["fid"] = -1 - to_return["validation_time_s"] = time.time() - start_time - return to_return - - -@torch.no_grad() -def compute_lpips( - dataloader, generator, - truncation_value: float=None, - data_len=None, - ) -> dict: - """ - Args: - n_samples (int): Creates N samples from same image to calculate stats - dataset_percentage (float): The percentage of the dataset to compute metrics on. - """ - global lpips_model, fid_model - if lpips_model is None: - lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) - start_time = time.time() - lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device()) - diversity_total = torch.zeros_like(lpips_total) - if data_len is None: - data_len = len(dataloader) * dataloader.batch_size - eidx = 0 - n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device()) - for batch in utils.tqdm_(dataloader, desc="Validating on dataset."): - sidx = eidx - eidx = sidx + batch["img"].shape[0] - n_samples_seen += batch["img"].shape[0] - with torch.cuda.amp.autocast(tops.AMP()): - fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"] - fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"] - real_data = batch["img"] - fakes1 = utils.denormalize_img(fakes1).mul(255) - fakes2 = utils.denormalize_img(fakes2).mul(255) - real_data = utils.denormalize_img(real_data).mul(255) - lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1) - fake2_lpips_feats = lpips_model.get_feats(fakes2) - lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats) - - lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2) - diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum() - tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM) - tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM) - tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM) - lpips_total = lpips_total / n_samples_seen - diversity_total = diversity_total / n_samples_seen - to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total) - to_return = {k: v.cpu().item() for k, v in to_return.items()} - to_return["validation_time_s"] = time.time() - start_time - return to_return diff --git a/dp2/utils/__init__.py b/dp2/utils/__init__.py deleted file mode 100644 index a28c97be96580185208e5c6e56b1307a0fd9b9be..0000000000000000000000000000000000000000 --- a/dp2/utils/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -import pathlib -from tops.config import LazyConfig -from .torch_utils import ( - im2torch, im2numpy, denormalize_img, set_requires_grad, forward_D_fake, - binary_dilation, crop_box, remove_pad -) -from .ema import EMA -from .utils import init_tops, tqdm_, print_config, config_to_str, trange_ -from .cse import from_E_to_vertex - - -def load_config(config_path): - config_path = pathlib.Path(config_path) - assert config_path.is_file(), config_path - cfg = LazyConfig.load(str(config_path)) - cfg.output_dir = pathlib.Path(str(config_path).replace("configs", str(cfg.common.output_dir)).replace(".py", "")) - if cfg.common.experiment_name is None: - cfg.experiment_name = str(config_path) - else: - cfg.experiment_name = cfg.common.experiment_name - cfg.checkpoint_dir = cfg.output_dir.joinpath("checkpoints") - print("Saving outputs to:", cfg.output_dir) - return cfg diff --git a/dp2/utils/bufferless_video_capture.py b/dp2/utils/bufferless_video_capture.py deleted file mode 100644 index b071c1c4316ad48127c86c4f52ca40f66530edf7..0000000000000000000000000000000000000000 --- a/dp2/utils/bufferless_video_capture.py +++ /dev/null @@ -1,32 +0,0 @@ -import queue -import threading -import cv2 - - -class BufferlessVideoCapture: - - def __init__(self, name, width=None, height=None): - self.cap = cv2.VideoCapture(name) - if width is not None and height is not None: - self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) - self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) - self.q = queue.Queue() - t = threading.Thread(target=self._reader) - t.daemon = True - t.start() - - # read frames as soon as they are available, keeping only most recent one - def _reader(self): - while True: - ret, frame = self.cap.read() - if not ret: - break - if not self.q.empty(): - try: - self.q.get_nowait() # discard previous (unprocessed) frame - except queue.Empty: - pass - self.q.put((ret, frame)) - - def read(self): - return self.q.get() \ No newline at end of file diff --git a/dp2/utils/cse.py b/dp2/utils/cse.py deleted file mode 100644 index f1b5a2b55cb912df0a50961eafb260eb297b0d4f..0000000000000000000000000000000000000000 --- a/dp2/utils/cse.py +++ /dev/null @@ -1,21 +0,0 @@ -import warnings -import torch -from densepose.modeling.cse.utils import get_closest_vertices_mask_from_ES - - -def from_E_to_vertex(E, M, embed_map): - """ - M is 1 for unkown regions - """ - assert len(E.shape) == 4 - assert len(E.shape) == len(M.shape), (E.shape, M.shape) - assert E.shape[0] == 1 - M = M.float() - M = torch.cat([M, 1-M], dim=1) - with warnings.catch_warnings(): # Ignore userError for pytorch interpolate from detectron2 - warnings.filterwarnings("ignore") - vertices, _ = get_closest_vertices_mask_from_ES( - E, M, E.shape[2], E.shape[3], - embed_map, device=E.device) - - return vertices.long() diff --git a/dp2/utils/ema.py b/dp2/utils/ema.py deleted file mode 100644 index 0c508213d4c445e2417607a7d3957d3ec953eb1f..0000000000000000000000000000000000000000 --- a/dp2/utils/ema.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import copy -import tops -from tops import logger -from .torch_utils import set_requires_grad - -class EMA: - """ - Expoenential moving average. - See: - Yazici, Y. et al.The unusual effectiveness of averaging in GAN training. ICLR 2019 - - """ - - def __init__( - self, - generator: torch.nn.Module, - batch_size: int, - rampup: float, - ): - self.rampup = rampup - self._nimg_half_time = batch_size * 10 / 32 * 1000 - self._batch_size = batch_size - with torch.no_grad(): - self.generator = copy.deepcopy(generator.cpu()).eval() - self.generator = tops.to_cuda(self.generator) - self.old_ra_beta = 0 - set_requires_grad(self.generator, False) - - def update_beta(self): - y = self._nimg_half_time - global_step = logger.global_step() - if self.rampup != None: - y = min(y, global_step*self.rampup) - self.ra_beta = 0.5 ** (self._batch_size/max(y, 1e-8)) - if self.ra_beta != self.old_ra_beta: - logger.add_scalar("stats/EMA_beta", self.ra_beta) - self.old_ra_beta = self.ra_beta - - @torch.no_grad() - def update(self, normal_G): - with torch.autograd.profiler.record_function("EMA_update"): - for ema_p, p in zip(self.generator.parameters(), - normal_G.parameters()): - ema_p.copy_(p.lerp(ema_p, self.ra_beta)) - for ema_buf, buff in zip(self.generator.buffers(), - normal_G.buffers()): - ema_buf.copy_(buff) - - def __call__(self, *args, **kwargs): - return self.generator(*args, **kwargs) - - def __getattr__(self, name: str): - if hasattr(self.generator, name): - return getattr(self.generator, name) - raise AttributeError(f"Generator object has no attribute {name}") - - def cuda(self, *args, **kwargs): - self.generator = self.generator.cuda() - return self - - def state_dict(self, *args, **kwargs): - return self.generator.state_dict(*args, **kwargs) - - def load_state_dict(self, *args, **kwargs): - return self.generator.load_state_dict(*args, **kwargs) - - def eval(self): - self.generator.eval() - - def train(self): - self.generator.train() - - @property - def module(self): - return self.generator.module - - def sample(self, *args, **kwargs): - return self.generator.sample(*args, **kwargs) diff --git a/dp2/utils/torch_utils.py b/dp2/utils/torch_utils.py deleted file mode 100644 index 46defb854b11a615aad3b918dce4a650b0b80889..0000000000000000000000000000000000000000 --- a/dp2/utils/torch_utils.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import tops - - -def denormalize_img(image, mean=0.5, std=0.5): - image = image * std + mean - image = torch.clamp(image.float(), 0, 1) - image = (image * 255) - image = torch.round(image) - return image / 255 - - -@torch.no_grad() -def im2numpy(images, to_uint8=False, denormalize=False): - if denormalize: - images = denormalize_img(images) - if images.dtype != torch.uint8: - images = images.clamp(0, 1) - return tops.im2numpy(images, to_uint8=to_uint8) - - -@torch.no_grad() -def im2torch(im, cuda=False, normalize=True, to_float=True): - im = tops.im2torch(im, cuda=cuda, to_float=to_float) - if normalize: - assert im.min() >= 0.0 and im.max() <= 1.0 - if normalize: - im = im * 2 - 1 - return im - - -@torch.no_grad() -def binary_dilation(im: torch.Tensor, kernel: torch.Tensor): - assert len(im.shape) == 4 - assert len(kernel.shape) == 2 - kernel = kernel.unsqueeze(0).unsqueeze(0) - padding = kernel.shape[-1]//2 - assert kernel.shape[-1] % 2 != 0 - if isinstance(im, torch.cuda.FloatTensor): - im, kernel = im.half(), kernel.half() - else: - im, kernel = im.float(), kernel.float() - im = torch.nn.functional.conv2d( - im, kernel, groups=im.shape[1], padding=padding) - im = im > 0.5 - return im - - -@torch.no_grad() -def binary_erosion(im: torch.Tensor, kernel: torch.Tensor): - assert len(im.shape) == 4 - assert len(kernel.shape) == 2 - kernel = kernel.unsqueeze(0).unsqueeze(0) - padding = kernel.shape[-1]//2 - assert kernel.shape[-1] % 2 != 0 - if isinstance(im, torch.cuda.FloatTensor): - im, kernel = im.half(), kernel.half() - else: - im, kernel = im.float(), kernel.float() - ksum = kernel.sum() - padding = (padding, padding, padding, padding) - im = torch.nn.functional.pad(im, padding, mode="reflect") - im = torch.nn.functional.conv2d( - im, kernel, groups=im.shape[1]) - return im.round() == ksum - - -def set_requires_grad(value: torch.nn.Module, requires_grad: bool): - if isinstance(value, (list, tuple)): - for param in value: - param.requires_grad = requires_grad - return - for p in value.parameters(): - p.requires_grad = requires_grad - - -def forward_D_fake(batch, fake_img, discriminator, **kwargs): - fake_batch = {k: v for k, v in batch.items() if k != "img"} - fake_batch["img"] = fake_img - return discriminator(**fake_batch, **kwargs) - - - -def remove_pad(x: torch.Tensor, bbox_XYXY, imshape): - """ - Remove padding that is shown as negative - """ - H, W = imshape - x0, y0, x1, y1 = bbox_XYXY - padding = [ - max(0, -x0), - max(0, -y0), - max(x1 - W, 0), - max(y1 - H, 0) - ] - x0, y0 = padding[:2] - x1 = x.shape[2] - padding[2] - y1 = x.shape[1] - padding[3] - return x[:, y0:y1, x0:x1] - - -def crop_box(x: torch.Tensor, bbox_XYXY) -> torch.Tensor: - """ - Crops x by bbox_XYXY. - """ - x0, y0, x1, y1 = bbox_XYXY - x0 = max(x0, 0) - y0 = max(y0, 0) - x1 = min(x1, x.shape[-1]) - y1 = min(y1, x.shape[-2]) - return x[..., y0:y1, x0:x1] \ No newline at end of file diff --git a/dp2/utils/utils.py b/dp2/utils/utils.py deleted file mode 100644 index 965eb6ae987bd6b5644e50116bfa6317e4e36769..0000000000000000000000000000000000000000 --- a/dp2/utils/utils.py +++ /dev/null @@ -1,30 +0,0 @@ -import tops -import tqdm -from tops import logger, highlight_py_str -from tops.config import LazyConfig - - -def print_config(cfg): - logger.log("\n" + highlight_py_str(LazyConfig.to_py(cfg, prefix=""))) - - -def config_to_str(cfg): - return LazyConfig.to_py(cfg, prefix=".") - - -def init_tops(cfg, reinit=False): - tops.init( - cfg.output_dir, cfg.common.logger_backend, cfg.experiment_name, - cfg.common.wandb_project, dict(cfg), reinit) - - -def tqdm_(iterator, *args, **kwargs): - if tops.rank() == 0: - return tqdm.tqdm(iterator, *args, **kwargs) - return iterator - - -def trange_(*args, **kwargs): - if tops.rank() == 0: - return tqdm.trange(*args, **kwargs) - return range(*args) \ No newline at end of file diff --git a/dp2/utils/vis_utils.py b/dp2/utils/vis_utils.py deleted file mode 100644 index ee227f359b4ab3af3a1123e2900373fed01f6c38..0000000000000000000000000000000000000000 --- a/dp2/utils/vis_utils.py +++ /dev/null @@ -1,407 +0,0 @@ -import torch -import tops -import cv2 -import torchvision.transforms.functional as F -from typing import Optional, List, Union, Tuple -from .cse import from_E_to_vertex -import numpy as np -from tops import download_file -from .torch_utils import ( - denormalize_img, binary_dilation, binary_erosion, - remove_pad, crop_box) -from torchvision.utils import _generate_color_palette -from PIL import Image, ImageColor, ImageDraw - - -def get_coco_keypoints(): - # From: https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/keypoints.py - keypoints = [ - 'nose', - 'left_eye', - 'right_eye', - 'left_ear', - 'right_ear', - 'left_shoulder', - 'right_shoulder', - 'left_elbow', - 'right_elbow', - 'left_wrist', - 'right_wrist', - 'left_hip', - 'right_hip', - 'left_knee', - 'right_knee', - 'left_ankle', - 'right_ankle' - ] - keypoint_flip_map = { - 'left_eye': 'right_eye', - 'left_ear': 'right_ear', - 'left_shoulder': 'right_shoulder', - 'left_elbow': 'right_elbow', - 'left_wrist': 'right_wrist', - 'left_hip': 'right_hip', - 'left_knee': 'right_knee', - 'left_ankle': 'right_ankle' - } - connectivity = { - "nose": "left_eye", - "left_eye": "right_eye", - "right_eye": "nose", - "left_ear": "left_eye", - "right_ear": "right_eye", - "left_shoulder": "nose", - "right_shoulder": "nose", - "left_elbow": "left_shoulder", - "right_elbow": "right_shoulder", - "left_wrist": "left_elbow", - "right_wrist": "right_elbow", - "left_hip": "left_shoulder", - "right_hip": "right_shoulder", - "left_knee": "left_hip", - "right_knee": "right_hip", - "left_ankle": "left_knee", - "right_ankle": "right_knee" - } - connectivity_indices = [ - (sidx, keypoints.index(connectivity[kp])) - for sidx, kp in enumerate(keypoints) - ] - return keypoints, keypoint_flip_map, connectivity_indices - - -@torch.no_grad() -def draw_keypoints( - image: torch.Tensor, - keypoints: torch.Tensor, - connectivity: Optional[List[Tuple[int, int]]] = None, - colors: Optional[Union[str, Tuple[int, int, int]]] = None, - radius: int = 1, - width: int = 1, -) -> torch.Tensor: - - - """ - Function taken from torchvision source code. Added in torchvision 0.12 - - Draws Keypoints on given RGB image. - The values of the input image should be uint8 between 0 and 255. - - Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. - keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, - in the format [x, y]. - connectivity (List[Tuple[int, int]]]): A List of tuple where, - each tuple contains pair of keypoints to be connected. - colors (str, Tuple): The color can be represented as - PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - radius (int): Integer denoting radius of keypoint. - width (int): Integer denoting width of line connecting keypoints. - - Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. - """ - - if not isinstance(image, torch.Tensor): - raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") - elif image.dim() != 3: - raise ValueError("Pass individual images, not batches") - elif image.size()[0] != 3: - raise ValueError("Pass an RGB image. Other Image formats are not supported") - - if keypoints.ndim != 3: - raise ValueError("keypoints must be of shape (num_instances, K, 2)") - - ndarr = image.permute(1, 2, 0).cpu().numpy() - img_to_draw = Image.fromarray(ndarr) - draw = ImageDraw.Draw(img_to_draw) - img_kpts = keypoints.to(torch.int64).tolist() - - for kpt_id, kpt_inst in enumerate(img_kpts): - for inst_id, kpt in enumerate(kpt_inst): - x1 = kpt[0] - radius - x2 = kpt[0] + radius - y1 = kpt[1] - radius - y2 = kpt[1] + radius - draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) - - if connectivity: - for connection in connectivity: - if connection[1] >= len(kpt_inst) or connection[0] >= len(kpt_inst): - continue - start_pt_x = kpt_inst[connection[0]][0] - start_pt_y = kpt_inst[connection[0]][1] - - end_pt_x = kpt_inst[connection[1]][0] - end_pt_y = kpt_inst[connection[1]][1] - - draw.line( - ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), - width=width, - ) - - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) - -def visualize_batch( - img: torch.Tensor, mask: torch.Tensor, - vertices: torch.Tensor=None, - E_mask: torch.Tensor=None, - embed_map: torch.Tensor=None, - semantic_mask: torch.Tensor=None, - embedding: torch.Tensor=None, - keypoints: torch.Tensor=None, - maskrcnn_mask: torch.Tensor=None, - **kwargs) -> torch.ByteTensor: - img = denormalize_img(img).mul(255).byte() - img = draw_mask(img, mask) - if maskrcnn_mask is not None: - img = draw_mask(img, maskrcnn_mask) - if vertices is not None or embedding is not None: - assert E_mask is not None - assert embed_map is not None - img = draw_cse(img, E_mask, embedding, embed_map, vertices) - elif semantic_mask is not None: - img = draw_segmentation_masks(img, semantic_mask) - if keypoints is not None: - keypoints = keypoints.clone() - keypoints[:, :, 0] *= img.shape[-1] - keypoints[:, :, 1] *= img.shape[-2] - _, _, connectivity = get_coco_keypoints() - connectivity = np.array(connectivity) - for idx in range(img.shape[0]): - if keypoints.shape[-1] == 3: - visible = (keypoints[idx:idx+1, :, 2] > 0 ).view(-1) - else: - visible = torch.ones(keypoints.shape[1], device=keypoints.device, dtype=torch.bool) - - if keypoints.shape[1] == 17: # COCO Connectivity - c = connectivity[visible.cpu().numpy()].tolist() - else: - c = None - - kp = keypoints[idx:idx+1, visible].long() - img[idx] = draw_keypoints(img[idx], kp, colors="red", connectivity=c) - return img - - -@torch.no_grad() -def draw_cse( - img: torch.Tensor, E_seg: torch.Tensor, - embedding: torch.Tensor = None, - embed_map: torch.Tensor = None, - vertices: torch.Tensor = None, t=0.7 - ): - """ - E_seg: 1 for areas with embedding - """ - assert img.dtype == torch.uint8 - img = img.view(-1, *img.shape[-3:]) - E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:]) - if vertices is None: - assert embedding is not None - assert embed_map is not None - embedding = embedding.view(-1, *embedding.shape[-3:]) - vertices = torch.stack( - [from_E_to_vertex(e[None], e_seg[None].logical_not().float(), embed_map) - for e, e_seg in zip(embedding, E_seg)]) - - i = np.arange(0, 256, dtype=np.uint8).reshape(1, -1) - colormap_JET = torch.from_numpy(cv2.applyColorMap(i, cv2.COLORMAP_JET)[0]) - color_embed_map, _ = np.load(download_file("https://dl.fbaipublicfiles.com/densepose/data/cse/mds_d=256.npy"), allow_pickle=True) - color_embed_map = torch.from_numpy(color_embed_map).float()[:, 0] - color_embed_map -= color_embed_map.min() - color_embed_map /= color_embed_map.max() - vertx2idx = (color_embed_map*255).long() - vertx2colormap = colormap_JET[vertx2idx] - - vertices = vertices.view(-1, *vertices.shape[-2:]) - E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:]) - # This operation might be good to do on cpu... - E_color = vertx2colormap[vertices.long()] - E_color = E_color.to(E_seg.device) - E_color = E_color.permute(0, 3, 1, 2) - E_color = E_color*E_seg.byte() - - m = E_seg.bool().repeat(1, 3, 1, 1) - img[m] = (img[m] * (1-t) + t * E_color[m]).byte() - return img - - -def draw_cse_all( - embedding: List[torch.Tensor], E_mask: List[torch.Tensor], - im: torch.Tensor, boxes_XYXY: list, embed_map: torch.Tensor, t=0.7): - """ - E_seg: 1 for areas with embedding - """ - assert len(im.shape) == 3, im.shape - assert im.dtype == torch.uint8 - - N = len(E_mask) - im = im.clone() - for i in range(N): - assert len(E_mask[i].shape) == 2 - assert len(embedding[i].shape) == 3 - assert embed_map.shape[1] == embedding[i].shape[0] - assert len(boxes_XYXY[i]) == 4 - E = embedding[i] - x0, y0, x1, y1 = boxes_XYXY[i] - E = F.resize(E, (y1-y0, x1-x0), antialias=True) - s = E_mask[i].float() - s = (F.resize(s.squeeze()[None], (y1-y0, x1-x0), antialias=True) > 0).float() - box = boxes_XYXY[i] - - im_ = crop_box(im, box) - s = remove_pad(s, box, im.shape[1:]) - E = remove_pad(E, box, im.shape[1:]) - E_color = draw_cse(img=im_, E_seg=s[None], embedding=E[None],embed_map=embed_map)[0] - E_color = E_color.to(im.device) - s = s.bool().repeat(3, 1, 1) - crop_box(im, box)[s] = (im_[s] * (1-t) + t * E_color[s]).byte() - return im - - - -@torch.no_grad() -def draw_segmentation_masks( - image: torch.Tensor, - masks: torch.Tensor, - alpha: float = 0.8, - colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, -) -> torch.Tensor: - - """ - Draws segmentation masks on given RGB image. - The values of the input image should be uint8 between 0 and 255. - - Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. - masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. - alpha (float): Float number between 0 and 1 denoting the transparency of the masks. - 0 means full transparency, 1 means no transparency. - colors (list or None): List containing the colors of the masks. The colors can - be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list - with one element. By default, random colors are generated for each mask. - - Returns: - img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. - """ - - if not isinstance(image, torch.Tensor): - raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") - elif image.dim() != 3: - raise ValueError("Pass individual images, not batches") - elif image.size()[0] != 3: - raise ValueError("Pass an RGB image. Other Image formats are not supported") - if masks.ndim == 2: - masks = masks[None, :, :] - if masks.ndim != 3: - raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") - if masks.dtype != torch.bool: - raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") - if masks.shape[-2:] != image.shape[-2:]: - raise ValueError("The image and the masks must have the same height and width") - num_masks = masks.size()[0] - if num_masks == 0: - return image - if colors is None: - colors = _generate_color_palette(num_masks) - if not isinstance(colors[0], (Tuple, List)): - colors = [colors for i in range(num_masks)] - if colors is not None and num_masks > len(colors): - raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") - - - if not isinstance(colors, list): - colors = [colors] - if not isinstance(colors[0], (tuple, str)): - raise ValueError("colors must be a tuple or a string, or a list thereof") - if isinstance(colors[0], tuple) and len(colors[0]) != 3: - raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") - - out_dtype = torch.uint8 - - colors_ = [] - for color in colors: - if isinstance(color, str): - color = ImageColor.getrgb(color) - color = torch.tensor(color, dtype=out_dtype, device=masks.device) - colors_.append(color) - img_to_draw = image.detach().clone() - # TODO: There might be a way to vectorize this - for mask, color in zip(masks, colors_): - img_to_draw[:, mask] = color[:, None] - - out = image * (1 - alpha) + img_to_draw * alpha - return out.to(out_dtype) - - -def draw_mask(im: torch.Tensor, mask: torch.Tensor, t=0.2, color=(255, 255, 255), visualize_instances=True): - """ - Visualize mask where mask = 0. - Supports multiple instances. - mask shape: [N, C, H, W], where C is different instances in same image. - """ - orig_imshape = im.shape - if mask.numel() == 0: return im - assert len(mask.shape) in (3, 4), mask.shape - mask = mask.view(-1, *mask.shape[-3:]) - im = im.view(-1, *im.shape[-3:]) - assert im.dtype == torch.uint8, im.dtype - assert 0 <= t <= 1 - if not visualize_instances: - mask = mask.any(dim=1, keepdim=True) - mask = mask.bool() - kernel = torch.ones((3, 3), dtype=mask.dtype, device=mask.device) - outer_border = binary_dilation(mask, kernel).logical_xor(mask) - outer_border = outer_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0 - inner_border = binary_erosion(mask, kernel).logical_xor(mask) - inner_border = inner_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0 - mask = (mask == 0).any(dim=1, keepdim=True).repeat(1, 3, 1, 1) - color = torch.tensor(color).to(im.device).byte().view(1, 3, 1, 1)#.repeat(1, *im.shape[1:]) - color = color.repeat(im.shape[0], 1, *im.shape[-2:]) - im[mask] = (im[mask] * (1-t) + t * color[mask]).byte() - im[outer_border] = 255 - im[inner_border] = 0 - return im.view(*orig_imshape) - - -def draw_cropped_masks(im: torch.Tensor, mask: torch.Tensor, boxes: torch.Tensor, **kwargs): - for i, box in enumerate(boxes): - x0, y0, x1, y1 = boxes[i] - orig_shape = (y1-y0, x1-x0) - m = F.resize(mask[i], orig_shape, F.InterpolationMode.NEAREST).squeeze()[None] - m = remove_pad(m, boxes[i], im.shape[-2:]) - crop_box(im, boxes[i]).set_(draw_mask(crop_box(im, boxes[i]), m)) - return im - - -def draw_cropped_keypoints(im: torch.Tensor, all_keypoints: torch.Tensor, boxes: torch.Tensor, **kwargs): - n_boxes = boxes.shape[0] - tops.assert_shape(all_keypoints, (n_boxes, 17, 3)) - im = im.clone() - for i, box in enumerate(boxes): - - x0, y0, x1, y1 = boxes[i] - orig_shape = (y1-y0, x1-x0) - keypoints = all_keypoints[i].clone() - keypoints[:, 0] *= orig_shape[1] - keypoints[:, 1] *= orig_shape[0] - keypoints = keypoints.long() - _, _, connectivity = get_coco_keypoints() - connectivity = np.array(connectivity) - visible = (keypoints[:, 2] > 0) - if keypoints.shape[0] == 17: # COCO Connectivity - c = connectivity[visible.cpu().numpy()].tolist() - else: - c = None - # Remove padding from keypoints before visualization - keypoints[:, 0] += min(x0, 0) - keypoints[:, 1] += min(y0, 0) - im_with_kp = draw_keypoints(crop_box(im, box), keypoints[None, visible, :2], colors="red", connectivity=c) - crop_box(im, box).copy_(im_with_kp) - return im diff --git a/media b/media new file mode 120000 index 0000000000000000000000000000000000000000..4de87a74095d5613c3ca247cb84bd26f286e1ada --- /dev/null +++ b/media @@ -0,0 +1 @@ +deep_privacy2/media \ No newline at end of file diff --git a/sg3_torch_utils/LICENSE.txt b/sg3_torch_utils/LICENSE.txt deleted file mode 100644 index 6b5ee9bf994cc9441cb659c3527160b4ee5bcb33..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/LICENSE.txt +++ /dev/null @@ -1,97 +0,0 @@ -Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved. - - -NVIDIA Source Code License for StyleGAN3 - - -======================================================================= - -1. Definitions - -"Licensor" means any person or entity that distributes its Work. - -"Software" means the original work of authorship made available under -this License. - -"Work" means the Software and any additions to or derivative works of -the Software that are made available under this License. - -The terms "reproduce," "reproduction," "derivative works," and -"distribution" have the meaning as provided under U.S. copyright law; -provided, however, that for the purposes of this License, derivative -works shall not include works that remain separable from, or merely -link (or bind by name) to the interfaces of, the Work. - -Works, including the Software, are "made available" under this License -by including in or with the Work either (a) a copyright notice -referencing the applicability of this License to the Work, or (b) a -copy of this License. - -2. License Grants - - 2.1 Copyright Grant. Subject to the terms and conditions of this - License, each Licensor grants to you a perpetual, worldwide, - non-exclusive, royalty-free, copyright license to reproduce, - prepare derivative works of, publicly display, publicly perform, - sublicense and distribute its Work and any resulting derivative - works in any form. - -3. Limitations - - 3.1 Redistribution. You may reproduce or distribute the Work only - if (a) you do so under this License, (b) you include a complete - copy of this License with your distribution, and (c) you retain - without modification any copyright, patent, trademark, or - attribution notices that are present in the Work. - - 3.2 Derivative Works. You may specify that additional or different - terms apply to the use, reproduction, and distribution of your - derivative works of the Work ("Your Terms") only if (a) Your Terms - provide that the use limitation in Section 3.3 applies to your - derivative works, and (b) you identify the specific derivative - works that are subject to Your Terms. Notwithstanding Your Terms, - this License (including the redistribution requirements in Section - 3.1) will continue to apply to the Work itself. - - 3.3 Use Limitation. The Work and any derivative works thereof only - may be used or intended for use non-commercially. Notwithstanding - the foregoing, NVIDIA and its affiliates may use the Work and any - derivative works commercially. As used herein, "non-commercially" - means for research or evaluation purposes only. - - 3.4 Patent Claims. If you bring or threaten to bring a patent claim - against any Licensor (including any claim, cross-claim or - counterclaim in a lawsuit) to enforce any patents that you allege - are infringed by any Work, then your rights under this License from - such Licensor (including the grant in Section 2.1) will terminate - immediately. - - 3.5 Trademarks. This License does not grant any rights to use any - Licensor’s or its affiliates’ names, logos, or trademarks, except - as necessary to reproduce the notices described in this License. - - 3.6 Termination. If you violate any term of this License, then your - rights under this License (including the grant in Section 2.1) will - terminate immediately. - -4. Disclaimer of Warranty. - -THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY -KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR -NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER -THIS LICENSE. - -5. Limitation of Liability. - -EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL -THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE -SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, -INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF -OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK -(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, -LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER -COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF -THE POSSIBILITY OF SUCH DAMAGES. - -======================================================================= diff --git a/sg3_torch_utils/__init__.py b/sg3_torch_utils/__init__.py deleted file mode 100755 index ece0ea08fe2e939cc260a1dafc0ab5b391b773d9..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -# empty diff --git a/sg3_torch_utils/custom_ops.py b/sg3_torch_utils/custom_ops.py deleted file mode 100755 index 4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/custom_ops.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -import os -import glob -import torch -import torch.utils.cpp_extension -import importlib -import hashlib -import shutil -from pathlib import Path - -from torch.utils.file_baton import FileBaton - -#---------------------------------------------------------------------------- -# Global options. - -verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' - -#---------------------------------------------------------------------------- -# Internal helper funcs. - -def _find_compiler_bindir(): - patterns = [ - 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', - ] - for pattern in patterns: - matches = sorted(glob.glob(pattern)) - if len(matches): - return matches[-1] - return None - -#---------------------------------------------------------------------------- -# Main entry point for compiling and loading C++/CUDA plugins. - -_cached_plugins = dict() - -def get_plugin(module_name, sources, **build_kwargs): - assert verbosity in ['none', 'brief', 'full'] - - # Already cached? - if module_name in _cached_plugins: - return _cached_plugins[module_name] - - # Print status. - if verbosity == 'full': - print(f'Setting up PyTorch plugin "{module_name}"...') - elif verbosity == 'brief': - print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) - - try: # pylint: disable=too-many-nested-blocks - # Make sure we can find the necessary compiler binaries. - if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: - compiler_bindir = _find_compiler_bindir() - if compiler_bindir is None: - raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') - os.environ['PATH'] += ';' + compiler_bindir - - # Compile and load. - verbose_build = (verbosity == 'full') - - # Incremental build md5sum trickery. Copies all the input source files - # into a cached build directory under a combined md5 digest of the input - # source files. Copying is done only if the combined digest has changed. - # This keeps input file timestamps and filenames the same as in previous - # extension builds, allowing for fast incremental rebuilds. - # - # This optimization is done only in case all the source files reside in - # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR - # environment variable is set (we take this as a signal that the user - # actually cares about this.) - source_dirs_set = set(os.path.dirname(source) for source in sources) - if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): - all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) - - # Compute a combined hash digest for all source files in the same - # custom op directory (usually .cu, .cpp, .py and .h files). - hash_md5 = hashlib.md5() - for src in all_source_files: - with open(src, 'rb') as f: - hash_md5.update(f.read()) - build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access - digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) - - if not os.path.isdir(digest_build_dir): - os.makedirs(digest_build_dir, exist_ok=True) - baton = FileBaton(os.path.join(digest_build_dir, 'lock')) - if baton.try_acquire(): - try: - for src in all_source_files: - shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) - finally: - baton.release() - else: - # Someone else is copying source files under the digest dir, - # wait until done and continue. - baton.wait() - digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] - torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, - verbose=verbose_build, sources=digest_sources, **build_kwargs) - else: - torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) - module = importlib.import_module(module_name) - - except: - if verbosity == 'brief': - print('Failed!') - raise - - # Print status and add to cache. - if verbosity == 'full': - print(f'Done setting up PyTorch plugin "{module_name}".') - elif verbosity == 'brief': - print('Done.') - _cached_plugins[module_name] = module - return module - -#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/misc.py b/sg3_torch_utils/misc.py deleted file mode 100755 index 10d8e31880affdd185580b6f5b98e92c79597dc3..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/misc.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -import re -import contextlib -import numpy as np -import torch -import warnings - -#---------------------------------------------------------------------------- -# Cached construction of constant tensors. Avoids CPU=>GPU copy when the -# same constant is used multiple times. - -_constant_cache = dict() - -def constant(value, shape=None, dtype=None, device=None, memory_format=None): - value = np.asarray(value) - if shape is not None: - shape = tuple(shape) - if dtype is None: - dtype = torch.get_default_dtype() - if device is None: - device = torch.device('cpu') - if memory_format is None: - memory_format = torch.contiguous_format - - key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) - tensor = _constant_cache.get(key, None) - if tensor is None: - tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) - if shape is not None: - tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) - tensor = tensor.contiguous(memory_format=memory_format) - _constant_cache[key] = tensor - return tensor - -#---------------------------------------------------------------------------- -# Replace NaN/Inf with specified numerical values. - -try: - nan_to_num = torch.nan_to_num # 1.8.0a0 -except AttributeError: - def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin - assert isinstance(input, torch.Tensor) - if posinf is None: - posinf = torch.finfo(input.dtype).max - if neginf is None: - neginf = torch.finfo(input.dtype).min - assert nan == 0 - return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) - -#---------------------------------------------------------------------------- -# Symbolic assert. - -try: - symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access -except AttributeError: - symbolic_assert = torch.Assert # 1.7.0 - -#---------------------------------------------------------------------------- -# Context manager to suppress known warnings in torch.jit.trace(). - -class suppress_tracer_warnings(warnings.catch_warnings): - def __enter__(self): - super().__enter__() - warnings.simplefilter('ignore', category=torch.jit.TracerWarning) - return self - -#---------------------------------------------------------------------------- -# Assert that the shape of a tensor matches the given list of integers. -# None indicates that the size of a dimension is allowed to vary. -# Performs symbolic assertion when used in torch.jit.trace(). - -def assert_shape(tensor, ref_shape): - if tensor.ndim != len(ref_shape): - raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') - for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): - if ref_size is None: - pass - elif isinstance(ref_size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') - elif isinstance(size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') - elif size != ref_size: - raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') - -#---------------------------------------------------------------------------- -# Function decorator that calls torch.autograd.profiler.record_function(). - -def profiled_function(fn): - def decorator(*args, **kwargs): - with torch.autograd.profiler.record_function(fn.__name__): - return fn(*args, **kwargs) - decorator.__name__ = fn.__name__ - return decorator - -#---------------------------------------------------------------------------- -# Sampler for torch.utils.data.DataLoader that loops over the dataset -# indefinitely, shuffling items as it goes. - -class InfiniteSampler(torch.utils.data.Sampler): - def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): - assert len(dataset) > 0 - assert num_replicas > 0 - assert 0 <= rank < num_replicas - assert 0 <= window_size <= 1 - super().__init__(dataset) - self.dataset = dataset - self.rank = rank - self.num_replicas = num_replicas - self.shuffle = shuffle - self.seed = seed - self.window_size = window_size - - def __iter__(self): - order = np.arange(len(self.dataset)) - rnd = None - window = 0 - if self.shuffle: - rnd = np.random.RandomState(self.seed) - rnd.shuffle(order) - window = int(np.rint(order.size * self.window_size)) - - idx = 0 - while True: - i = idx % order.size - if idx % self.num_replicas == self.rank: - yield order[i] - if window >= 2: - j = (i - rnd.randint(window)) % order.size - order[i], order[j] = order[j], order[i] - idx += 1 - -#---------------------------------------------------------------------------- -# Utilities for operating with torch.nn.Module parameters and buffers. - -def params_and_buffers(module): - assert isinstance(module, torch.nn.Module) - return list(module.parameters()) + list(module.buffers()) - -def named_params_and_buffers(module): - assert isinstance(module, torch.nn.Module) - return list(module.named_parameters()) + list(module.named_buffers()) - -def copy_params_and_buffers(src_module, dst_module, require_all=False): - assert isinstance(src_module, torch.nn.Module) - assert isinstance(dst_module, torch.nn.Module) - src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} - for name, tensor in named_params_and_buffers(dst_module): - assert (name in src_tensors) or (not require_all) - if name in src_tensors: - tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) - -#---------------------------------------------------------------------------- -# Context manager for easily enabling/disabling DistributedDataParallel -# synchronization. - -@contextlib.contextmanager -def ddp_sync(module, sync): - assert isinstance(module, torch.nn.Module) - if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): - yield - else: - with module.no_sync(): - yield diff --git a/sg3_torch_utils/ops/__init__.py b/sg3_torch_utils/ops/__init__.py deleted file mode 100755 index ece0ea08fe2e939cc260a1dafc0ab5b391b773d9..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -# empty diff --git a/sg3_torch_utils/ops/bias_act.cpp b/sg3_torch_utils/ops/bias_act.cpp deleted file mode 100755 index 5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/bias_act.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include -#include -#include "bias_act.h" - -//------------------------------------------------------------------------ - -static bool has_same_layout(torch::Tensor x, torch::Tensor y) -{ - if (x.dim() != y.dim()) - return false; - for (int64_t i = 0; i < x.dim(); i++) - { - if (x.size(i) != y.size(i)) - return false; - if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) - return false; - } - return true; -} - -//------------------------------------------------------------------------ - -static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) -{ - // Validate arguments. - TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); - TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); - TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); - TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); - TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); - TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); - TORCH_CHECK(b.dim() == 1, "b must have rank 1"); - TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); - TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); - TORCH_CHECK(grad >= 0, "grad must be non-negative"); - - // Validate layout. - TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); - TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); - TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); - TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); - TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); - - // Create output tensor. - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - torch::Tensor y = torch::empty_like(x); - TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); - - // Initialize CUDA kernel parameters. - bias_act_kernel_params p; - p.x = x.data_ptr(); - p.b = (b.numel()) ? b.data_ptr() : NULL; - p.xref = (xref.numel()) ? xref.data_ptr() : NULL; - p.yref = (yref.numel()) ? yref.data_ptr() : NULL; - p.dy = (dy.numel()) ? dy.data_ptr() : NULL; - p.y = y.data_ptr(); - p.grad = grad; - p.act = act; - p.alpha = alpha; - p.gain = gain; - p.clamp = clamp; - p.sizeX = (int)x.numel(); - p.sizeB = (int)b.numel(); - p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; - - // Choose CUDA kernel. - void* kernel; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] - { - kernel = choose_bias_act_kernel(p); - }); - TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); - - // Launch CUDA kernel. - p.loopX = 4; - int blockSize = 4 * 32; - int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; - void* args[] = {&p}; - AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); - return y; -} - -//------------------------------------------------------------------------ - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("bias_act", &bias_act); -} - -//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/bias_act.cu b/sg3_torch_utils/ops/bias_act.cu deleted file mode 100755 index dd8fc4756d7d94727f94af738665b68d9c518880..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/bias_act.cu +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include "bias_act.h" - -//------------------------------------------------------------------------ -// Helpers. - -template struct InternalType; -template <> struct InternalType { typedef double scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; - -//------------------------------------------------------------------------ -// CUDA kernel. - -template -__global__ void bias_act_kernel(bias_act_kernel_params p) -{ - typedef typename InternalType::scalar_t scalar_t; - int G = p.grad; - scalar_t alpha = (scalar_t)p.alpha; - scalar_t gain = (scalar_t)p.gain; - scalar_t clamp = (scalar_t)p.clamp; - scalar_t one = (scalar_t)1; - scalar_t two = (scalar_t)2; - scalar_t expRange = (scalar_t)80; - scalar_t halfExpRange = (scalar_t)40; - scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; - scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; - - // Loop over elements. - int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; - for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) - { - // Load. - scalar_t x = (scalar_t)((const T*)p.x)[xi]; - scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; - scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; - scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; - scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; - scalar_t yy = (gain != 0) ? yref / gain : 0; - scalar_t y = 0; - - // Apply bias. - ((G == 0) ? x : xref) += b; - - // linear - if (A == 1) - { - if (G == 0) y = x; - if (G == 1) y = x; - } - - // relu - if (A == 2) - { - if (G == 0) y = (x > 0) ? x : 0; - if (G == 1) y = (yy > 0) ? x : 0; - } - - // lrelu - if (A == 3) - { - if (G == 0) y = (x > 0) ? x : x * alpha; - if (G == 1) y = (yy > 0) ? x : x * alpha; - } - - // tanh - if (A == 4) - { - if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } - if (G == 1) y = x * (one - yy * yy); - if (G == 2) y = x * (one - yy * yy) * (-two * yy); - } - - // sigmoid - if (A == 5) - { - if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); - if (G == 1) y = x * yy * (one - yy); - if (G == 2) y = x * yy * (one - yy) * (one - two * yy); - } - - // elu - if (A == 6) - { - if (G == 0) y = (x >= 0) ? x : exp(x) - one; - if (G == 1) y = (yy >= 0) ? x : x * (yy + one); - if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); - } - - // selu - if (A == 7) - { - if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); - if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); - if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); - } - - // softplus - if (A == 8) - { - if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); - if (G == 1) y = x * (one - exp(-yy)); - if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } - } - - // swish - if (A == 9) - { - if (G == 0) - y = (x < -expRange) ? 0 : x / (exp(-x) + one); - else - { - scalar_t c = exp(xref); - scalar_t d = c + one; - if (G == 1) - y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); - else - y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); - yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; - } - } - - // Apply gain. - y *= gain * dy; - - // Clamp. - if (clamp >= 0) - { - if (G == 0) - y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; - else - y = (yref > -clamp & yref < clamp) ? y : 0; - } - - // Store. - ((T*)p.y)[xi] = (T)y; - } -} - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template void* choose_bias_act_kernel(const bias_act_kernel_params& p) -{ - if (p.act == 1) return (void*)bias_act_kernel; - if (p.act == 2) return (void*)bias_act_kernel; - if (p.act == 3) return (void*)bias_act_kernel; - if (p.act == 4) return (void*)bias_act_kernel; - if (p.act == 5) return (void*)bias_act_kernel; - if (p.act == 6) return (void*)bias_act_kernel; - if (p.act == 7) return (void*)bias_act_kernel; - if (p.act == 8) return (void*)bias_act_kernel; - if (p.act == 9) return (void*)bias_act_kernel; - return NULL; -} - -//------------------------------------------------------------------------ -// Template specializations. - -template void* choose_bias_act_kernel (const bias_act_kernel_params& p); -template void* choose_bias_act_kernel (const bias_act_kernel_params& p); -template void* choose_bias_act_kernel (const bias_act_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/bias_act.h b/sg3_torch_utils/ops/bias_act.h deleted file mode 100755 index a32187e1fb7e3bae509d4eceaf900866866875a4..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/bias_act.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -//------------------------------------------------------------------------ -// CUDA kernel parameters. - -struct bias_act_kernel_params -{ - const void* x; // [sizeX] - const void* b; // [sizeB] or NULL - const void* xref; // [sizeX] or NULL - const void* yref; // [sizeX] or NULL - const void* dy; // [sizeX] or NULL - void* y; // [sizeX] - - int grad; - int act; - float alpha; - float gain; - float clamp; - - int sizeX; - int sizeB; - int stepB; - int loopX; -}; - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template void* choose_bias_act_kernel(const bias_act_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/bias_act.py b/sg3_torch_utils/ops/bias_act.py deleted file mode 100755 index 7c39717268055fafe737419486cf96f1f93f4fb5..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/bias_act.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom PyTorch ops for efficient bias and activation.""" - -import os -import warnings -import numpy as np -import torch -import traceback - -from .. import custom_ops -from easydict import EasyDict -from torch.cuda.amp import custom_bwd, custom_fwd -#---------------------------------------------------------------------------- - -activation_funcs = { - 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), - 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), - 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), - 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), - 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), - 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), - 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), - 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), - 'swish': EasyDict(func=lambda x, **_: torch.nn.functional.silu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), -} - -#---------------------------------------------------------------------------- - -_inited = False -_plugin = None -enabled = False -_null_tensor = torch.empty([0]) - -def _init(): - global _inited, _plugin - if not _inited: - _inited = True - sources = ['bias_act.cpp', 'bias_act.cu'] - sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] - try: - _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) - except: - warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) - return _plugin is not None - -#---------------------------------------------------------------------------- - -def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): - r"""Fused bias and activation function. - - Adds bias `b` to activation tensor `x`, evaluates activation function `act`, - and scales the result by `gain`. Each of the steps is optional. In most cases, - the fused op is considerably more efficient than performing the same calculation - using standard PyTorch ops. It supports first and second order gradients, - but not third order gradients. - - Args: - x: Input activation tensor. Can be of any shape. - b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type - as `x`. The shape must be known, and it must match the dimension of `x` - corresponding to `dim`. - dim: The dimension in `x` corresponding to the elements of `b`. - The value of `dim` is ignored if `b` is not specified. - act: Name of the activation function to evaluate, or `"linear"` to disable. - Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. - See `activation_funcs` for a full list. `None` is not allowed. - alpha: Shape parameter for the activation function, or `None` to use the default. - gain: Scaling factor for the output tensor, or `None` to use default. - See `activation_funcs` for the default scaling of each activation function. - If unsure, consider specifying 1. - clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable - the clamping (default). - impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). - - Returns: - Tensor of the same shape and datatype as `x`. - """ - assert isinstance(x, torch.Tensor) - assert impl in ['ref', 'cuda'] - if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init(): - return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) - return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) - -#---------------------------------------------------------------------------- - -def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): - """Slow reference implementation of `bias_act()` using standard TensorFlow ops. - """ - assert isinstance(x, torch.Tensor) - assert clamp is None or clamp >= 0 - spec = activation_funcs[act] - alpha = float(alpha if alpha is not None else spec.def_alpha) - gain = float(gain if gain is not None else spec.def_gain) - clamp = float(clamp if clamp is not None else -1) - - # Add bias. - if b is not None: - assert isinstance(b, torch.Tensor) and b.ndim == 1 - assert 0 <= dim < x.ndim - assert b.shape[0] == x.shape[dim] - x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) - - # Evaluate activation function. - alpha = float(alpha) - x = spec.func(x, alpha=alpha) - - # Scale by gain. - gain = float(gain) - if gain != 1: - x = x * gain - - # Clamp. - if clamp >= 0: - x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type - return x - -#---------------------------------------------------------------------------- - -_bias_act_cuda_cache = dict() - -def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): - """Fast CUDA implementation of `bias_act()` using custom ops. - """ - # Parse arguments. - assert clamp is None or clamp >= 0 - spec = activation_funcs[act] - alpha = float(alpha if alpha is not None else spec.def_alpha) - gain = float(gain if gain is not None else spec.def_gain) - clamp = float(clamp if clamp is not None else -1) - - # Lookup from cache. - key = (dim, act, alpha, gain, clamp) - if key in _bias_act_cuda_cache: - return _bias_act_cuda_cache[key] - - # Forward op. - class BiasActCuda(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, x, b): # pylint: disable=arguments-differ - ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format - x = x.contiguous(memory_format=ctx.memory_format) - b = b.contiguous() if b is not None else _null_tensor - y = x - if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: - y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) - ctx.save_for_backward( - x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, - b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, - y if 'y' in spec.ref else _null_tensor) - return y - - @staticmethod - @custom_bwd - def backward(ctx, dy): # pylint: disable=arguments-differ - dy = dy.contiguous(memory_format=ctx.memory_format) - x, b, y = ctx.saved_tensors - dx = None - db = None - - if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: - dx = dy - if act != 'linear' or gain != 1 or clamp >= 0: - dx = BiasActCudaGrad.apply(dy, x, b, y) - - if ctx.needs_input_grad[1]: - db = dx.sum([i for i in range(dx.ndim) if i != dim]) - - return dx, db - - # Backward op. - class BiasActCudaGrad(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ - ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format - dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) - ctx.save_for_backward( - dy if spec.has_2nd_grad else _null_tensor, - x, b, y) - return dx - - @staticmethod - @custom_bwd - def backward(ctx, d_dx): # pylint: disable=arguments-differ - d_dx = d_dx.contiguous(memory_format=ctx.memory_format) - dy, x, b, y = ctx.saved_tensors - d_dy = None - d_x = None - d_b = None - d_y = None - - if ctx.needs_input_grad[0]: - d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) - - if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): - d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) - - if spec.has_2nd_grad and ctx.needs_input_grad[2]: - d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) - - return d_dy, d_x, d_b, d_y - - # Add to cache. - _bias_act_cuda_cache[key] = BiasActCuda - return BiasActCuda - -#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/ops/conv2d_gradfix.py b/sg3_torch_utils/ops/conv2d_gradfix.py deleted file mode 100755 index e66591f19fad68760d3df7c9737a14574b70ee83..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/conv2d_gradfix.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom replacement for `torch.nn.functional.conv2d` that supports -arbitrarily high order gradients with zero performance penalty.""" - -import warnings -import contextlib -import torch -from torch.cuda.amp import custom_bwd, custom_fwd - -# pylint: disable=redefined-builtin -# pylint: disable=arguments-differ -# pylint: disable=protected-access - -#---------------------------------------------------------------------------- - -enabled = False # Enable the custom op by setting this to true. -weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. - -@contextlib.contextmanager -def no_weight_gradients(): - global weight_gradients_disabled - old = weight_gradients_disabled - weight_gradients_disabled = True - yield - weight_gradients_disabled = old - -#---------------------------------------------------------------------------- - -def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if _should_use_custom_op(input): - return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) - return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) - -def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - if _should_use_custom_op(input): - return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) - return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) - -#---------------------------------------------------------------------------- - -def _should_use_custom_op(input): - assert isinstance(input, torch.Tensor) - if (not enabled) or (not torch.backends.cudnn.enabled): - return False - if input.device.type != 'cuda': - return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.10']): - return True - warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') - return False - -def _tuple_of_ints(xs, ndim): - xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim - assert len(xs) == ndim - assert all(isinstance(x, int) for x in xs) - return xs - -#---------------------------------------------------------------------------- - -_conv2d_gradfix_cache = dict() - -def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): - # Parse arguments. - ndim = 2 - weight_shape = tuple(weight_shape) - stride = _tuple_of_ints(stride, ndim) - padding = _tuple_of_ints(padding, ndim) - output_padding = _tuple_of_ints(output_padding, ndim) - dilation = _tuple_of_ints(dilation, ndim) - - # Lookup from cache. - key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) - if key in _conv2d_gradfix_cache: - return _conv2d_gradfix_cache[key] - - # Validate arguments. - assert groups >= 1 - assert len(weight_shape) == ndim + 2 - assert all(stride[i] >= 1 for i in range(ndim)) - assert all(padding[i] >= 0 for i in range(ndim)) - assert all(dilation[i] >= 0 for i in range(ndim)) - if not transpose: - assert all(output_padding[i] == 0 for i in range(ndim)) - else: # transpose - assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) - - # Helpers. - common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) - def calc_output_padding(input_shape, output_shape): - if transpose: - return [0, 0] - return [ - input_shape[i + 2] - - (output_shape[i + 2] - 1) * stride[i] - - (1 - 2 * padding[i]) - - dilation[i] * (weight_shape[i + 2] - 1) - for i in range(ndim) - ] - - # Forward & backward. - class Conv2d(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input, weight, bias): - assert weight.shape == weight_shape - if not transpose: - output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) - else: # transpose - output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) - ctx.save_for_backward(input, weight) - return output - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - grad_input = None - grad_weight = None - grad_bias = None - - if ctx.needs_input_grad[0]: - p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) - grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output.float(), weight.float(), None) - assert grad_input.shape == input.shape - - if ctx.needs_input_grad[1] and not weight_gradients_disabled: - grad_weight = Conv2dGradWeight.apply(grad_output.float(), input.float()) - assert grad_weight.shape == weight_shape - - if ctx.needs_input_grad[2]: - grad_bias = grad_output.float().sum([0, 2, 3]) - - return grad_input, grad_weight, grad_bias - - # Gradient with respect to the weights. - class Conv2dGradWeight(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, grad_output, input): - op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') - flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] - grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) - assert grad_weight.shape == weight_shape - ctx.save_for_backward(grad_output, input) - return grad_weight - - @staticmethod - @custom_bwd - def backward(ctx, grad2_grad_weight): - grad_output, input = ctx.saved_tensors - grad2_grad_output = None - grad2_input = None - - if ctx.needs_input_grad[0]: - grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) - assert grad2_grad_output.shape == grad_output.shape - - if ctx.needs_input_grad[1]: - p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) - grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) - assert grad2_input.shape == input.shape - - return grad2_grad_output, grad2_input - - _conv2d_gradfix_cache[key] = Conv2d - return Conv2d - -#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/ops/conv2d_resample.py b/sg3_torch_utils/ops/conv2d_resample.py deleted file mode 100755 index 4a999b58b36a5da53752024e86a9ebdf9c031d97..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/conv2d_resample.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""2D convolution with optional up/downsampling.""" - -import torch - -from .. import misc -from . import conv2d_gradfix -from . import upfirdn2d -from .upfirdn2d import _parse_padding -from .upfirdn2d import _get_filter_size - -#---------------------------------------------------------------------------- - -def _get_weight_shape(w): - with misc.suppress_tracer_warnings(): # this value will be treated as a constant - shape = [int(sz) for sz in w.shape] - misc.assert_shape(w, shape) - return shape - -#---------------------------------------------------------------------------- - -def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): - """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. - """ - out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) - - # Flip weight if requested. - if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). - w = w.flip([2, 3]) - - # Otherwise => execute using conv2d_gradfix. - op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d - return op(x, w, stride=stride, padding=padding, groups=groups) - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): - r"""2D convolution with optional up/downsampling. - - Padding is performed only once at the beginning, not between the operations. - - Args: - x: Input tensor of shape - `[batch_size, in_channels, in_height, in_width]`. - w: Weight tensor of shape - `[out_channels, in_channels//groups, kernel_height, kernel_width]`. - f: Low-pass filter for up/downsampling. Must be prepared beforehand by - calling upfirdn2d.setup_filter(). None = identity (default). - up: Integer upsampling factor (default: 1). - down: Integer downsampling factor (default: 1). - padding: Padding with respect to the upsampled image. Can be a single number - or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - groups: Split input channels into N groups (default: 1). - flip_weight: False = convolution, True = correlation (default: True). - flip_filter: False = convolution, True = correlation (default: False). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - # Validate arguments. - assert isinstance(x, torch.Tensor) and (x.ndim == 4) - assert isinstance(w, torch.Tensor) and (w.ndim == 4) - assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) - assert isinstance(up, int) and (up >= 1) - assert isinstance(down, int) and (down >= 1) - assert isinstance(groups, int) and (groups >= 1) - out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) - fw, fh = _get_filter_size(f) - px0, px1, py0, py1 = _parse_padding(padding) - - # Adjust padding to account for up/downsampling. - if up > 1: - px0 += (fw + up - 1) // 2 - px1 += (fw - up) // 2 - py0 += (fh + up - 1) // 2 - py1 += (fh - up) // 2 - if down > 1: - px0 += (fw - down + 1) // 2 - px1 += (fw - down) // 2 - py0 += (fh - down + 1) // 2 - py1 += (fh - down) // 2 - - # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. - if kw == 1 and kh == 1 and (down > 1 and up == 1): - x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - return x - - # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. - if kw == 1 and kh == 1 and (up > 1 and down == 1): - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) - return x - - # Fast path: downsampling only => use strided convolution. - if down > 1 and up == 1: - x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) - x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) - return x - - # Fast path: upsampling with optional downsampling => use transpose strided convolution. - if up > 1: - if groups == 1: - w = w.transpose(0, 1) - else: - w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) - w = w.transpose(1, 2) - w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) - px0 -= kw - 1 - px1 -= kw - up - py0 -= kh - 1 - py1 -= kh - up - pxt = max(min(-px0, -px1), 0) - pyt = max(min(-py0, -py1), 0) - x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) - x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) - if down > 1: - x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) - return x - - # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. - if up == 1 and down == 1: - if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: - return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) - - # Fallback: Generic reference implementation. - x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - if down > 1: - x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) - return x - -#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/ops/fma.py b/sg3_torch_utils/ops/fma.py deleted file mode 100755 index b4e8ef9169440d4c3bd95befae7d26e3c1e1f017..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/fma.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" - -import torch -from torch.cuda.amp import custom_bwd, custom_fwd - -#---------------------------------------------------------------------------- - -def fma(a, b, c): # => a * b + c - return _FusedMultiplyAdd.apply(a, b, c) - -#---------------------------------------------------------------------------- - -class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, a, b, c): # pylint: disable=arguments-differ - out = torch.addcmul(c, a, b) - ctx.save_for_backward(a, b) - ctx.c_shape = c.shape - return out - - @staticmethod - @custom_bwd - def backward(ctx, dout): # pylint: disable=arguments-differ - a, b = ctx.saved_tensors - c_shape = ctx.c_shape - da = None - db = None - dc = None - - if ctx.needs_input_grad[0]: - da = _unbroadcast(dout * b, a.shape) - - if ctx.needs_input_grad[1]: - db = _unbroadcast(dout * a, b.shape) - - if ctx.needs_input_grad[2]: - dc = _unbroadcast(dout, c_shape) - - return da, db, dc - -#---------------------------------------------------------------------------- - -def _unbroadcast(x, shape): - extra_dims = x.ndim - len(shape) - assert extra_dims >= 0 - dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] - if len(dim): - x = x.sum(dim=dim, keepdim=True) - if extra_dims: - x = x.reshape(-1, *x.shape[extra_dims+1:]) - assert x.shape == shape - return x - -#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/ops/grid_sample_gradfix.py b/sg3_torch_utils/ops/grid_sample_gradfix.py deleted file mode 100755 index 87067e150c591b1ace91816e7a5c3ee3a4aeacd3..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/grid_sample_gradfix.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom replacement for `torch.nn.functional.grid_sample` that -supports arbitrarily high order gradients between the input and output. -Only works on 2D images and assumes -`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" - -import torch -from torch.cuda.amp import custom_bwd, custom_fwd -from pkg_resources import parse_version -# pylint: disable=redefined-builtin -# pylint: disable=arguments-differ -# pylint: disable=protected-access -_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 - - -#---------------------------------------------------------------------------- - -enabled = False # Enable the custom op by setting this to true. - -#---------------------------------------------------------------------------- - -def grid_sample(input, grid): - if _should_use_custom_op(): - return _GridSample2dForward.apply(input, grid) - return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) - -#---------------------------------------------------------------------------- - -def _should_use_custom_op(): - return enabled - -#---------------------------------------------------------------------------- - -class _GridSample2dForward(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input, grid): - assert input.ndim == 4 - assert grid.ndim == 4 - output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) - ctx.save_for_backward(input, grid) - return output - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - input, grid = ctx.saved_tensors - grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) - return grad_input, grad_grid - -#---------------------------------------------------------------------------- - -class _GridSample2dBackward(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, grad_output, input, grid): - op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') - if _use_pytorch_1_11_api: - output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) - grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) - else: - grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) - ctx.save_for_backward(grid) - return grad_input, grad_grid - - @staticmethod - @custom_bwd - def backward(ctx, grad2_grad_input, grad2_grad_grid): - _ = grad2_grad_grid # unused - grid, = ctx.saved_tensors - grad2_grad_output = None - grad2_input = None - grad2_grid = None - - if ctx.needs_input_grad[0]: - grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) - - assert not ctx.needs_input_grad[2] - return grad2_grad_output, grad2_input, grad2_grid - -#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/ops/upfirdn2d.cpp b/sg3_torch_utils/ops/upfirdn2d.cpp deleted file mode 100755 index 2d7177fc60040751d20e9a8da0301fa3ab64968a..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/upfirdn2d.cpp +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include -#include -#include "upfirdn2d.h" - -//------------------------------------------------------------------------ - -static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) -{ - // Validate arguments. - TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); - TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); - TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); - TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); - TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); - TORCH_CHECK(x.dim() == 4, "x must be rank 4"); - TORCH_CHECK(f.dim() == 2, "f must be rank 2"); - TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); - TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); - TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); - - // Create output tensor. - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; - int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; - TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); - torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); - TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); - - // Initialize CUDA kernel parameters. - upfirdn2d_kernel_params p; - p.x = x.data_ptr(); - p.f = f.data_ptr(); - p.y = y.data_ptr(); - p.up = make_int2(upx, upy); - p.down = make_int2(downx, downy); - p.pad0 = make_int2(padx0, pady0); - p.flip = (flip) ? 1 : 0; - p.gain = gain; - p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); - p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); - p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); - p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); - p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); - p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); - p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; - p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; - - // Choose CUDA kernel. - upfirdn2d_kernel_spec spec; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] - { - spec = choose_upfirdn2d_kernel(p); - }); - - // Set looping options. - p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; - p.loopMinor = spec.loopMinor; - p.loopX = spec.loopX; - p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; - p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; - - // Compute grid size. - dim3 blockSize, gridSize; - if (spec.tileOutW < 0) // large - { - blockSize = dim3(4, 32, 1); - gridSize = dim3( - ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, - (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, - p.launchMajor); - } - else // small - { - blockSize = dim3(256, 1, 1); - gridSize = dim3( - ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, - (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, - p.launchMajor); - } - - // Launch CUDA kernel. - void* args[] = {&p}; - AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); - return y; -} - -//------------------------------------------------------------------------ - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("upfirdn2d", &upfirdn2d); -} - -//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/upfirdn2d.cu b/sg3_torch_utils/ops/upfirdn2d.cu deleted file mode 100755 index ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/upfirdn2d.cu +++ /dev/null @@ -1,350 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include "upfirdn2d.h" - -//------------------------------------------------------------------------ -// Helpers. - -template struct InternalType; -template <> struct InternalType { typedef double scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; - -static __device__ __forceinline__ int floor_div(int a, int b) -{ - int t = 1 - a / b; - return (a + t * b) / b - t; -} - -//------------------------------------------------------------------------ -// Generic CUDA implementation for large filters. - -template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) -{ - typedef typename InternalType::scalar_t scalar_t; - - // Calculate thread index. - int minorBase = blockIdx.x * blockDim.x + threadIdx.x; - int outY = minorBase / p.launchMinor; - minorBase -= outY * p.launchMinor; - int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; - int majorBase = blockIdx.z * p.loopMajor; - if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) - return; - - // Setup Y receptive field. - int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; - int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); - int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; - int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; - if (p.flip) - filterY = p.filterSize.y - 1 - filterY; - - // Loop over major, minor, and X. - for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) - for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) - { - int nc = major * p.sizeMinor + minor; - int n = nc / p.inSize.z; - int c = nc - n * p.inSize.z; - for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) - { - // Setup X receptive field. - int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; - int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); - int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; - int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; - if (p.flip) - filterX = p.filterSize.x - 1 - filterX; - - // Initialize pointers. - const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; - const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; - int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; - int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; - - // Inner loop. - scalar_t v = 0; - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - v += (scalar_t)(*xp) * (scalar_t)(*fp); - xp += p.inStride.x; - fp += filterStepX; - } - xp += p.inStride.y - w * p.inStride.x; - fp += filterStepY - w * filterStepX; - } - - // Store result. - v *= p.gain; - ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; - } - } -} - -//------------------------------------------------------------------------ -// Specialized CUDA implementation for small filters. - -template -static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) -{ - typedef typename InternalType::scalar_t scalar_t; - const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; - const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; - __shared__ volatile scalar_t sf[filterH][filterW]; - __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; - - // Calculate tile index. - int minorBase = blockIdx.x; - int tileOutY = minorBase / p.launchMinor; - minorBase -= tileOutY * p.launchMinor; - minorBase *= loopMinor; - tileOutY *= tileOutH; - int tileOutXBase = blockIdx.y * p.loopX * tileOutW; - int majorBase = blockIdx.z * p.loopMajor; - if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) - return; - - // Load filter (flipped). - for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) - { - int fy = tapIdx / filterW; - int fx = tapIdx - fy * filterW; - scalar_t v = 0; - if (fx < p.filterSize.x & fy < p.filterSize.y) - { - int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; - int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; - v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; - } - sf[fy][fx] = v; - } - - // Loop over major and X. - for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) - { - int baseNC = major * p.sizeMinor + minorBase; - int n = baseNC / p.inSize.z; - int baseC = baseNC - n * p.inSize.z; - for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) - { - // Load input pixels. - int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; - int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; - int tileInX = floor_div(tileMidX, upx); - int tileInY = floor_div(tileMidY, upy); - __syncthreads(); - for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) - { - int relC = inIdx; - int relInX = relC / loopMinor; - int relInY = relInX / tileInW; - relC -= relInX * loopMinor; - relInX -= relInY * tileInW; - int c = baseC + relC; - int inX = tileInX + relInX; - int inY = tileInY + relInY; - scalar_t v = 0; - if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) - v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; - sx[relInY][relInX][relC] = v; - } - - // Loop over output pixels. - __syncthreads(); - for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) - { - int relC = outIdx; - int relOutX = relC / loopMinor; - int relOutY = relOutX / tileOutW; - relC -= relOutX * loopMinor; - relOutX -= relOutY * tileOutW; - int c = baseC + relC; - int outX = tileOutX + relOutX; - int outY = tileOutY + relOutY; - - // Setup receptive field. - int midX = tileMidX + relOutX * downx; - int midY = tileMidY + relOutY * downy; - int inX = floor_div(midX, upx); - int inY = floor_div(midY, upy); - int relInX = inX - tileInX; - int relInY = inY - tileInY; - int filterX = (inX + 1) * upx - midX - 1; // flipped - int filterY = (inY + 1) * upy - midY - 1; // flipped - - // Inner loop. - if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) - { - scalar_t v = 0; - #pragma unroll - for (int y = 0; y < filterH / upy; y++) - #pragma unroll - for (int x = 0; x < filterW / upx; x++) - v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; - v *= p.gain; - ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; - } - } - } - } -} - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) -{ - int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; - - upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous - if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last - - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - } - if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - } - if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - } - if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - } - if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - } - return spec; -} - -//------------------------------------------------------------------------ -// Template specializations. - -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/upfirdn2d.h b/sg3_torch_utils/ops/upfirdn2d.h deleted file mode 100755 index c9e2032bcac9d2abde7a75eea4d812da348afadd..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/upfirdn2d.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include - -//------------------------------------------------------------------------ -// CUDA kernel parameters. - -struct upfirdn2d_kernel_params -{ - const void* x; - const float* f; - void* y; - - int2 up; - int2 down; - int2 pad0; - int flip; - float gain; - - int4 inSize; // [width, height, channel, batch] - int4 inStride; - int2 filterSize; // [width, height] - int2 filterStride; - int4 outSize; // [width, height, channel, batch] - int4 outStride; - int sizeMinor; - int sizeMajor; - - int loopMinor; - int loopMajor; - int loopX; - int launchMinor; - int launchMajor; -}; - -//------------------------------------------------------------------------ -// CUDA kernel specialization. - -struct upfirdn2d_kernel_spec -{ - void* kernel; - int tileOutW; - int tileOutH; - int loopMinor; - int loopX; -}; - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/upfirdn2d.py b/sg3_torch_utils/ops/upfirdn2d.py deleted file mode 100755 index a0bbd22d245481e7c5a19315e5cb3242b1278787..0000000000000000000000000000000000000000 --- a/sg3_torch_utils/ops/upfirdn2d.py +++ /dev/null @@ -1,388 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom PyTorch ops for efficient resampling of 2D images.""" - -import os -import warnings -import numpy as np -import torch -import traceback - -from .. import custom_ops -from .. import misc -from . import conv2d_gradfix -from torch.cuda.amp import custom_bwd, custom_fwd - -#---------------------------------------------------------------------------- - -_inited = False -_plugin = None -enabled = False - -def _init(): - global _inited, _plugin - if not _inited: - sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] - sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] - try: - _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) - except: - warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) - return _plugin is not None - -def _parse_scaling(scaling): - if isinstance(scaling, int): - scaling = [scaling, scaling] - assert isinstance(scaling, (list, tuple)) - assert all(isinstance(x, int) for x in scaling) - sx, sy = scaling - assert sx >= 1 and sy >= 1 - return sx, sy - -def _parse_padding(padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, int) for x in padding) - if len(padding) == 2: - padx, pady = padding - padding = [padx, padx, pady, pady] - padx0, padx1, pady0, pady1 = padding - return padx0, padx1, pady0, pady1 - -def _get_filter_size(f): - if f is None: - return 1, 1 - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - fw = f.shape[-1] - fh = f.shape[0] - with misc.suppress_tracer_warnings(): - fw = int(fw) - fh = int(fh) - misc.assert_shape(f, [fh, fw][:f.ndim]) - assert fw >= 1 and fh >= 1 - return fw, fh - -#---------------------------------------------------------------------------- - -def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): - r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. - - Args: - f: Torch tensor, numpy array, or python list of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), - `[]` (impulse), or - `None` (identity). - device: Result device (default: cpu). - normalize: Normalize the filter so that it retains the magnitude - for constant input signal (DC)? (default: True). - flip_filter: Flip the filter? (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - separable: Return a separable filter? (default: select automatically). - - Returns: - Float32 tensor of the shape - `[filter_height, filter_width]` (non-separable) or - `[filter_taps]` (separable). - """ - # Validate. - if f is None: - f = 1 - f = torch.as_tensor(f, dtype=torch.float32) - assert f.ndim in [0, 1, 2] - assert f.numel() > 0 - if f.ndim == 0: - f = f[np.newaxis] - - # Separable? - if separable is None: - separable = (f.ndim == 1 and f.numel() >= 8) - if f.ndim == 1 and not separable: - f = f.ger(f) - assert f.ndim == (1 if separable else 2) - - # Apply normalize, flip, gain, and device. - if normalize: - f /= f.sum() - if flip_filter: - f = f.flip(list(range(f.ndim))) - f = f * (gain ** (f.ndim / 2)) - f = f.to(device=device) - return f - -#---------------------------------------------------------------------------- - -def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Pad, upsample, filter, and downsample a batch of 2D images. - - Performs the following sequence of operations for each channel: - - 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). - - 2. Pad the image with the specified number of zeros on each side (`padding`). - Negative padding corresponds to cropping the image. - - 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it - so that the footprint of all output pixels lies within the input image. - - 4. Downsample the image by keeping every Nth pixel (`down`). - - This sequence of operations bears close resemblance to scipy.signal.upfirdn(). - The fused op is considerably more efficient than performing the same calculation - using standard PyTorch ops. It supports gradients of arbitrary order. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - up: Integer upsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - down: Integer downsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the upsampled image. Can be a single number - or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - assert isinstance(x, torch.Tensor) - assert impl in ['ref', 'cuda'] - if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init(): - return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) - return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. - """ - # Validate arguments. - assert isinstance(x, torch.Tensor) and x.ndim == 4 - if f is None: - f = torch.ones([1, 1], dtype=torch.float32, device=x.device) - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - assert f.dtype == torch.float32 and not f.requires_grad - batch_size, num_channels, in_height, in_width = x.shape - upx, upy = _parse_scaling(up) - downx, downy = _parse_scaling(down) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - - # Upsample by inserting zeros. - x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) - x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) - x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) - - # Pad or crop. - x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) - x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] - - # Setup filter. - f = f * (gain ** (f.ndim / 2)) - f = f.to(x.dtype) - if not flip_filter: - f = f.flip(list(range(f.ndim))) - - # Convolve with the filter. - f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) - if f.ndim == 4: - x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) - else: - x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) - x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) - - # Downsample by throwing away pixels. - x = x[:, :, ::downy, ::downx] - return x - -#---------------------------------------------------------------------------- - -_upfirdn2d_cuda_cache = dict() - -def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): - """Fast CUDA implementation of `upfirdn2d()` using custom ops. - """ - # Parse arguments. - upx, upy = _parse_scaling(up) - downx, downy = _parse_scaling(down) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - - # Lookup from cache. - key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) - if key in _upfirdn2d_cuda_cache: - return _upfirdn2d_cuda_cache[key] - - # Forward op. - class Upfirdn2dCuda(torch.autograd.Function): - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, x, f): # pylint: disable=arguments-differ - assert isinstance(x, torch.Tensor) and x.ndim == 4 - if f is None: - f = torch.ones([1, 1], dtype=torch.float32, device=x.device) - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - y = x - if f.ndim == 2: - y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) - else: - y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) - y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) - ctx.save_for_backward(f) - ctx.x_shape = x.shape - return y - - @staticmethod - @custom_bwd - def backward(ctx, dy): # pylint: disable=arguments-differ - f, = ctx.saved_tensors - _, _, ih, iw = ctx.x_shape - _, _, oh, ow = dy.shape - fw, fh = _get_filter_size(f) - p = [ - fw - padx0 - 1, - iw * upx - ow * downx + padx0 - upx + 1, - fh - pady0 - 1, - ih * upy - oh * downy + pady0 - upy + 1, - ] - dx = None - df = None - - if ctx.needs_input_grad[0]: - dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) - - assert not ctx.needs_input_grad[1] - return dx, df - - # Add to cache. - _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda - return Upfirdn2dCuda - -#---------------------------------------------------------------------------- - -def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Filter a batch of 2D images using the given 2D FIR filter. - - By default, the result is padded so that its shape matches the input. - User-specified padding is applied on top of that, with negative values - indicating cropping. Pixels outside the image are assumed to be zero. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - padding: Padding with respect to the output. Can be a single number or a - list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - padx0, padx1, pady0, pady1 = _parse_padding(padding) - fw, fh = _get_filter_size(f) - p = [ - padx0 + fw // 2, - padx1 + (fw - 1) // 2, - pady0 + fh // 2, - pady1 + (fh - 1) // 2, - ] - return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) - -#---------------------------------------------------------------------------- - -def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Upsample a batch of 2D images using the given 2D FIR filter. - - By default, the result is padded so that its shape is a multiple of the input. - User-specified padding is applied on top of that, with negative values - indicating cropping. Pixels outside the image are assumed to be zero. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - up: Integer upsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the output. Can be a single number or a - list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - upx, upy = _parse_scaling(up) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - fw, fh = _get_filter_size(f) - p = [ - padx0 + (fw + upx - 1) // 2, - padx1 + (fw - upx) // 2, - pady0 + (fh + upy - 1) // 2, - pady1 + (fh - upy) // 2, - ] - return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) - -#---------------------------------------------------------------------------- - -def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Downsample a batch of 2D images using the given 2D FIR filter. - - By default, the result is padded so that its shape is a fraction of the input. - User-specified padding is applied on top of that, with negative values - indicating cropping. Pixels outside the image are assumed to be zero. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - down: Integer downsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the input. Can be a single number or a - list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - downx, downy = _parse_scaling(down) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - fw, fh = _get_filter_size(f) - p = [ - padx0 + (fw - downx + 1) // 2, - padx1 + (fw - downx) // 2, - pady0 + (fh - downy + 1) // 2, - pady1 + (fh - downy) // 2, - ] - return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) - -#----------------------------------------------------------------------------