diff --git a/.gitattributes b/.gitattributes
index 1d9109d5d3f4374bb8bbe54d3b39527374d8529f..ad94a41e4af773d90e30d921cdec035d16a142cc 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
erling.jpg filter=lfs diff=lfs merge=lfs -text
*.jpg filter=lfs diff=lfs merge=lfs -text
+torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth filter=lfs diff=lfs merge=lfs -text
+torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..a8736fdf09f983f34b0bd35e5d78eb1083f98958
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "deep_privacy2"]
+ path = deep_privacy2
+ url = https://github.com/hukkelas/deep_privacy2
diff --git a/app.py b/app.py
index c7ea0d4fcb0831179e845e297f21092a4a9ce3b5..f2e73970b9ce61a7170e08ea50d05e1d6f54e9af 100644
--- a/app.py
+++ b/app.py
@@ -1,78 +1,37 @@
+import gradio
+import sys
import os
+from pathlib import Path
+from tops.config import instantiate
+import gradio.inputs
os.system("pip install --upgrade pip")
os.system("pip install ftfy regex tqdm")
-os.system("pip install git+https://github.com/openai/CLIP.git")
+os.system("pip install --no-deps git+https://github.com/openai/CLIP.git")
os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
-os.system("pip install git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
-import gradio
-import numpy as np
-import torch
-from PIL import Image
+os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
+sys.path.insert(0, Path(os.getcwd(), "deep_privacy2"))
+os.environ["TORCH_HOME"] = "torch_home"
from dp2 import utils
-from tops.config import instantiate
-import tops
-import gradio.inputs
-
-
-cfg_body = utils.load_config("configs/anonymizers/FB_cse.py")
-anonymizer_body = instantiate(cfg_body.anonymizer, load_cache=False)
-anonymizer_body.initialize_tracker(fps=1)
-cfg_face = utils.load_config("configs/anonymizers/face.py")
-anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
-anonymizer_face.initialize_tracker(fps=1)
-
-
-class ExampleDemo:
+from gradio_demos.modules import ExampleDemo, WebcamDemo
- def __init__(self, anonymizer, multi_modal_truncation=False) -> None:
- self.multi_modal_truncation = multi_modal_truncation
- self.anonymizer = anonymizer
- with gradio.Row():
- input_image = gradio.Image(type="pil", label="Upload your image or try the example below!")
- output_image = gradio.Image(type="numpy", label="Output")
- with gradio.Row():
- update_btn = gradio.Button("Update Anonymization").style(full_width=True)
- visualize_det = gradio.Checkbox(value=False, label="Show Detections")
- visualize_det.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
- gradio.Examples(
- ["erling.jpg", "g7-summit-leaders-distraction.jpg"], inputs=[input_image]
- )
- update_btn.click(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
- input_image.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
- self.track = False
+cfg_face = utils.load_config("deep_privacy2/configs/anonymizers/face.py")
+for key in ["person_G_cfg", "cse_person_G_cfg", "face_G_cfg", "car_G_cfg"]:
+ if key in cfg_face.anonymizer:
+ cfg_face.anonymizer[key] = Path("deep_privacy2", cfg_face.anonymizer[key])
- def anonymize(self, img: Image, visualize_detection: bool):
- img, cache_id = pil2torch(img)
- img = tops.to_cuda(img)
- if visualize_detection:
- img = self.anonymizer.visualize_detection(img, cache_id=cache_id)
- else:
- img = self.anonymizer(
- img, truncation_value=0 if self.multi_modal_truncation else 1, multi_modal_truncation=self.multi_modal_truncation, amp=True,
- cache_id=cache_id, track=self.track)
- img = utils.im2numpy(img)
- return img
+anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
-def pil2torch(img: Image.Image):
- img = img.convert("RGB")
- img = np.array(img)
- img = np.rollaxis(img, 2)
- return torch.from_numpy(img), None
+anonymizer_face.initialize_tracker(fps=1)
with gradio.Blocks() as demo:
gradio.Markdown("#
DeepPrivacy2 - Realistic Image Anonymization ")
gradio.Markdown("### Håkon Hukkelås, Rudolf Mester, Frank Lindseth ")
- gradio.Markdown(" DeepPrivacy2 is a toolbox for realistic anonymization of humans, including a face and a full-body anonymizer. ")
gradio.Markdown(" See more information at: https://github.com/hukkelas/deep_privacy2 ")
+ with gradio.Tab("Face Anonymization"):
+ ExampleDemo(anonymizer_face)
+ with gradio.Tab("Live Webcam"):
+ WebcamDemo(anonymizer_face)
-
-
- with gradio.Tab("Full-Body Anonymization"):
- ExampleDemo(anonymizer_body, multi_modal_truncation=True)
- with gradio.Tab("Face Anonymization"):
- ExampleDemo(anonymizer_face, multi_modal_truncation=False)
-
-
-demo.launch()
\ No newline at end of file
+demo.launch()
diff --git a/configs/anonymizers/FB_cse.py b/configs/anonymizers/FB_cse.py
deleted file mode 100644
index ff44a8ef9da980d545b09609de82e071edc912ac..0000000000000000000000000000000000000000
--- a/configs/anonymizers/FB_cse.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from dp2.anonymizer import Anonymizer
-from dp2.detection.person_detector import CSEPersonDetector
-from ..defaults import common
-from tops.config import LazyCall as L
-from dp2.generator.dummy_generators import MaskOutGenerator
-
-
-maskout_G = L(MaskOutGenerator)(noise="constant")
-
-detector = L(CSEPersonDetector)(
- mask_rcnn_cfg=dict(),
- cse_cfg=dict(),
- cse_post_process_cfg=dict(
- target_imsize=(288, 160),
- exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
- exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
- iou_combine_threshold=0.4,
- dilation_percentage=0.02,
- normalize_embedding=False
- ),
- score_threshold=0.3,
- cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
-)
-
-anonymizer = L(Anonymizer)(
- detector="${detector}",
- cse_person_G_cfg="configs/fdh/styleganL.py",
-)
diff --git a/configs/anonymizers/FB_cse_mask.py b/configs/anonymizers/FB_cse_mask.py
deleted file mode 100644
index ff5e3bfbefad8e1d6e480fa22256aff0f9647b35..0000000000000000000000000000000000000000
--- a/configs/anonymizers/FB_cse_mask.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from dp2.anonymizer import Anonymizer
-from dp2.detection.person_detector import CSEPersonDetector
-from ..defaults import common
-from tops.config import LazyCall as L
-from dp2.generator.dummy_generators import MaskOutGenerator
-
-
-maskout_G = L(MaskOutGenerator)(noise="constant")
-
-detector = L(CSEPersonDetector)(
- mask_rcnn_cfg=dict(),
- cse_cfg=dict(),
- cse_post_process_cfg=dict(
- target_imsize=(288, 160),
- exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
- exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
- iou_combine_threshold=0.4,
- dilation_percentage=0.02,
- normalize_embedding=False
- ),
- score_threshold=0.3,
- cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
-)
-
-anonymizer = L(Anonymizer)(
- detector="${detector}",
- person_G_cfg="configs/fdh/styleganL_nocse.py",
- cse_person_G_cfg="configs/fdh/styleganL.py",
-)
diff --git a/configs/anonymizers/FB_cse_mask_face.py b/configs/anonymizers/FB_cse_mask_face.py
deleted file mode 100644
index 8d6e4eef72111c0d8f427631926a7465a9cb174d..0000000000000000000000000000000000000000
--- a/configs/anonymizers/FB_cse_mask_face.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from dp2.anonymizer import Anonymizer
-from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector
-from ..defaults import common
-from tops.config import LazyCall as L
-
-detector = L(CSeMaskFaceDetector)(
- mask_rcnn_cfg=dict(),
- face_detector_cfg=dict(),
- face_post_process_cfg=dict(target_imsize=(256, 256)),
- cse_cfg=dict(),
- cse_post_process_cfg=dict(
- target_imsize=(288, 160),
- exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
- exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
- iou_combine_threshold=0.4,
- dilation_percentage=0.02,
- normalize_embedding=False
- ),
- score_threshold=0.3,
- cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache")
-)
-
-anonymizer = L(Anonymizer)(
- detector="${detector}",
- face_G_cfg="configs/fdf/stylegan.py",
- person_G_cfg="configs/fdh/styleganL_nocse.py",
- cse_person_G_cfg="configs/fdh/styleganL.py",
- car_G_cfg="configs/generators/dummy/pixelation8.py"
-)
diff --git a/configs/anonymizers/face.py b/configs/anonymizers/face.py
deleted file mode 100644
index 8c39560aad3c8c3f40592bb58850065994860b46..0000000000000000000000000000000000000000
--- a/configs/anonymizers/face.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from dp2.anonymizer import Anonymizer
-from dp2.detection.face_detector import FaceDetector
-from ..defaults import common
-from tops.config import LazyCall as L
-
-
-detector = L(FaceDetector)(
- face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
- face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
- score_threshold=0.3,
- cache_directory=common.output_dir.joinpath("face_detection_cache")
-)
-
-
-anonymizer = L(Anonymizer)(
- detector="${detector}",
- face_G_cfg="configs/fdf/stylegan.py",
-)
diff --git a/configs/anonymizers/market1501/blackout.py b/configs/anonymizers/market1501/blackout.py
deleted file mode 100644
index 14da21e3c4b367a942f9a99796a1d9996b773522..0000000000000000000000000000000000000000
--- a/configs/anonymizers/market1501/blackout.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from ..FB_cse_mask_face import anonymizer, detector, common
-
-detector.score_threshold = .1
-detector.face_detector_cfg.confidence_threshold = .5
-detector.cse_cfg.score_thres = 0.3
-anonymizer.generators.face_G_cfg = None
-anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py"
-anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py"
\ No newline at end of file
diff --git a/configs/anonymizers/market1501/person.py b/configs/anonymizers/market1501/person.py
deleted file mode 100644
index 51fa99b21f068ce68f796fd32c85d37d9a22bec1..0000000000000000000000000000000000000000
--- a/configs/anonymizers/market1501/person.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from ..FB_cse_mask_face import anonymizer, detector, common
-
-detector.score_threshold = .1
-detector.face_detector_cfg.confidence_threshold = .5
-detector.cse_cfg.score_thres = 0.3
-anonymizer.generators.face_G_cfg = None
\ No newline at end of file
diff --git a/configs/anonymizers/market1501/pixelation16.py b/configs/anonymizers/market1501/pixelation16.py
deleted file mode 100644
index 2569fc2abb91919f91dd12546c06a86624d235fc..0000000000000000000000000000000000000000
--- a/configs/anonymizers/market1501/pixelation16.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from ..FB_cse_mask_face import anonymizer, detector, common
-
-detector.score_threshold = .1
-detector.face_detector_cfg.confidence_threshold = .5
-detector.cse_cfg.score_thres = 0.3
-anonymizer.generators.face_G_cfg = None
-anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py"
-anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py"
\ No newline at end of file
diff --git a/configs/anonymizers/market1501/pixelation8.py b/configs/anonymizers/market1501/pixelation8.py
deleted file mode 100644
index ef49cb613d09e972adf7b8136b632eb210420686..0000000000000000000000000000000000000000
--- a/configs/anonymizers/market1501/pixelation8.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from ..FB_cse_mask_face import anonymizer, detector, common
-
-detector.score_threshold = .1
-detector.face_detector_cfg.confidence_threshold = .5
-detector.cse_cfg.score_thres = 0.3
-anonymizer.generators.face_G_cfg = None
-anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py"
-anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py"
\ No newline at end of file
diff --git a/configs/datasets/coco_cse.py b/configs/datasets/coco_cse.py
deleted file mode 100644
index 00582b39cb473c7cad1ec95ce7361ba9c25939b4..0000000000000000000000000000000000000000
--- a/configs/datasets/coco_cse.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import os
-from pathlib import Path
-from tops.config import LazyCall as L
-import torch
-import functools
-from dp2.data.datasets import CocoCSE
-from dp2.data.build import get_dataloader
-from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
-from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
-from dp2.metrics.torch_metrics import compute_metrics_iteratively
-from .utils import final_eval_fn
-
-
-dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
-metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
-data_dir = Path(dataset_base_dir, "coco_cse")
-data = dict(
- imsize=(288, 160),
- im_channels=3,
- semantic_nc=26,
- cse_nc=16,
- train=dict(
- dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
- loader=L(get_dataloader)(
- shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
- batch_size="${train.batch_size}",
- dataset="${..dataset}",
- infinite=True,
- gpu_transform=L(torch.nn.Sequential)(*[
- L(ToFloat)(),
- L(StyleGANAugmentPipe)(
- rotate=0.5, rotate_max=.05,
- xint=.5, xint_max=0.05,
- scale=.5, scale_std=.05,
- aniso=0.5, aniso_std=.05,
- xfrac=.5, xfrac_std=.05,
- brightness=.5, brightness_std=.05,
- contrast=.5, contrast_std=.1,
- hue=.5, hue_max=.05,
- saturation=.5, saturation_std=.5,
- imgfilter=.5, imgfilter_std=.1),
- L(RandomHorizontalFlip)(p=0.5),
- L(CreateEmbedding)(),
- L(Resize)(size="${data.imsize}"),
- L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
- L(CreateCondition)(),
- ])
- )
- ),
- val=dict(
- dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
- loader=L(get_dataloader)(
- shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
- batch_size="${train.batch_size}",
- dataset="${..dataset}",
- infinite=False,
- gpu_transform=L(torch.nn.Sequential)(*[
- L(ToFloat)(),
- L(CreateEmbedding)(),
- L(Resize)(size="${data.imsize}"),
- L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
- L(CreateCondition)(),
- ])
- )
- ),
- # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
- train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
- evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
-)
diff --git a/configs/datasets/fdf128.py b/configs/datasets/fdf128.py
deleted file mode 100644
index 8740ebd4738d7487cc9f1c6fbbcc2a2695240163..0000000000000000000000000000000000000000
--- a/configs/datasets/fdf128.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from pathlib import Path
-from functools import partial
-from dp2.data.datasets.fdf import FDFDataset
-from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn
-
-data_dir = Path(dataset_base_dir, "fdf")
-data.train.dataset.dirpath = data_dir.joinpath("train")
-data.val.dataset.dirpath = data_dir.joinpath("val")
-data.imsize = (128, 128)
-
-
-data.train_evaluation_fn = partial(
- final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train"))
-data.evaluation_fn = partial(
- final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final"))
-
-data.train.dataset.update(
- _target_ = FDFDataset,
- imsize="${data.imsize}"
-)
-data.val.dataset.update(
- _target_ = FDFDataset,
- imsize="${data.imsize}"
-)
\ No newline at end of file
diff --git a/configs/datasets/fdf256.py b/configs/datasets/fdf256.py
deleted file mode 100644
index 3767634c3d8f4b2253b5ad3c28bd324a59cde0ae..0000000000000000000000000000000000000000
--- a/configs/datasets/fdf256.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import os
-from pathlib import Path
-from tops.config import LazyCall as L
-import torch
-import functools
-from dp2.data.datasets.fdf import FDF256Dataset
-from dp2.data.build import get_dataloader
-from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
-from dp2.metrics.torch_metrics import compute_metrics_iteratively
-from dp2.metrics.fid_clip import compute_fid_clip
-from dp2.metrics.ppl import calculate_ppl
-from .utils import final_eval_fn
-
-
-def final_eval_fn(*args, **kwargs):
- result = compute_metrics_iteratively(*args, **kwargs)
- result2 = compute_fid_clip(*args, **kwargs)
- assert all(key not in result for key in result2)
- result.update(result2)
- result3 = calculate_ppl(*args, **kwargs,)
- assert all(key not in result for key in result3)
- result.update(result3)
- return result
-
-
-dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
-metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
-data_dir = Path(dataset_base_dir, "fdf256")
-data = dict(
- imsize=(256, 256),
- im_channels=3,
- semantic_nc=None,
- cse_nc=None,
- n_keypoints=None,
- train=dict(
- dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
- loader=L(get_dataloader)(
- shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
- batch_size="${train.batch_size}",
- dataset="${..dataset}",
- infinite=True,
- gpu_transform=L(torch.nn.Sequential)(*[
- L(ToFloat)(),
- L(RandomHorizontalFlip)(p=0.5),
- L(Resize)(size="${data.imsize}"),
- L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
- L(CreateCondition)(),
- ])
- )
- ),
- val=dict(
- dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
- loader=L(get_dataloader)(
- shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
- batch_size="${train.batch_size}",
- dataset="${..dataset}",
- infinite=False,
- gpu_transform=L(torch.nn.Sequential)(*[
- L(ToFloat)(),
- L(Resize)(size="${data.imsize}"),
- L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
- L(CreateCondition)(),
- ])
- )
- ),
- # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
- train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "fdf_val_train")),
- evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
-)
\ No newline at end of file
diff --git a/configs/datasets/fdh.py b/configs/datasets/fdh.py
deleted file mode 100644
index 298687d85302a029af2c44decf8dcb248b024031..0000000000000000000000000000000000000000
--- a/configs/datasets/fdh.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import os
-from pathlib import Path
-from tops.config import LazyCall as L
-import torch
-import functools
-from dp2.data.datasets.fdh import get_dataloader_fdh_wds
-from dp2.data.utils import get_coco_flipmap
-from dp2.data.transforms.transforms import (
- Normalize,
- ToFloat,
- CreateCondition,
- RandomHorizontalFlip,
- CreateEmbedding,
-)
-from dp2.metrics.torch_metrics import compute_metrics_iteratively
-from dp2.metrics.fid_clip import compute_fid_clip
-from .utils import final_eval_fn
-
-
-def train_eval_fn(*args, **kwargs):
- result = compute_metrics_iteratively(*args, **kwargs)
- result2 = compute_fid_clip(*args, **kwargs)
- assert all(key not in result for key in result2)
- result.update(result2)
- return result
-
-
-dataset_base_dir = (
- os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
-)
-metrics_cache = (
- os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
-)
-data_dir = Path(dataset_base_dir, "fdh")
-data = dict(
- imsize=(288, 160),
- im_channels=3,
- cse_nc=16,
- n_keypoints=17,
- train=dict(
- loader=L(get_dataloader_fdh_wds)(
- path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
- batch_size="${train.batch_size}",
- num_workers=6,
- transform=L(torch.nn.Sequential)(
- L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
- ),
- gpu_transform=L(torch.nn.Sequential)(
- L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
- L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
- L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
- L(CreateCondition)(),
- ),
- infinite=True,
- shuffle=True,
- partial_batches=False,
- load_embedding=True,
- )
- ),
- val=dict(
- loader=L(get_dataloader_fdh_wds)(
- path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
- batch_size="${train.batch_size}",
- num_workers=6,
- transform=None,
- gpu_transform=L(torch.nn.Sequential)(
- L(ToFloat)(keys=["img", "mask", "E_mask", "maskrcnn_mask"], norm=False),
- L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
- L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
- L(CreateCondition)(),
- ),
- infinite=False,
- shuffle=False,
- partial_batches=True,
- load_embedding=True,
- )
- ),
- # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
- train_evaluation_fn=functools.partial(
- train_eval_fn,
- cache_directory=Path(metrics_cache, "fdh_v7_train"),
- data_len=int(30e3),
- ),
- evaluation_fn=functools.partial(
- final_eval_fn,
- cache_directory=Path(metrics_cache, "fdh_v6_val"),
- data_len=int(30e3),
- ),
-)
diff --git a/configs/datasets/utils.py b/configs/datasets/utils.py
deleted file mode 100644
index ccc72dc4c403ffc83db9e79e6d8ee97d9b181622..0000000000000000000000000000000000000000
--- a/configs/datasets/utils.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from dp2.metrics.ppl import calculate_ppl
-from dp2.metrics.torch_metrics import compute_metrics_iteratively
-from dp2.metrics.fid_clip import compute_fid_clip
-
-
-def final_eval_fn(*args, **kwargs):
- result = compute_metrics_iteratively(*args, **kwargs)
- result2 = calculate_ppl(*args, **kwargs,)
- result2 = compute_fid_clip(*args, **kwargs)
- assert all(key not in result for key in result2)
- result.update(result2)
- return result
diff --git a/configs/defaults.py b/configs/defaults.py
deleted file mode 100644
index 9767dfa920d6261b749ce7064d4881a64ac73798..0000000000000000000000000000000000000000
--- a/configs/defaults.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import pathlib
-import os
-import torch
-from tops.config import LazyCall as L
-
-if "PRETRAINED_CHECKPOINTS_PATH" in os.environ:
- PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"])
-else:
- PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints")
-if "BASE_OUTPUT_DIR" in os.environ:
- BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"])
-else:
- BASE_OUTPUT_DIR = pathlib.Path("outputs")
-
-
-
-common = dict(
- logger_backend=["wandb", "stdout", "json", "image_dumper"],
- wandb_project="fba_test",
- output_dir=BASE_OUTPUT_DIR,
- experiment_name=None, # Optional experiment name to show on wandb
-)
-
-train = dict(
- batch_size=32,
- seed=0,
- ims_per_log=1024,
- ims_per_val=int(200e3),
- max_images_to_train=int(12e6),
- amp=dict(
- enabled=True,
- scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
- scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
- ),
- fp16_ddp_accumulate=False, # All gather gradients in fp16?
- broadcast_buffers=False,
- bias_act_plugin_enabled=True,
- grid_sample_gradfix_enabled=True,
- conv2d_gradfix_enabled=False,
- channels_last=False,
-)
-
-# exponential moving average
-EMA = dict(rampup=0.05)
-
diff --git a/configs/discriminators/sg2_discriminator.py b/configs/discriminators/sg2_discriminator.py
deleted file mode 100644
index e692450da87e04afe62c150dabe8eea3208b1382..0000000000000000000000000000000000000000
--- a/configs/discriminators/sg2_discriminator.py
+++ /dev/null
@@ -1,42 +0,0 @@
-from tops.config import LazyCall as L
-from dp2.discriminator import SG2Discriminator
-import torch
-from dp2.loss import StyleGAN2Loss
-
-
-discriminator = L(SG2Discriminator)(
- imsize="${data.imsize}",
- im_channels="${data.im_channels}",
- min_fmap_resolution=4,
- max_cnum_mul=8,
- cnum=80,
- input_condition=True,
- conv_clamp=256,
- input_cse=False,
- cse_nc="${data.cse_nc}"
-)
-
-
-loss_fnc = L(StyleGAN2Loss)(
- lazy_regularization=True,
- lazy_reg_interval=16,
- r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False),
- EP_lambd=0.001,
- pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01)
-)
-
-def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs):
- if lazy_regularization:
- # From Analyzing and improving the image quality of stylegan, CVPR 2020
- c = lazy_reg_interval / (lazy_reg_interval + 1)
- betas = [beta ** c for beta in betas]
- lr *= c
- print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}")
- return type(lr=lr, betas=betas, **kwargs)
-
-
-D_optim = L(build_D_optim)(
- type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99),
- lazy_regularization="${loss_fnc.lazy_regularization}",
- lazy_reg_interval="${loss_fnc.lazy_reg_interval}")
-G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99))
diff --git a/configs/fdf/stylegan.py b/configs/fdf/stylegan.py
deleted file mode 100644
index a4da2c3ad76d3d1fb6e1d91e832cde5c735bf32a..0000000000000000000000000000000000000000
--- a/configs/fdf/stylegan.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from ..generators.stylegan_unet import generator
-from ..datasets.fdf256 import data
-from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
-from ..defaults import train, common, EMA
-
-train.max_images_to_train = int(35e6)
-G_optim.lr = 0.002
-D_optim.lr = 0.002
-generator.input_cse = False
-loss_fnc.r1_opts.lambd = 1
-train.ims_per_val = int(2e6)
-
-common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e"
-common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07"
\ No newline at end of file
diff --git a/configs/fdf/stylegan_fdf128.py b/configs/fdf/stylegan_fdf128.py
deleted file mode 100644
index ff5c4065732d155f92a8a551efff546899589a4f..0000000000000000000000000000000000000000
--- a/configs/fdf/stylegan_fdf128.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
-from ..datasets.fdf128 import data
-from ..generators.stylegan_unet import generator
-from ..defaults import train, common, EMA
-from tops.config import LazyCall as L
-
-train.max_images_to_train = int(25e6)
-G_optim.lr = 0.002
-D_optim.lr = 0.002
-generator.cnum = 128
-generator.max_cnum_mul = 4
-generator.input_cse = False
-loss_fnc.r1_opts.lambd = .1
diff --git a/configs/fdh/styleganL.py b/configs/fdh/styleganL.py
deleted file mode 100644
index 48fcf09b43a7141a270fbe5c69bd7932414270fe..0000000000000000000000000000000000000000
--- a/configs/fdh/styleganL.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from tops.config import LazyCall as L
-from ..generators.stylegan_unet import generator
-from ..datasets.fdh import data
-from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
-from ..defaults import train, common, EMA
-
-train.max_images_to_train = int(50e6)
-train.batch_size = 64
-G_optim.lr = 0.002
-D_optim.lr = 0.002
-data.train.loader.num_workers = 4
-train.ims_per_val = int(1e6)
-loss_fnc.r1_opts.lambd = .1
-
-common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c"
-common.model_md5sum = "3411478b5ec600a4219cccf4499732bd"
\ No newline at end of file
diff --git a/configs/fdh/styleganL_nocse.py b/configs/fdh/styleganL_nocse.py
deleted file mode 100644
index 210fd68743f0b872f89f4407dfaac7c9bf5f0e32..0000000000000000000000000000000000000000
--- a/configs/fdh/styleganL_nocse.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from tops.config import LazyCall as L
-from ..generators.stylegan_unet import generator
-from ..datasets.fdh import data
-from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
-from ..defaults import train, common, EMA
-
-train.max_images_to_train = int(50e6)
-G_optim.lr = 0.002
-D_optim.lr = 0.002
-generator.input_cse = False
-data.load_embeddings = False
-common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt"
-common.model_md5sum = "fda0d809741bc67487abada793975c37"
-generator.fix_errors = False
\ No newline at end of file
diff --git a/configs/generators/stylegan_unet.py b/configs/generators/stylegan_unet.py
deleted file mode 100644
index 638859263a1cb549f533b75b2b19609665b3443e..0000000000000000000000000000000000000000
--- a/configs/generators/stylegan_unet.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from dp2.generator.stylegan_unet import StyleGANUnet
-from tops.config import LazyCall as L
-
-generator = L(StyleGANUnet)(
- imsize="${data.imsize}",
- im_channels="${data.im_channels}",
- min_fmap_resolution=8,
- cnum=64,
- max_cnum_mul=8,
- n_middle_blocks=0,
- z_channels=512,
- mask_output=True,
- conv_clamp=256,
- input_cse=True,
- scale_grad=True,
- cse_nc="${data.cse_nc}",
- w_dim=512,
- n_keypoints="${data.n_keypoints}",
- input_keypoints=False,
- input_keypoint_indices=[],
- fix_errors=True
-)
\ No newline at end of file
diff --git a/deep_privacy2 b/deep_privacy2
new file mode 160000
index 0000000000000000000000000000000000000000..37dcbeb23a1f51121d53bcd80d32d086d6822b7b
--- /dev/null
+++ b/deep_privacy2
@@ -0,0 +1 @@
+Subproject commit 37dcbeb23a1f51121d53bcd80d32d086d6822b7b
diff --git a/dp2/__init__.py b/dp2/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/dp2/anonymizer/__init__.py b/dp2/anonymizer/__init__.py
deleted file mode 100644
index 3fb33d7e6ad3b247938dc20ab2311728f286eb14..0000000000000000000000000000000000000000
--- a/dp2/anonymizer/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .anonymizer import Anonymizer
\ No newline at end of file
diff --git a/dp2/anonymizer/anonymizer.py b/dp2/anonymizer/anonymizer.py
deleted file mode 100644
index d850384b3fa33b08b5d5b6770b584c5b64adce44..0000000000000000000000000000000000000000
--- a/dp2/anonymizer/anonymizer.py
+++ /dev/null
@@ -1,159 +0,0 @@
-from pathlib import Path
-from typing import Union, Optional
-import numpy as np
-import torch
-import tops
-import torchvision.transforms.functional as F
-from motpy import Detection, MultiObjectTracker
-from dp2.utils import load_config
-from dp2.infer import build_trained_generator
-from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
-
-
-def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
- cfg = load_config(cfg_path)
- G = build_trained_generator(cfg)
- tops.logger.log(f"Loaded generator from: {cfg_path}")
- return G
-
-
-def resize_batch(img, mask, maskrcnn_mask, condition, imsize, **kwargs):
- img = F.resize(img, imsize, antialias=True)
- mask = (F.resize(mask, imsize, antialias=True) > 0.99).float()
- maskrcnn_mask = (F.resize(maskrcnn_mask, imsize, antialias=True) > 0.5).float()
-
- condition = img * mask
- return dict(img=img, mask=mask, maskrcnn_mask=maskrcnn_mask, condition=condition)
-
-
-class Anonymizer:
-
- def __init__(
- self,
- detector,
- load_cache: bool,
- person_G_cfg: Optional[Union[str, Path]] = None,
- cse_person_G_cfg: Optional[Union[str, Path]] = None,
- face_G_cfg: Optional[Union[str, Path]] = None,
- car_G_cfg: Optional[Union[str, Path]] = None,
- ) -> None:
- self.detector = detector
- self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
- self.load_cache = load_cache
- if cse_person_G_cfg is not None:
- self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
- if person_G_cfg is not None:
- self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
- if face_G_cfg is not None:
- self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
- if car_G_cfg is not None:
- self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
-
- def initialize_tracker(self, fps: float):
- self.tracker = MultiObjectTracker(dt=1/fps)
- self.track_to_z_idx = dict()
- self.cur_z_idx = 0
-
- @torch.no_grad()
- def anonymize_detections(self,
- im, detection, truncation_value: float,
- multi_modal_truncation: bool, amp: bool, z_idx,
- all_styles=None,
- update_identity=None,
- ):
- G = self.generators[type(detection)]
- if G is None:
- return im
- C, H, W = im.shape
- orig_im = im.clone()
- if update_identity is None:
- update_identity = [True for i in range(len(detection))]
- for idx in range(len(detection)):
- if not update_identity[idx]:
- continue
- batch = detection.get_crop(idx, im)
- x0, y0, x1, y1 = batch.pop("boxes")[0]
- batch = {k: tops.to_cuda(v) for k, v in batch.items()}
- batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
- batch["img"] = batch["img"].float()
- batch["condition"] = batch["mask"] * batch["img"]
- orig_shape = None
- if G.imsize and batch["img"].shape[-1] != G.imsize[-1] and batch["img"].shape[-2] != G.imsize[-2]:
- orig_shape = batch["img"].shape[-2:]
- batch = resize_batch(**batch, imsize=G.imsize)
- with torch.cuda.amp.autocast(amp):
- if all_styles is not None:
- anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
- elif multi_modal_truncation and hasattr(G, "multi_modal_truncate") and hasattr(G.style_net, "w_centers"):
- w_indices = None
- if z_idx is not None:
- w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
- anonymized_im = G.multi_modal_truncate(
- **batch, truncation_value=truncation_value,
- w_indices=w_indices)["img"]
- else:
- z = None
- if z_idx is not None:
- state = np.random.RandomState(seed=z_idx[idx])
- z = state.normal(size=(1, G.z_channels))
- z = tops.to_cuda(torch.from_numpy(z))
- anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
- if orig_shape is not None:
- anonymized_im = F.resize(anonymized_im, orig_shape, antialias=True)
- anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255).round().byte()
-
- # Resize and denormalize image
- gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), antialias=True)
- mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
- # Remove padding
- pad = [max(-x0,0), max(-y0,0)]
- pad = [*pad, max(x1-W,0), max(y1-H,0)]
- remove_pad = lambda x: x[...,pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
- gim = remove_pad(gim)
- mask = remove_pad(mask)
- x0, y0 = max(x0, 0), max(y0, 0)
- x1, y1 = min(x1, W), min(y1, H)
- mask = mask.logical_not()[None].repeat(3, 1, 1)
- im[:, y0:y1, x0:x1][mask] = gim[mask]
-
- return im
-
- def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
- all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
- for det in all_detections:
- im = det.visualize(im)
- return im
-
- @torch.no_grad()
- def forward(self, im: torch.Tensor, cache_id: str = None, track=True, **synthesis_kwargs) -> torch.Tensor:
- assert im.dtype == torch.uint8
- im = tops.to_cuda(im)
- all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
- if hasattr(self, "tracker") and track:
- [_.pre_process() for _ in all_detections]
- import numpy as np
- boxes = np.concatenate([_.boxes for _ in all_detections])
- boxes = [Detection(box) for box in boxes]
- self.tracker.step(boxes)
- track_ids = self.tracker.detections_matched_ids
- z_idx = []
- for track_id in track_ids:
- if track_id not in self.track_to_z_idx:
- self.track_to_z_idx[track_id] = self.cur_z_idx
- self.cur_z_idx += 1
- z_idx.append(self.track_to_z_idx[track_id])
- z_idx = np.array(z_idx)
- idx_offset = 0
-
- for detection in all_detections:
- zs = None
- if hasattr(self, "tracker") and track:
- zs = z_idx[idx_offset:idx_offset+len(detection)]
- idx_offset += len(detection)
- im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
-
- return im.cpu()
-
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
-
diff --git a/dp2/data/__init__.py b/dp2/data/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/dp2/data/build.py b/dp2/data/build.py
deleted file mode 100644
index 07f2a4e3630b405bf5a84f6926a733044345d741..0000000000000000000000000000000000000000
--- a/dp2/data/build.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import io
-import torch
-import tops
-from .utils import collate_fn, jpg_decoder, get_num_workers, png_decoder
-
-def get_dataloader(
- dataset, gpu_transform: torch.nn.Module,
- num_workers,
- batch_size,
- infinite: bool,
- drop_last: bool,
- prefetch_factor: int,
- shuffle,
- channels_last=False
- ):
- sampler = None
- dl_kwargs = dict(
- pin_memory=True,
- )
- if infinite:
- sampler = tops.InfiniteSampler(
- dataset, rank=tops.rank(),
- num_replicas=tops.world_size(),
- shuffle=shuffle
- )
- elif tops.world_size() > 1:
- sampler = torch.utils.data.DistributedSampler(
- dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
- dl_kwargs["drop_last"] = drop_last
- else:
- dl_kwargs["shuffle"] = shuffle
- dl_kwargs["drop_last"] = drop_last
- dataloader = torch.utils.data.DataLoader(
- dataset, sampler=sampler, collate_fn=collate_fn,
- batch_size=batch_size,
- num_workers=num_workers, prefetch_factor=prefetch_factor,
- **dl_kwargs
- )
- dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
- return dataloader
-
-
-def get_dataloader_places2_wds(
- path,
- batch_size: int,
- num_workers: int,
- transform: torch.nn.Module,
- gpu_transform: torch.nn.Module,
- infinite: bool,
- shuffle: bool,
- partial_batches: bool,
- sample_shuffle=10_000,
- tar_shuffle=100,
- channels_last=False,
- ):
- import webdataset as wds
- import os
- os.environ["RANK"] = str(tops.rank())
- os.environ["WORLD_SIZE"] = str(tops.world_size())
-
- if infinite:
- pipeline = [wds.ResampledShards(str(path))]
- else:
- pipeline = [wds.SimpleShardList(str(path))]
- if shuffle:
- pipeline.append(wds.shuffle(tar_shuffle))
- pipeline.extend([
- wds.split_by_node,
- wds.split_by_worker,
- ])
- if shuffle:
- pipeline.append(wds.shuffle(sample_shuffle))
-
- pipeline.extend([
- wds.tarfile_to_samples(),
- wds.decode("torchrgb8"),
- wds.rename_keys(["img", "jpg"], ["__key__", "__key__"]),
- ])
- if transform is not None:
- pipeline.append(wds.map(transform))
- pipeline.extend([
- wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
- ])
- pipeline = wds.DataPipeline(*pipeline)
- if infinite:
- pipeline = pipeline.repeat(nepochs=1000000)
- loader = wds.WebLoader(
- pipeline, batch_size=None, shuffle=False,
- num_workers=get_num_workers(num_workers),
- persistent_workers=True,
- )
- loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
- return loader
-
-
-
-
-def get_dataloader_celebAHQ_wds(
- path,
- batch_size: int,
- num_workers: int,
- transform: torch.nn.Module,
- gpu_transform: torch.nn.Module,
- infinite: bool,
- shuffle: bool,
- partial_batches: bool,
- sample_shuffle=10_000,
- tar_shuffle=100,
- channels_last=False,
- ):
- import webdataset as wds
- import os
- os.environ["RANK"] = str(tops.rank())
- os.environ["WORLD_SIZE"] = str(tops.world_size())
-
- if infinite:
- pipeline = [wds.ResampledShards(str(path))]
- else:
- pipeline = [wds.SimpleShardList(str(path))]
- if shuffle:
- pipeline.append(wds.shuffle(tar_shuffle))
- pipeline.extend([
- wds.split_by_node,
- wds.split_by_worker,
- ])
- if shuffle:
- pipeline.append(wds.shuffle(sample_shuffle))
-
- pipeline.extend([
- wds.tarfile_to_samples(),
- wds.decode(wds.handle_extension(".png", png_decoder)),
- wds.rename_keys(["img", "png"], ["__key__", "__key__"]),
- ])
- if transform is not None:
- pipeline.append(wds.map(transform))
- pipeline.extend([
- wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
- ])
- pipeline = wds.DataPipeline(*pipeline)
- if infinite:
- pipeline = pipeline.repeat(nepochs=1000000)
- loader = wds.WebLoader(
- pipeline, batch_size=None, shuffle=False,
- num_workers=get_num_workers(num_workers),
- persistent_workers=True,
- )
- loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last)
- return loader
diff --git a/dp2/data/datasets/__init__.py b/dp2/data/datasets/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/dp2/data/datasets/coco_cse.py b/dp2/data/datasets/coco_cse.py
deleted file mode 100644
index 27fa6dfb94f118b939788e0134e6ac42c613b297..0000000000000000000000000000000000000000
--- a/dp2/data/datasets/coco_cse.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import pickle
-import torchvision
-import torch
-import pathlib
-import numpy as np
-from typing import Callable, Optional, Union
-from torch.hub import get_dir as get_hub_dir
-
-
-def cache_embed_stats(embed_map: torch.Tensor):
- mean = embed_map.mean(dim=0, keepdim=True)
- rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
-
- cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
- path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
- path.parent.mkdir(exist_ok=True, parents=True)
- torch.save(cache, path)
-
-
-class CocoCSE(torch.utils.data.Dataset):
-
- def __init__(self,
- dirpath: Union[str, pathlib.Path],
- transform: Optional[Callable],
- normalize_E: bool,):
- dirpath = pathlib.Path(dirpath)
- self.dirpath = dirpath
-
- self.transform = transform
- assert self.dirpath.is_dir(),\
- f"Did not find dataset at: {dirpath}"
- self.image_paths, self.embedding_paths = self._load_impaths()
- self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
- mean = self.embed_map.mean(dim=0, keepdim=True)
- rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
- self.embed_map = (self.embed_map - mean) * rstd
- cache_embed_stats(self.embed_map)
-
- def _load_impaths(self):
- image_dir = self.dirpath.joinpath("images")
- image_paths = list(image_dir.glob("*.png"))
- image_paths.sort()
- embedding_paths = [
- self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
- ]
- return image_paths, embedding_paths
-
- def __len__(self):
- return len(self.image_paths)
-
- def __getitem__(self, idx):
- im = torchvision.io.read_image(str(self.image_paths[idx]))
- vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
- vertices = torch.from_numpy(vertices.squeeze()).long()
- mask = torch.from_numpy(mask.squeeze()).float()
- border = torch.from_numpy(border.squeeze()).float()
- E_mask = 1 - mask - border
- batch = {
- "img": im,
- "vertices": vertices[None],
- "mask": mask[None],
- "embed_map": self.embed_map,
- "border": border[None],
- "E_mask": E_mask[None]
- }
- if self.transform is None:
- return batch
- return self.transform(batch)
-
-
-class CocoCSEWithFace(CocoCSE):
-
- def __init__(self,
- dirpath: Union[str, pathlib.Path],
- transform: Optional[Callable],
- **kwargs):
- super().__init__(dirpath, transform, **kwargs)
- with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp:
- self.face_boxes = pickle.load(fp)
-
- def __getitem__(self, idx):
- item = super().__getitem__(idx)
- item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name]
- return item
-
-
-class CocoCSESemantic(torch.utils.data.Dataset):
-
- def __init__(self,
- dirpath: Union[str, pathlib.Path],
- transform: Optional[Callable],
- **kwargs):
- dirpath = pathlib.Path(dirpath)
- self.dirpath = dirpath
-
- self.transform = transform
- assert self.dirpath.is_dir(),\
- f"Did not find dataset at: {dirpath}"
- self.image_paths, self.embedding_paths = self._load_impaths()
- self.vertx2cat = torch.from_numpy(np.load(self.dirpath.parent.joinpath("vertx2cat.npy")))
- self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
-
- def _load_impaths(self):
- image_dir = self.dirpath.joinpath("images")
- image_paths = list(image_dir.glob("*.png"))
- image_paths.sort()
- embedding_paths = [
- self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
- ]
- return image_paths, embedding_paths
-
- def __len__(self):
- return len(self.image_paths)
-
- def __getitem__(self, idx):
- im = torchvision.io.read_image(str(self.image_paths[idx]))
- vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
- vertices = torch.from_numpy(vertices.squeeze()).long()
- mask = torch.from_numpy(mask.squeeze()).float()
- border = torch.from_numpy(border.squeeze()).float()
- E_mask = 1 - mask - border
- batch = {
- "img": im,
- "vertices": vertices[None],
- "mask": mask[None],
- "border": border[None],
- "vertx2cat": self.vertx2cat,
- "embed_map": self.embed_map,
- }
- if self.transform is None:
- return batch
- return self.transform(batch)
-
-
-class CocoCSESemanticWithFace(CocoCSESemantic):
-
- def __init__(self,
- dirpath: Union[str, pathlib.Path],
- transform: Optional[Callable],
- **kwargs):
- super().__init__(dirpath, transform, **kwargs)
- with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp:
- self.face_boxes = pickle.load(fp)
-
- def __getitem__(self, idx):
- item = super().__getitem__(idx)
- item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name]
- return item
diff --git a/dp2/data/datasets/fdf.py b/dp2/data/datasets/fdf.py
deleted file mode 100644
index b05c75692515c5e294af72417371af9fcefbfad8..0000000000000000000000000000000000000000
--- a/dp2/data/datasets/fdf.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import pathlib
-from typing import Tuple
-import numpy as np
-import torch
-import pathlib
-try:
- import pyspng
- PYSPNG_IMPORTED = True
-except ImportError:
- PYSPNG_IMPORTED = False
- print("Could not load pyspng. Defaulting to pillow image backend.")
- from PIL import Image
-from tops import logger
-
-
-class FDFDataset:
-
- def __init__(self,
- dirpath,
- imsize: Tuple[int],
- load_keypoints: bool,
- transform):
- dirpath = pathlib.Path(dirpath)
- self.dirpath = dirpath
- self.transform = transform
- self.imsize = imsize[0]
- self.load_keypoints = load_keypoints
- assert self.dirpath.is_dir(),\
- f"Did not find dataset at: {dirpath}"
- image_dir = self.dirpath.joinpath("images", str(self.imsize))
- self.image_paths = list(image_dir.glob("*.png"))
- assert len(self.image_paths) > 0,\
- f"Did not find images in: {image_dir}"
- self.image_paths.sort(key=lambda x: int(x.stem))
- self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
-
- self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch"))
- assert len(self.image_paths) == len(self.bounding_boxes)
- assert len(self.image_paths) == len(self.landmarks)
- logger.log(
- f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}")
-
- def get_mask(self, idx):
- mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool)
- bounding_box = self.bounding_boxes[idx]
- x0, y0, x1, y1 = bounding_box
- mask[:, y0:y1, x0:x1] = 0
- return mask
-
- def __len__(self):
- return len(self.image_paths)
-
- def __getitem__(self, index):
- impath = self.image_paths[index]
- if PYSPNG_IMPORTED:
- with open(impath, "rb") as fp:
- im = pyspng.load(fp.read())
- else:
- with Image.open(impath) as fp:
- im = np.array(fp)
- im = torch.from_numpy(np.rollaxis(im, -1, 0))
- masks = self.get_mask(index)
- landmark = self.landmarks[index]
- batch = {
- "img": im,
- "mask": masks,
- }
- if self.load_keypoints:
- batch["keypoints"] = landmark
- if self.transform is None:
- return batch
- return self.transform(batch)
-
-
-class FDF256Dataset:
-
- def __init__(self,
- dirpath,
- load_keypoints: bool,
- transform):
- dirpath = pathlib.Path(dirpath)
- self.dirpath = dirpath
- self.transform = transform
- self.load_keypoints = load_keypoints
- assert self.dirpath.is_dir(),\
- f"Did not find dataset at: {dirpath}"
- image_dir = self.dirpath.joinpath("images")
- self.image_paths = list(image_dir.glob("*.png"))
- assert len(self.image_paths) > 0,\
- f"Did not find images in: {image_dir}"
- self.image_paths.sort(key=lambda x: int(x.stem))
- self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
- self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy")))
- assert len(self.image_paths) == len(self.bounding_boxes)
- assert len(self.image_paths) == len(self.landmarks)
- logger.log(
- f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}")
-
- def get_mask(self, idx):
- mask = torch.ones((1, 256, 256), dtype=torch.bool)
- bounding_box = self.bounding_boxes[idx]
- x0, y0, x1, y1 = bounding_box
- mask[:, y0:y1, x0:x1] = 0
- return mask
-
- def __len__(self):
- return len(self.image_paths)
-
- def __getitem__(self, index):
- impath = self.image_paths[index]
- if PYSPNG_IMPORTED:
- with open(impath, "rb") as fp:
- im = pyspng.load(fp.read())
- else:
- with Image.open(impath) as fp:
- im = np.array(fp)
- im = torch.from_numpy(np.rollaxis(im, -1, 0))
- masks = self.get_mask(index)
- landmark = self.landmarks[index]
- batch = {
- "img": im,
- "mask": masks,
- }
- if self.load_keypoints:
- batch["keypoints"] = landmark
- if self.transform is None:
- return batch
- return self.transform(batch)
-
diff --git a/dp2/data/datasets/fdh.py b/dp2/data/datasets/fdh.py
deleted file mode 100644
index 5eb654ba71604f1b088df586e6683418033b4b16..0000000000000000000000000000000000000000
--- a/dp2/data/datasets/fdh.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import torch
-import tops
-import numpy as np
-import io
-import webdataset as wds
-import os
-from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
-
-
-def kp_decoder(x):
- # Keypoints are between [0, 1] for webdataset
- keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
- keypoints[:, 0] /= 160
- keypoints[:, 1] /= 288
- check_outside = lambda x: (x < 0).logical_or(x > 1)
- is_outside = check_outside(keypoints[:, 0]).logical_or(
- check_outside(keypoints[:, 1])
- )
- keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
- return keypoints
-
-
-def vertices_decoder(x):
- vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
- return vertices.squeeze()[None]
-
-
-def get_dataloader_fdh_wds(
- path,
- batch_size: int,
- num_workers: int,
- transform: torch.nn.Module,
- gpu_transform: torch.nn.Module,
- infinite: bool,
- shuffle: bool,
- partial_batches: bool,
- load_embedding: bool,
- sample_shuffle=10_000,
- tar_shuffle=100,
- read_condition=False,
- channels_last=False,
- ):
- # Need to set this for split_by_node to work.
- os.environ["RANK"] = str(tops.rank())
- os.environ["WORLD_SIZE"] = str(tops.world_size())
- if infinite:
- pipeline = [wds.ResampledShards(str(path))]
- else:
- pipeline = [wds.SimpleShardList(str(path))]
- if shuffle:
- pipeline.append(wds.shuffle(tar_shuffle))
- pipeline.extend([
- wds.split_by_node,
- wds.split_by_worker,
- ])
- if shuffle:
- pipeline.append(wds.shuffle(sample_shuffle))
-
- decoder = [
- wds.handle_extension("image.png", png_decoder),
- wds.handle_extension("mask.png", mask_decoder),
- wds.handle_extension("maskrcnn_mask.png", mask_decoder),
- wds.handle_extension("keypoints.npy", kp_decoder),
- ]
-
- rename_keys = [
- ["img", "image.png"], ["mask", "mask.png"],
- ["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"]
- ]
- if load_embedding:
- decoder.extend([
- wds.handle_extension("vertices.npy", vertices_decoder),
- wds.handle_extension("E_mask.png", mask_decoder)
- ])
- rename_keys.extend([
- ["vertices", "vertices.npy"],
- ["E_mask", "e_mask.png"]
- ])
-
- if read_condition:
- decoder.append(
- wds.handle_extension("condition.png", png_decoder)
- )
- rename_keys.append(["condition", "condition.png"])
-
- pipeline.extend([
- wds.tarfile_to_samples(),
- wds.decode(*decoder),
- wds.rename_keys(*rename_keys),
- wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
- ])
- if transform is not None:
- pipeline.append(wds.map(transform))
- pipeline = wds.DataPipeline(*pipeline)
- if infinite:
- pipeline = pipeline.repeat(nepochs=1000000)
-
- loader = wds.WebLoader(
- pipeline, batch_size=None, shuffle=False,
- num_workers=get_num_workers(num_workers),
- persistent_workers=True,
- )
- loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
- return loader
diff --git a/dp2/data/transforms/__init__.py b/dp2/data/transforms/__init__.py
deleted file mode 100644
index 66a9e160b6513cc81a00b0a62606721f37b70c61..0000000000000000000000000000000000000000
--- a/dp2/data/transforms/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize
-from .stylegan2_transform import StyleGANAugmentPipe
\ No newline at end of file
diff --git a/dp2/data/transforms/functional.py b/dp2/data/transforms/functional.py
deleted file mode 100644
index 6d5c695944a47d67cac7e03a8a7f5b400c94417b..0000000000000000000000000000000000000000
--- a/dp2/data/transforms/functional.py
+++ /dev/null
@@ -1,61 +0,0 @@
-import torchvision.transforms.functional as F
-import torch
-import pickle
-from tops import download_file, assert_shape
-from typing import Dict
-from functools import lru_cache
-
-global symmetry_transform
-
-@lru_cache(maxsize=1)
-def get_symmetry_transform(symmetry_url):
- file_name = download_file(symmetry_url)
- with open(file_name, "rb") as fp:
- symmetry = pickle.load(fp)
- return torch.from_numpy(symmetry["vertex_transforms"]).long()
-
-
-hflip_handled_cases = set([
- "keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition",
- "embedding", "vertx2cat", "maskrcnn_mask", "__key__",
- "img_hr", "condition_hr", "mask_hr"])
-
-def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]:
- container["img"] = F.hflip(container["img"])
- if "condition" in container:
- container["condition"] = F.hflip(container["condition"])
- if "embedding" in container:
- container["embedding"] = F.hflip(container["embedding"])
- assert all([key in hflip_handled_cases for key in container]), container.keys()
- if "keypoints" in container:
- assert flip_map is not None
- if container["keypoints"].ndim == 3:
- keypoints = container["keypoints"][:, flip_map, :]
- keypoints[:, :, 0] = 1 - keypoints[:, :, 0]
- else:
- assert_shape(container["keypoints"], (None, 3))
- keypoints = container["keypoints"][flip_map, :]
- keypoints[:, 0] = 1 - keypoints[:, 0]
- container["keypoints"] = keypoints
- if "mask" in container:
- container["mask"] = F.hflip(container["mask"])
- if "border" in container:
- container["border"] = F.hflip(container["border"])
- if "semantic_mask" in container:
- container["semantic_mask"] = F.hflip(container["semantic_mask"])
- if "vertices" in container:
- symmetry_transform = get_symmetry_transform("https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl")
- container["vertices"] = F.hflip(container["vertices"])
- symmetry_transform_ = symmetry_transform.to(container["vertices"].device)
- container["vertices"] = symmetry_transform_[container["vertices"].long()]
- if "E_mask" in container:
- container["E_mask"] = F.hflip(container["E_mask"])
- if "maskrcnn_mask" in container:
- container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"])
- if "img_hr" in container:
- container["img_hr"] = F.hflip(container["img_hr"])
- if "condition_hr" in container:
- container["condition_hr"] = F.hflip(container["condition_hr"])
- if "mask_hr" in container:
- container["mask_hr"] = F.hflip(container["mask_hr"])
- return container
diff --git a/dp2/data/transforms/stylegan2_transform.py b/dp2/data/transforms/stylegan2_transform.py
deleted file mode 100644
index 49a143cddf9673d079b87ac7d725c433713e54c5..0000000000000000000000000000000000000000
--- a/dp2/data/transforms/stylegan2_transform.py
+++ /dev/null
@@ -1,394 +0,0 @@
-import numpy as np
-import scipy.signal
-import torch
-try:
- from sg3_torch_utils import misc
- from sg3_torch_utils.ops import upfirdn2d
- from sg3_torch_utils.ops import grid_sample_gradfix
- from sg3_torch_utils.ops import conv2d_gradfix
-except:
- pass
-#----------------------------------------------------------------------------
-# Coefficients of various wavelet decomposition low-pass filters.
-
-wavelets = {
- 'haar': [0.7071067811865476, 0.7071067811865476],
- 'db1': [0.7071067811865476, 0.7071067811865476],
- 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
- 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
- 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
- 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
- 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
- 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
- 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
- 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
- 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
- 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
- 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
- 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
- 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
- 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
-}
-
-#----------------------------------------------------------------------------
-# Helpers for constructing transformation matrices.
-
-
-def matrix(*rows, device=None):
- assert all(len(row) == len(rows[0]) for row in rows)
- elems = [x for row in rows for x in row]
- ref = [x for x in elems if isinstance(x, torch.Tensor)]
- if len(ref) == 0:
- return misc.constant(np.asarray(rows), device=device)
- assert device is None or device == ref[0].device
- elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
- return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
-
-
-def translate2d(tx, ty, **kwargs):
- return matrix(
- [1, 0, tx],
- [0, 1, ty],
- [0, 0, 1],
- **kwargs)
-
-
-def translate3d(tx, ty, tz, **kwargs):
- return matrix(
- [1, 0, 0, tx],
- [0, 1, 0, ty],
- [0, 0, 1, tz],
- [0, 0, 0, 1],
- **kwargs)
-
-
-def scale2d(sx, sy, **kwargs):
- return matrix(
- [sx, 0, 0],
- [0, sy, 0],
- [0, 0, 1],
- **kwargs)
-
-
-def scale3d(sx, sy, sz, **kwargs):
- return matrix(
- [sx, 0, 0, 0],
- [0, sy, 0, 0],
- [0, 0, sz, 0],
- [0, 0, 0, 1],
- **kwargs)
-
-
-def rotate2d(theta, **kwargs):
- return matrix(
- [torch.cos(theta), torch.sin(-theta), 0],
- [torch.sin(theta), torch.cos(theta), 0],
- [0, 0, 1],
- **kwargs)
-
-
-def rotate3d(v, theta, **kwargs):
- vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
- s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
- return matrix(
- [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
- [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
- [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
- [0, 0, 0, 1],
- **kwargs)
-
-
-def translate2d_inv(tx, ty, **kwargs):
- return translate2d(-tx, -ty, **kwargs)
-
-
-def scale2d_inv(sx, sy, **kwargs):
- return scale2d(1 / sx, 1 / sy, **kwargs)
-
-
-def rotate2d_inv(theta, **kwargs):
- return rotate2d(-theta, **kwargs)
-
-
-class StyleGANAugmentPipe(torch.nn.Module):
- def __init__(self,
- rotate90=0, xint=0, xint_max=0.125,
- scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
- brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5,
- hue_max=1, saturation_std=1,
- imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
- ):
- super().__init__()
- self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
-
- # Pixel blitting.
- self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
- self.xint = float(xint) # Probability multiplier for integer translation.
- self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
-
- # General geometric transformations.
- self.scale = float(scale) # Probability multiplier for isotropic scaling.
- self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
- self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
- self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
- self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
- self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
- self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
- self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
-
- # Color transformations.
- self.brightness = float(brightness) # Probability multiplier for brightness.
- self.contrast = float(contrast) # Probability multiplier for contrast.
- self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
- self.hue = float(hue) # Probability multiplier for hue rotation.
- self.saturation = float(saturation) # Probability multiplier for saturation.
- self.brightness_std = float(brightness_std) # Standard deviation of brightness.
- self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
- self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
- self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
-
- # Image-space filtering.
- self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
- self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
- self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
-
- # Setup orthogonal lowpass filter for geometric augmentations.
- self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
-
- # Construct filter bank for image-space filtering.
- Hz_lo = np.asarray(wavelets['sym2']) # H(z)
- Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
- Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
- Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
- Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
- for i in range(1, Hz_fbank.shape[0]):
- Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
- Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
- Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
- self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
-
- def forward(self, batch, debug_percentile=None):
- images = batch["img"]
- batch["vertices"] = batch["vertices"].float()
- assert isinstance(images, torch.Tensor) and images.ndim == 4
- batch_size, num_channels, height, width = images.shape
- device = images.device
- self.Hz_fbank = self.Hz_fbank.to(device)
- self.Hz_geom = self.Hz_geom.to(device)
- if debug_percentile is not None:
- debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
-
- # -------------------------------------
- # Select parameters for pixel blitting.
- # -------------------------------------
-
- # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
- I_3 = torch.eye(3, device=device)
- G_inv = I_3
-
- # Apply integer translation with probability (xint * strength).
- if self.xint > 0:
- t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
- t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
- if debug_percentile is not None:
- t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
- G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
-
- # --------------------------------------------------------
- # Select parameters for general geometric transformations.
- # --------------------------------------------------------
-
- # Apply isotropic scaling with probability (scale * strength).
- if self.scale > 0:
- s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
- s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
- if debug_percentile is not None:
- s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
- G_inv = G_inv @ scale2d_inv(s, s)
-
- # Apply pre-rotation with probability p_rot.
- p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
- if self.rotate > 0:
- theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
- theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
- if debug_percentile is not None:
- theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
- G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
-
- # Apply anisotropic scaling with probability (aniso * strength).
- if self.aniso > 0:
- s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
- s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
- if debug_percentile is not None:
- s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
- G_inv = G_inv @ scale2d_inv(s, 1 / s)
-
- # Apply post-rotation with probability p_rot.
- if self.rotate > 0:
- theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
- theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
- if debug_percentile is not None:
- theta = torch.zeros_like(theta)
- G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
-
- # Apply fractional translation with probability (xfrac * strength).
- if self.xfrac > 0:
- t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
- t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
- if debug_percentile is not None:
- t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
- G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
-
- # ----------------------------------
- # Execute geometric transformations.
- # ----------------------------------
-
- # Execute if the transform is not identity.
- if G_inv is not I_3:
- # Calculate padding.
- cx = (width - 1) / 2
- cy = (height - 1) / 2
- cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
- cp = G_inv @ cp.t() # [batch, xyz, idx]
- Hz_pad = self.Hz_geom.shape[0] // 4
- margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
- margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
- margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
- margin = margin.max(misc.constant([0, 0] * 2, device=device))
- margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
- mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
-
- # Pad image and adjust origin.
- images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
- batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0)
- batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
- batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
- G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
-
- # Upsample.
- images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
- batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest")
- batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest")
- batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest")
- G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
- G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
-
- # Execute transformation.
- shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
- G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
- grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
- images = grid_sample_gradfix.grid_sample(images, grid)
-
- batch["mask"] = torch.nn.functional.grid_sample(
- input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
- batch["E_mask"] = torch.nn.functional.grid_sample(
- input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
- batch["vertices"] = torch.nn.functional.grid_sample(
- input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
-
-
- # Downsample and crop.
- images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
- batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
- batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
- batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
- # --------------------------------------------
- # Select parameters for color transformations.
- # --------------------------------------------
-
- # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
- I_4 = torch.eye(4, device=device)
- C = I_4
-
- # Apply brightness with probability (brightness * strength).
- if self.brightness > 0:
- b = torch.randn([batch_size], device=device) * self.brightness_std
- b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
- if debug_percentile is not None:
- b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
- C = translate3d(b, b, b) @ C
-
- # Apply contrast with probability (contrast * strength).
- if self.contrast > 0:
- c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
- c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
- if debug_percentile is not None:
- c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
- C = scale3d(c, c, c) @ C
-
- # Apply luma flip with probability (lumaflip * strength).
- v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
-
- # Apply hue rotation with probability (hue * strength).
- if self.hue > 0 and num_channels > 1:
- theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
- theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
- if debug_percentile is not None:
- theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
- C = rotate3d(v, theta) @ C # Rotate around v.
-
- # Apply saturation with probability (saturation * strength).
- if self.saturation > 0 and num_channels > 1:
- s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
- s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
- if debug_percentile is not None:
- s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
- C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
-
- # ------------------------------
- # Execute color transformations.
- # ------------------------------
-
- # Execute if the transform is not identity.
- if C is not I_4:
- images = images.reshape([batch_size, num_channels, height * width])
- if num_channels == 3:
- images = C[:, :3, :3] @ images + C[:, :3, 3:]
- elif num_channels == 1:
- C = C[:, :3, :].mean(dim=1, keepdims=True)
- images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
- else:
- raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
- images = images.reshape([batch_size, num_channels, height, width])
-
- # ----------------------
- # Image-space filtering.
- # ----------------------
-
- if self.imgfilter > 0:
- num_bands = self.Hz_fbank.shape[0]
- assert len(self.imgfilter_bands) == num_bands
- expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
-
- # Apply amplification for each band with probability (imgfilter * strength * band_strength).
- g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
- for i, band_strength in enumerate(self.imgfilter_bands):
- t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
- t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
- if debug_percentile is not None:
- t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
- t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
- t[:, i] = t_i # Replace i'th element.
- t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
- g = g * t # Accumulate into global gain.
-
- # Construct combined amplification filter.
- Hz_prime = g @ self.Hz_fbank # [batch, tap]
- Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
- Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
-
- # Apply filter.
- p = self.Hz_fbank.shape[1] // 2
- images = images.reshape([1, batch_size * num_channels, height, width])
- images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
- images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
- images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
- images = images.reshape([batch_size, num_channels, height, width])
-
- # ------------------------
- # Image-space corruptions.
- # ------------------------
- batch["img"] = images
- batch["vertices"] = batch["vertices"].long()
- batch["border"] = 1 - batch["E_mask"] - batch["mask"]
- return batch
diff --git a/dp2/data/transforms/transforms.py b/dp2/data/transforms/transforms.py
deleted file mode 100644
index 1221a9121d7f59b1ac33c28e4189339e8df6dadf..0000000000000000000000000000000000000000
--- a/dp2/data/transforms/transforms.py
+++ /dev/null
@@ -1,247 +0,0 @@
-from pathlib import Path
-from typing import Dict, List
-import torchvision
-import torch
-import tops
-import torchvision.transforms.functional as F
-from .functional import hflip
-
-
-class RandomHorizontalFlip(torch.nn.Module):
-
- def __init__(self, p: float, flip_map=None,**kwargs):
- super().__init__()
- self.flip_ratio = p
- self.flip_map = flip_map
- if self.flip_ratio is None:
- self.flip_ratio = 0.5
- assert 0 <= self.flip_ratio <= 1
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- if torch.rand(1) > self.flip_ratio:
- return container
- return hflip(container, self.flip_map)
-
-
-class CenterCrop(torch.nn.Module):
- """
- Performs the transform on the image.
- NOTE: Does not transform the mask to improve runtime.
- """
-
- def __init__(self, size: List[int]):
- super().__init__()
- self.size = tuple(size)
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- min_size = min(container["img"].shape[1], container["img"].shape[2])
- if min_size < self.size[0]:
- container["img"] = F.center_crop(container["img"], min_size)
- container["img"] = F.resize(container["img"], self.size)
- return container
- container["img"] = F.center_crop(container["img"], self.size)
- return container
-
-
-class Resize(torch.nn.Module):
- """
- Performs the transform on the image.
- NOTE: Does not transform the mask to improve runtime.
- """
-
- def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
- super().__init__()
- self.size = tuple(size)
- self.interpolation = interpolation
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
- if "semantic_mask" in container:
- container["semantic_mask"] = F.resize(
- container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
- if "embedding" in container:
- container["embedding"] = F.resize(
- container["embedding"], self.size, self.interpolation)
- if "mask" in container:
- container["mask"] = F.resize(
- container["mask"], self.size, F.InterpolationMode.NEAREST)
- if "E_mask" in container:
- container["E_mask"] = F.resize(
- container["E_mask"], self.size, F.InterpolationMode.NEAREST)
- if "maskrcnn_mask" in container:
- container["maskrcnn_mask"] = F.resize(
- container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
- if "vertices" in container:
- container["vertices"] = F.resize(
- container["vertices"], self.size, F.InterpolationMode.NEAREST)
- return container
-
- def __repr__(self):
- repr = super().__repr__()
- vars_ = dict(size=self.size, interpolation=self.interpolation)
- return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
-
-
-class InsertHRImage(torch.nn.Module):
- """
- Resizes mask by maxpool and assumes condition is already created
- """
- def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
- super().__init__()
- self.size = tuple(size)
- self.interpolation = interpolation
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- assert container["img"].dtype == torch.float32
- container["img_hr"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
- container["condition_hr"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
- mask = container["mask"] > 0
- container["mask_hr"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
- container["condition_hr"] = container["condition_hr"] * (1 - container["mask_hr"]) + container["img_hr"] * container["mask_hr"]
- return container
-
- def __repr__(self):
- repr = super().__repr__()
- vars_ = dict(size=self.size, interpolation=self.interpolation)
- return repr + " "
-
-
-class CopyHRImage(torch.nn.Module):
- def __init__(self) -> None:
- super().__init__()
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- container["img_hr"] = container["img"]
- container["condition_hr"] = container["condition"]
- container["mask_hr"] = container["mask"]
- return container
-
-
-class Resize2(torch.nn.Module):
- """
- Resizes mask by maxpool and assumes condition is already created
- """
- def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, downsample_condition: bool = True, mask_condition= True):
- super().__init__()
- self.size = tuple(size)
- self.interpolation = interpolation
- self.downsample_condition = downsample_condition
- self.mask_condition = mask_condition
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
-# assert container["img"].dtype == torch.float32
- container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
- mask = container["mask"] > 0
- container["mask"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
-
- if self.downsample_condition:
- container["condition"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
- if self.mask_condition:
- container["condition"] = container["condition"] * (1 - container["mask"]) + container["img"] * container["mask"]
- return container
-
- def __repr__(self):
- repr = super().__repr__()
- vars_ = dict(size=self.size, interpolation=self.interpolation)
- return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
-
-
-
-class Normalize(torch.nn.Module):
- """
- Performs the transform on the image.
- NOTE: Does not transform the mask to improve runtime.
- """
-
- def __init__(self, mean, std, inplace, keys=["img"]):
- super().__init__()
- self.mean = mean
- self.std = std
- self.inplace = inplace
- self.keys = keys
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- for key in self.keys:
- container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
- return container
-
- def __repr__(self):
- repr = super().__repr__()
- vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
- return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
-
-
-class ToFloat(torch.nn.Module):
-
- def __init__(self, keys=["img"], norm=True) -> None:
- super().__init__()
- self.keys = keys
- self.gain = 255 if norm else 1
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- for key in self.keys:
- container[key] = container[key].float() / self.gain
- return container
-
-
-class RandomCrop(torchvision.transforms.RandomCrop):
- """
- Performs the transform on the image.
- NOTE: Does not transform the mask to improve runtime.
- """
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- container["img"] = super().forward(container["img"])
- return container
-
-
-class CreateCondition(torch.nn.Module):
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- if container["img"].dtype == torch.uint8:
- container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
- return container
- container["condition"] = container["img"] * container["mask"]
- return container
-
-
-class CreateEmbedding(torch.nn.Module):
-
- def __init__(self, embed_path: Path, cuda=True) -> None:
- super().__init__()
- self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
- if cuda:
- self.embed_map = tops.to_cuda(self.embed_map)
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- vertices = container["vertices"]
- if vertices.ndim == 3:
- embedding = self.embed_map[vertices.long()].squeeze(dim=0)
- embedding = embedding.permute(2, 0, 1) * container["E_mask"]
- pass
- else:
- assert vertices.ndim == 4
- embedding = self.embed_map[vertices.long()].squeeze(dim=1)
- embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
- container["embedding"] = embedding
- container["embed_map"] = self.embed_map.clone()
- return container
-
-
-class UpdateMask(torch.nn.Module):
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- container["mask"] = (container["img"] == container["condition"]).any(dim=1, keepdims=True).float()
- return container
-
-
-class LoadClassEmbedding(torch.nn.Module):
-
- def __init__(self, embedding_path: Path) -> None:
- super().__init__()
- self.embedding = torch.load(embedding_path, map_location="cpu")
-
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- key = "_".join(container["__key__"].split("train/")[-1].split("/")[:-1])
- container["class_embedding"] = self.embedding[key].view(-1)
- return container
diff --git a/dp2/data/utils.py b/dp2/data/utils.py
deleted file mode 100644
index 1ec03f0ec0091e3263f5aa9b2962ad97a1c36a0f..0000000000000000000000000000000000000000
--- a/dp2/data/utils.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import torch
-from PIL import Image
-import numpy as np
-import multiprocessing
-import io
-from tops import logger
-from torch.utils.data._utils.collate import default_collate
-
-try:
- import pyspng
-
- PYSPNG_IMPORTED = True
-except ImportError:
- PYSPNG_IMPORTED = False
- print("Could not load pyspng. Defaulting to pillow image backend.")
- from PIL import Image
-
-
-def get_coco_keypoints():
- return [
- "nose",
- "left_eye",
- "right_eye",
- "left_ear",
- "right_ear",
- "left_shoulder",
- "right_shoulder",
- "left_elbow",
- "right_elbow",
- "left_wrist",
- "right_wrist",
- "left_hip",
- "right_hip",
- "left_knee",
- "right_knee",
- "left_ankle",
- "right_ankle",
- ]
-
-
-def get_coco_flipmap():
- keypoints = get_coco_keypoints()
- keypoint_flip_map = {
- "left_eye": "right_eye",
- "left_ear": "right_ear",
- "left_shoulder": "right_shoulder",
- "left_elbow": "right_elbow",
- "left_wrist": "right_wrist",
- "left_hip": "right_hip",
- "left_knee": "right_knee",
- "left_ankle": "right_ankle",
- }
- for key, value in list(keypoint_flip_map.items()):
- keypoint_flip_map[value] = key
- keypoint_flip_map["nose"] = "nose"
- keypoint_flip_map_idx = []
- for source in keypoints:
- keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
- return keypoint_flip_map_idx
-
-
-def mask_decoder(x):
- mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
- mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255.
- return mask
-
-
-def png_decoder(x):
- if PYSPNG_IMPORTED:
- return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
- with Image.open(io.BytesIO(x)) as im:
- im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
- return im
-
-
-def jpg_decoder(x):
- with Image.open(io.BytesIO(x)) as im:
- im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
- return im
-
-
-def get_num_workers(num_workers: int):
- n_cpus = multiprocessing.cpu_count()
- if num_workers > n_cpus:
- logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
- return n_cpus
- return num_workers
-
-
-def collate_fn(batch):
- elem = batch[0]
- ignore_keys = set(["embed_map", "vertx2cat"])
- batch_ = {
- key: default_collate([d[key] for d in batch])
- for key in elem
- if key not in ignore_keys
- }
- if "embed_map" in elem:
- batch_["embed_map"] = elem["embed_map"]
- if "vertx2cat" in elem:
- batch_["vertx2cat"] = elem["vertx2cat"]
- return batch_
diff --git a/dp2/detection/__init__.py b/dp2/detection/__init__.py
deleted file mode 100644
index 613969b28384cd1c64fc8db685e7622f4cc02615..0000000000000000000000000000000000000000
--- a/dp2/detection/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .cse_mask_face_detector import CSeMaskFaceDetector
-from .person_detector import CSEPersonDetector
-from .structures import PersonDetection, VehicleDetection, FaceDetection
diff --git a/dp2/detection/base.py b/dp2/detection/base.py
deleted file mode 100644
index 32ebba893878e95ddc1da451a9f2027799ffa044..0000000000000000000000000000000000000000
--- a/dp2/detection/base.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import pickle
-import torch
-import lzma
-from pathlib import Path
-from tops import logger
-
-
-class BaseDetector:
-
-
- def __init__(self, cache_directory: str) -> None:
- if cache_directory is not None:
- self.cache_directory = Path(cache_directory, str(self.__class__.__name__))
- self.cache_directory.mkdir(exist_ok=True, parents=True)
-
- def save_to_cache(self, detection, cache_path: Path, after_preprocess=True):
- logger.log(f"Caching detection to: {cache_path}")
- with lzma.open(cache_path, "wb") as fp:
- torch.save(
- [det.state_dict(after_preprocess=after_preprocess) for det in detection], fp,
- pickle_protocol=pickle.HIGHEST_PROTOCOL)
-
- def load_from_cache(self, cache_path: Path):
- logger.log(f"Loading detection from cache path: {cache_path}")
- with lzma.open(cache_path, "rb") as fp:
- state_dict = torch.load(fp)
- return [
- state["cls"].from_state_dict(state_dict=state) for state in state_dict
- ]
-
- def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool):
- if cache_id is None:
- return self.forward(im)
- cache_path = self.cache_directory.joinpath(cache_id + ".torch")
- if cache_path.is_file() and load_cache:
- try:
- return self.load_from_cache(cache_path)
- except Exception as e:
- logger.warn(f"The cache file was corrupted: {cache_path}")
- exit()
- detections = self.forward(im)
- self.save_to_cache(detections, cache_path)
- return detections
-
-
\ No newline at end of file
diff --git a/dp2/detection/box_utils.py b/dp2/detection/box_utils.py
deleted file mode 100644
index 6091b122a1e72d05b9cad2a25f4111b425eabb93..0000000000000000000000000000000000000000
--- a/dp2/detection/box_utils.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import numpy as np
-
-
-def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio):
- x0, y0, x1, y1 = [int(_) for _ in bbox]
- h, w = y1 - y0, x1 - x0
- cur_ratio = h / w
-
- if cur_ratio == target_aspect_ratio:
- return [x0, y0, x1, y1]
- if cur_ratio < target_aspect_ratio:
- target_height = int(w*target_aspect_ratio)
- y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
- else:
- target_width = int(h/target_aspect_ratio)
- x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
- return x0, y0, x1, y1
-
-
-def expand_axis(start, end, target_width, limit):
- # Can return a bbox outside of limit
- cur_width = end - start
- start = start - (target_width-cur_width)//2
- end = end + (target_width-cur_width)//2
- if end - start != target_width:
- end += 1
- assert end - start == target_width
- if start < 0 and end > limit:
- return start, end
- if start < 0 and end < limit:
- to_shift = min(0 - start, limit - end)
- start += to_shift
- end += to_shift
- if end > limit and start > 0:
- to_shift = min(end - limit, start)
- end -= to_shift
- start -= to_shift
- assert end - start == target_width
- return start, end
-
-
-def expand_box(bbox, imshape, mask, percentage_background: float):
- assert isinstance(bbox[0], int)
- assert 0 < percentage_background < 1
- # Percentage in S
- mask_pixels = mask.long().sum().cpu()
- total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
- percentage_mask = mask_pixels / total_pixels
- if (1 - percentage_mask) > percentage_background:
- return bbox
- target_pixels = mask_pixels / (1 - percentage_background)
- x0, y0, x1, y1 = bbox
- H = y1 - y0
- W = x1 - x0
- p = np.sqrt(target_pixels/(H*W))
- target_width = int(np.ceil(p * W))
- target_height = int(np.ceil(p * H))
- x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
- y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
- return [x0, y0, x1, y1]
-
-
-def expand_axises_by_percentage(bbox_XYXY, imshape, percentage):
- x0, y0, x1, y1 = bbox_XYXY
- H = y1 - y0
- W = x1 - x0
- expansion = int(((H*W)**0.5) * percentage)
- new_width = W + expansion
- new_height = H + expansion
- x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1])
- y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0])
- return [x0, y0, x1, y1]
-
-
-def get_expanded_bbox(
- bbox_XYXY,
- imshape,
- mask,
- percentage_background: float,
- axis_minimum_expansion: float,
- target_aspect_ratio: float):
- bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist()
- # Expand each axis of the bounding box by a minimum percentage
- bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion)
- # Find the minimum bbox with the aspect ratio. Can be outside of imshape
- bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio)
- # Expands square box such that X% of the bbox is background
- bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background)
- assert isinstance(bbox_XYXY[0], (int, np.int64))
- return bbox_XYXY
-
-
-def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape):
- def area_inside_ratio(bbox, imshape):
- area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
- area_inside = (min(bbox[2], imshape[1]) - max(0,bbox[0])) * (min(imshape[0],bbox[3]) - max(0,bbox[1]))
- return area_inside / area
- ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0])
- area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
- if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside:
- return False
- if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area:
- return False
- return True
diff --git a/dp2/detection/box_utils_fdf.py b/dp2/detection/box_utils_fdf.py
deleted file mode 100644
index 48e4e8c6ef067eb495ff8a021d2d236606e2e906..0000000000000000000000000000000000000000
--- a/dp2/detection/box_utils_fdf.py
+++ /dev/null
@@ -1,203 +0,0 @@
-"""
-The FDF dataset expands bound boxes differently from what is used for CSE.
-"""
-
-import numpy as np
-
-
-def quadratic_bounding_box(x0, y0, width, height, imshape):
- # We assume that we can create a image that is quadratic without
- # minimizing any of the sides
- assert width <= min(imshape[:2])
- assert height <= min(imshape[:2])
- min_side = min(height, width)
- if height != width:
- side_diff = abs(height - width)
- # Want to extend the shortest side
- if min_side == height:
- # Vertical side
- height += side_diff
- if height > imshape[0]:
- # Take full frame, and shrink width
- y0 = 0
- height = imshape[0]
-
- side_diff = abs(height - width)
- width -= side_diff
- x0 += side_diff // 2
- else:
- y0 -= side_diff // 2
- y0 = max(0, y0)
- else:
- # Horizontal side
- width += side_diff
- if width > imshape[1]:
- # Take full frame width, and shrink height
- x0 = 0
- width = imshape[1]
-
- side_diff = abs(height - width)
- height -= side_diff
- y0 += side_diff // 2
- else:
- x0 -= side_diff // 2
- x0 = max(0, x0)
- # Check that bbox goes outside image
- x1 = x0 + width
- y1 = y0 + height
- if imshape[1] < x1:
- diff = x1 - imshape[1]
- x0 -= diff
- if imshape[0] < y1:
- diff = y1 - imshape[0]
- y0 -= diff
- assert x0 >= 0, "Bounding box outside image."
- assert y0 >= 0, "Bounding box outside image."
- assert x0 + width <= imshape[1], "Bounding box outside image."
- assert y0 + height <= imshape[0], "Bounding box outside image."
- return x0, y0, width, height
-
-
-def expand_bounding_box(bbox, percentage, imshape):
- orig_bbox = bbox.copy()
- x0, y0, x1, y1 = bbox
- width = x1 - x0
- height = y1 - y0
- x0, y0, width, height = quadratic_bounding_box(
- x0, y0, width, height, imshape)
- expanding_factor = int(max(height, width) * percentage)
-
- possible_max_expansion = [(imshape[0] - width) // 2,
- (imshape[1] - height) // 2,
- expanding_factor]
-
- expanding_factor = min(possible_max_expansion)
- # Expand height
-
- if expanding_factor > 0:
-
- y0 = y0 - expanding_factor
- y0 = max(0, y0)
-
- height += expanding_factor * 2
- if height > imshape[0]:
- y0 -= (imshape[0] - height)
- height = imshape[0]
-
- if height + y0 > imshape[0]:
- y0 -= (height + y0 - imshape[0])
-
- # Expand width
- x0 = x0 - expanding_factor
- x0 = max(0, x0)
-
- width += expanding_factor * 2
- if width > imshape[1]:
- x0 -= (imshape[1] - width)
- width = imshape[1]
-
- if width + x0 > imshape[1]:
- x0 -= (width + x0 - imshape[1])
- y1 = y0 + height
- x1 = x0 + width
- assert y0 >= 0, "Y0 is minus"
- assert height <= imshape[0], "Height is larger than image."
- assert x0 + width <= imshape[1]
- assert y0 + height <= imshape[0]
- assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!"
- assert x0 >= 0, "Y0 is minus"
- assert width <= imshape[1], "Height is larger than image."
- # Check that original bbox is within new
- x0_o, y0_o, x1_o, y1_o = orig_bbox
- assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}"
- assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}"
- assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}"
- assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}"
-
- x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]]
- x1 = x0 + width
- y1 = y0 + height
- return np.array([x0, y0, x1, y1])
-
-
-def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
- keypoint = keypoint[:, :3] # only nose + eyes are relevant
- kp_X = keypoint[0, :]
- kp_Y = keypoint[1, :]
- within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1)
- within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1)
- return within_X and within_Y
-
-
-def expand_bbox_simple(bbox, percentage):
- x0, y0, x1, y1 = bbox.astype(float)
- width = x1 - x0
- height = y1 - y0
- x_c = int(x0) + width // 2
- y_c = int(y0) + height // 2
- avg_size = max(width, height)
- new_width = avg_size * (1 + percentage)
- x0 = x_c - new_width // 2
- y0 = y_c - new_width // 2
- x1 = x_c + new_width // 2
- y1 = y_c + new_width // 2
- return np.array([x0, y0, x1, y1]).astype(int)
-
-
-def pad_image(im, bbox, pad_value):
- x0, y0, x1, y1 = bbox
- if x0 < 0:
- pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]),
- dtype=np.uint8) + pad_value
- im = np.concatenate((pad_im, im), axis=1)
- x1 += abs(x0)
- x0 = 0
- if y0 < 0:
- pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]),
- dtype=np.uint8) + pad_value
- im = np.concatenate((pad_im, im), axis=0)
- y1 += abs(y0)
- y0 = 0
- if x1 >= im.shape[1]:
- pad_im = np.zeros(
- (im.shape[0], x1 - im.shape[1] + 1, im.shape[2]),
- dtype=np.uint8) + pad_value
- im = np.concatenate((im, pad_im), axis=1)
- if y1 >= im.shape[0]:
- pad_im = np.zeros(
- (y1 - im.shape[0] + 1, im.shape[1], im.shape[2]),
- dtype=np.uint8) + pad_value
- im = np.concatenate((im, pad_im), axis=0)
- return im[y0:y1, x0:x1]
-
-
-def clip_box(bbox, im):
- bbox[0] = max(0, bbox[0])
- bbox[1] = max(0, bbox[1])
- bbox[2] = min(im.shape[1] - 1, bbox[2])
- bbox[3] = min(im.shape[0] - 1, bbox[3])
- return bbox
-
-
-def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True):
- outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0]
- if simple_expand or (outside_im and pad_im):
- return pad_image(im, bbox, pad_value)
- bbox = clip_box(bbox, im)
- x0, y0, x1, y1 = bbox
- return im[y0:y1, x0:x1]
-
-
-def expand_bbox(
- bbox_ltrb, imshape, simple_expand, default_to_simple=False,
- expansion_factor=0.35):
- assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox.shape}"
- bbox = bbox_ltrb.astype(float)
- # FDF256 uses simple expand with ratio 0.4
- if simple_expand:
- return expand_bbox_simple(bbox, 0.4)
- try:
- return expand_bounding_box(bbox, expansion_factor, imshape)
- except AssertionError:
- return expand_bbox_simple(bbox, expansion_factor * 2)
-
diff --git a/dp2/detection/cse_mask_face_detector.py b/dp2/detection/cse_mask_face_detector.py
deleted file mode 100644
index 74a8cf43eb35516e5e2c2c3354e15fc50ea88016..0000000000000000000000000000000000000000
--- a/dp2/detection/cse_mask_face_detector.py
+++ /dev/null
@@ -1,116 +0,0 @@
-import torch
-import lzma
-import tops
-from pathlib import Path
-from dp2.detection.base import BaseDetector
-from .utils import combine_cse_maskrcnn_dets
-from face_detection import build_detector as build_face_detector
-from .models.cse import CSEDetector
-from .models.mask_rcnn import MaskRCNNDetector
-from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection
-from tops import logger
-
-
-def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
- assert len(box1.shape) == 2
- assert len(box2.shape) == 2
- box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
- # This can be batched
- for i, box in enumerate(box1):
- is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
- is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
- is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
- box1_inside[i] = is_outside.logical_not().any()
- return box1_inside
-
-
-class CSeMaskFaceDetector(BaseDetector):
-
- def __init__(
- self,
- mask_rcnn_cfg,
- face_detector_cfg: dict,
- cse_cfg: dict,
- face_post_process_cfg: dict,
- cse_post_process_cfg,
- score_threshold: float,
- **kwargs
- ) -> None:
- super().__init__(**kwargs)
- self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
- if "confidence_threshold" not in face_detector_cfg:
- face_detector_cfg["confidence_threshold"] = score_threshold
- if "score_thres" not in cse_cfg:
- cse_cfg["score_thres"] = score_threshold
- self.cse_detector = CSEDetector(**cse_cfg)
- self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True)
- self.cse_post_process_cfg = cse_post_process_cfg
- self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
- self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold")
- self.face_post_process_cfg = face_post_process_cfg
-
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
-
- def _detect_faces(self, im: torch.Tensor):
- H, W = im.shape[1:]
- im = im.float() - self.face_mean
- im = self.face_detector.resize(im[None], 1.0)
- boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
- boxes_XYXY[:, [0, 2]] *= W
- boxes_XYXY[:, [1, 3]] *= H
- return boxes_XYXY.round().long()
-
- def load_from_cache(self, cache_path: Path):
- logger.log(f"Loading detection from cache path: {cache_path}",)
- with lzma.open(cache_path, "rb") as fp:
- state_dict = torch.load(fp, map_location="cpu")
- kwargs = dict(
- post_process_cfg=self.cse_post_process_cfg,
- embed_map=self.cse_detector.embed_map,
- **self.face_post_process_cfg
- )
- return [
- state["cls"].from_state_dict(**kwargs, state_dict=state)
- for state in state_dict
- ]
-
- @torch.no_grad()
- def forward(self, im: torch.Tensor):
- maskrcnn_dets = self.mask_rcnn(im)
- cse_dets = self.cse_detector(im)
- embed_map = self.cse_detector.embed_map
- print("Calling face detector.")
- face_boxes = self._detect_faces(im).cpu()
- maskrcnn_person = {
- k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items()
- }
- maskrcnn_other = {
- k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items()
- }
- maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"])
- combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets(
- maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold)
-
- persons_with_cse = CSEPersonDetection(
- combined_segmentation, cse_dets, **self.cse_post_process_cfg,
- embed_map=embed_map,orig_imshape_CHW=im.shape
- )
- persons_with_cse.pre_process()
- not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]]
- persons_without_cse = PersonDetection(
- maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg,
- orig_imshape_CHW=im.shape
- )
- persons_without_cse.pre_process()
-
- face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or(
- box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes)
- )
- face_boxes = face_boxes[face_boxes_covered.logical_not()]
- face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
-
- # Order matters. The anonymizer will anonymize FIFO.
- # Later detections will overwrite.
- all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse]
- return all_detections
diff --git a/dp2/detection/face_detector.py b/dp2/detection/face_detector.py
deleted file mode 100644
index b05565bc3bc095edf1760c24c4238fe20b9962dc..0000000000000000000000000000000000000000
--- a/dp2/detection/face_detector.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import torch
-import lzma
-import tops
-from pathlib import Path
-from dp2.detection.base import BaseDetector
-from face_detection import build_detector as build_face_detector
-from .structures import FaceDetection
-from tops import logger
-
-
-def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
- assert len(box1.shape) == 2
- assert len(box2.shape) == 2
- box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
- # This can be batched
- for i, box in enumerate(box1):
- is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
- is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
- is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
- box1_inside[i] = is_outside.logical_not().any()
- return box1_inside
-
-
-class FaceDetector(BaseDetector):
-
- def __init__(
- self,
- face_detector_cfg: dict,
- score_threshold: float,
- face_post_process_cfg: dict,
- **kwargs
- ) -> None:
- super().__init__(**kwargs)
- self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
- self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
- self.face_post_process_cfg = face_post_process_cfg
-
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
-
- def _detect_faces(self, im: torch.Tensor):
- H, W = im.shape[1:]
- im = im.float() - self.face_mean
- im = self.face_detector.resize(im[None], 1.0)
- boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
- boxes_XYXY[:, [0, 2]] *= W
- boxes_XYXY[:, [1, 3]] *= H
- return boxes_XYXY.round().long().cpu()
-
- @torch.no_grad()
- def forward(self, im: torch.Tensor):
- face_boxes = self._detect_faces(im)
- face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
- return [face_boxes]
-
- def load_from_cache(self, cache_path: Path):
- logger.log(f"Loading detection from cache path: {cache_path}")
- with lzma.open(cache_path, "rb") as fp:
- state_dict = torch.load(fp)
- return [
- state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict
- ]
diff --git a/dp2/detection/models/__init__.py b/dp2/detection/models/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/dp2/detection/models/cse.py b/dp2/detection/models/cse.py
deleted file mode 100644
index fe6b0f6c876ed86d542604ea0f4d7274afe31584..0000000000000000000000000000000000000000
--- a/dp2/detection/models/cse.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import torch
-from typing import List
-import tops
-from torchvision.transforms.functional import InterpolationMode, resize
-from densepose.data.utils import get_class_to_mesh_name_mapping
-from densepose import add_densepose_config
-from densepose.structures import DensePoseEmbeddingPredictorOutput
-from densepose.vis.extractor import DensePoseOutputsExtractor
-from densepose.modeling import build_densepose_embedder
-from detectron2.config import get_cfg
-from detectron2.data.transforms import ResizeShortestEdge
-from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer
-from detectron2.modeling import build_model
-
-
-model_urls = {
- "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl",
- "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl",
-}
-
-
-def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape):
- assert len(S.shape) == 3
- H, W = imshape
- N = len(boxes_XYXY)
- segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device)
- boxes_XYXY = boxes_XYXY.long()
- for i in range(N):
- x0, y0, x1, y1 = boxes_XYXY[i]
- assert x0 >= 0 and y0 >= 0
- assert x1 <= imshape[1]
- assert y1 <= imshape[0]
- h = y1 - y0
- w = x1 - x0
- segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0
- return segmentation
-
-
-class CSEDetector:
-
- def __init__(
- self,
- cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
- cfg_2_download: List[str] = [
- "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
- "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml",
- "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"],
- score_thres: float = 0.9,
- nms_thresh: float = None,
- ) -> None:
- with tops.logger.capture_log_stdout():
- cfg = get_cfg()
- self.device = tops.get_device()
- add_densepose_config(cfg)
- cfg_path = tops.download_file(cfg_url)
- for p in cfg_2_download:
- tops.download_file(p)
- with tops.logger.capture_log_stdout():
- cfg.merge_from_file(cfg_path)
- assert cfg_url in model_urls, cfg_url
- model_path = tops.download_file(model_urls[cfg_url])
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
- if nms_thresh is not None:
- cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh
- cfg.MODEL.WEIGHTS = str(model_path)
- cfg.MODEL.DEVICE = str(self.device)
- cfg.freeze()
- with tops.logger.capture_log_stdout():
- self.model = build_model(cfg)
- self.model.eval()
- DetectionCheckpointer(self.model).load(str(model_path))
- self.input_format = cfg.INPUT.FORMAT
- self.densepose_extractor = DensePoseOutputsExtractor()
- self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
-
- self.embedder = build_densepose_embedder(cfg)
- self.mesh_vertex_embeddings = {
- mesh_name: self.embedder(mesh_name).to(self.device)
- for mesh_name in self.class_to_mesh_name.values()
- if self.embedder.has_embeddings(mesh_name)
- }
- self.cfg = cfg
- self.embed_map = self.mesh_vertex_embeddings["smpl_27554"]
- tops.logger.log("CSEDetector built.")
-
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
-
- def resize_im(self, im):
- H, W = im.shape[1:]
- newH, newW = ResizeShortestEdge.get_output_shape(
- H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
- return resize(
- im, (newH, newW), InterpolationMode.BILINEAR, antialias=True)
-
- @torch.no_grad()
- def forward(self, im):
- assert im.dtype == torch.uint8
- if self.input_format == "BGR":
- im = im.flip(0)
- H, W = im.shape[1:]
- im = self.resize_im(im)
- output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
- scores = output.get("scores")
- if len(scores) == 0:
- return dict(
- instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device),
- instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device),
- embed_map=self.mesh_vertex_embeddings["smpl_27554"],
- bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device),
- im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device),
- scores=torch.empty((0), dtype=torch.float, device=im.device)
- )
- pred_densepose, boxes_xywh, classes = self.densepose_extractor(output)
- assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose
- S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes)
- E = pred_densepose.embedding
- mesh_name = self.class_to_mesh_name[classes[0]]
- assert mesh_name == "smpl_27554"
- x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)]
- boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1)
- boxes_XYXY = boxes_XYXY.round_().long()
-
- non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not()
- S = S[non_empty_boxes]
- E = E[non_empty_boxes]
- boxes_XYXY = boxes_XYXY[non_empty_boxes]
- scores = scores[non_empty_boxes]
- im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W])
- return dict(
- instance_segmentation=S, instance_embedding=E,
- bbox_XYXY=boxes_XYXY,
- im_segmentation=im_segmentation,
- scores=scores.view(-1))
-
diff --git a/dp2/detection/models/keypoint_maskrcnn.py b/dp2/detection/models/keypoint_maskrcnn.py
deleted file mode 100644
index 4fc3fd9e19aa8a023ad8135f6e6997135049c3db..0000000000000000000000000000000000000000
--- a/dp2/detection/models/keypoint_maskrcnn.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import numpy as np
-import torch
-from detectron2.checkpoint import DetectionCheckpointer
-from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads
-from detectron2.data.transforms import ResizeShortestEdge
-from detectron2.structures import Instances
-from detectron2 import model_zoo
-from detectron2.config import instantiate
-from detectron2.config import LazyCall as L
-from PIL import Image
-import tops
-import functools
-from torchvision.transforms.functional import resize
-
-
-def get_rn50_fpn_keypoint_rcnn(weight_path: str):
- from detectron2.modeling.poolers import ROIPooler
- from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
- from detectron2.layers import ShapeSpec
- model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model
- model.roi_heads.update(
- num_classes=1,
- keypoint_in_features=["p2", "p3", "p4", "p5"],
- keypoint_pooler=L(ROIPooler)(
- output_size=14,
- scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
- sampling_ratio=0,
- pooler_type="ROIAlignV2",
- ),
- keypoint_head=L(KRCNNConvDeconvUpsampleHead)(
- input_shape=ShapeSpec(channels=256, width=14, height=14),
- num_keypoints=17,
- conv_dims=[512] * 8,
- loss_normalizer="visible",
- ),
- )
-
- # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2.
- # 1000 proposals per-image is found to hurt box AP.
- # Therefore we increase it to 1500 per-image.
- model.proposal_generator.post_nms_topk = (1500, 1000)
-
- # Keypoint AP degrades (though box AP improves) when using plain L1 loss
- model.roi_heads.box_predictor.smooth_l1_beta = 0.5
- model = instantiate(model)
-
- dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader
- test_transform = instantiate(dataloader.test.mapper.augmentations)
- DetectionCheckpointer(model).load(weight_path)
- return model, test_transform
-
-
-models = {
- "rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth")
-}
-
-
-
-
-class KeypointMaskRCNN:
-
- def __init__(self, model_name: str, score_threshold: float) -> None:
- assert model_name in models, f"Did not find {model_name} in models"
- model, test_transform = models[model_name]()
- self.model = model.eval().to(tops.get_device())
- if isinstance(self.model.roi_heads, CascadeROIHeads):
- for head in self.model.roi_heads.box_predictors:
- assert hasattr(head, "test_score_thresh")
- head.test_score_thresh = score_threshold
- else:
- assert isinstance(self.model.roi_heads, StandardROIHeads)
- assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh")
- self.model.roi_heads.box_predictor.test_score_thresh = score_threshold
-
- self.test_transform = test_transform
- assert len(self.test_transform) == 1
- self.test_transform = self.test_transform[0]
- assert isinstance(self.test_transform, ResizeShortestEdge)
- assert self.test_transform.interp == Image.BILINEAR
- self.image_format = self.model.input_format
-
- def resize_im(self, im):
- H, W = im.shape[-2:]
- if self.test_transform.is_range:
- size = np.random.randint(self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1)
- else:
- size = np.random.choice(self.test_transform.short_edge_length)
- newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size)
- return resize(
- im, (newH, newW), antialias=True)
-
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
-
- @torch.no_grad()
- def forward(self, im: torch.Tensor) -> Instances:
- assert im.ndim == 3
- if self.image_format == "BGR":
- im = im.flip(0)
- H, W = im.shape[-2:]
- im = self.resize_im(im)
- im = im.float()
- inputs = dict(image=im, height=H, width=W)
- # instances contains
- # dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps'])
- instances = self.model([inputs])[0]["instances"]
- return dict(
- scores=instances.get("scores").cpu(),
- segmentation=instances.get("pred_masks").cpu(),
- keypoints=instances.get("pred_keypoints").cpu()
- )
diff --git a/dp2/detection/models/mask_rcnn.py b/dp2/detection/models/mask_rcnn.py
deleted file mode 100644
index 1f87d151709c8adede45aa00de0c4bce0287114e..0000000000000000000000000000000000000000
--- a/dp2/detection/models/mask_rcnn.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import torch
-import tops
-from detectron2.modeling import build_model
-from detectron2.checkpoint import DetectionCheckpointer
-from detectron2.structures import Boxes
-from detectron2.data import MetadataCatalog
-from detectron2 import model_zoo
-from typing import Dict
-from detectron2.data.transforms import ResizeShortestEdge
-from torchvision.transforms.functional import resize
-
-
-
-model_urls = {
- "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml": "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl",
-
-}
-class MaskRCNNDetector:
-
- def __init__(
- self,
- cfg_name: str = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml",
- score_thres: float = 0.9,
- class_filter=["person"], #["car", "bicycle","truck", "bus", "backpack"]
- fp16_inference: bool = False
- ) -> None:
- cfg = model_zoo.get_config(cfg_name)
- cfg.MODEL.DEVICE = str(tops.get_device())
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
- cfg.freeze()
- self.cfg = cfg
- with tops.logger.capture_log_stdout():
- self.model = build_model(cfg)
- DetectionCheckpointer(self.model).load(model_urls[cfg_name])
- self.model.eval()
- self.input_format = cfg.INPUT.FORMAT
- self.class_names = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
- self.class_to_keep = set([self.class_names.index(cls_) for cls_ in class_filter])
- self.person_class = self.class_names.index("person")
- self.fp16_inference = fp16_inference
- tops.logger.log("Mask R-CNN built.")
-
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
-
- def resize_im(self, im):
- H, W = im.shape[1:]
- newH, newW = ResizeShortestEdge.get_output_shape(
- H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
- return resize(
- im, (newH, newW), antialias=True)
-
- @torch.no_grad()
- def forward(self, im: torch.Tensor):
- if self.input_format == "BGR":
- im = im.flip(0)
- else:
- assert self.input_format == "RGB"
- H, W = im.shape[-2:]
- im = self.resize_im(im)
- with torch.cuda.amp.autocast(enabled=self.fp16_inference):
- output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
- scores = output.get("scores")
- N = len(scores)
- classes = output.get("pred_classes")
- idx2keep = [i for i in range(N) if classes[i].tolist() in self.class_to_keep]
- classes = classes[idx2keep]
- assert isinstance(output.get("pred_boxes"), Boxes)
- segmentation = output.get("pred_masks")[idx2keep]
- assert segmentation.dtype == torch.bool
- is_person = classes == self.person_class
- return {
- "scores": output.get("scores")[idx2keep],
- "segmentation": segmentation,
- "classes": output.get("pred_classes")[idx2keep],
- "is_person": is_person
- }
-
diff --git a/dp2/detection/person_detector.py b/dp2/detection/person_detector.py
deleted file mode 100644
index 1bbd0df8c2aa44839a5de8bd9a6aeede054ff2ee..0000000000000000000000000000000000000000
--- a/dp2/detection/person_detector.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import torch
-import lzma
-from dp2.detection.base import BaseDetector
-from .utils import combine_cse_maskrcnn_dets
-from .models.cse import CSEDetector
-from .models.mask_rcnn import MaskRCNNDetector
-from .models.keypoint_maskrcnn import KeypointMaskRCNN
-from .structures import CSEPersonDetection, PersonDetection
-from pathlib import Path
-
-
-class CSEPersonDetector(BaseDetector):
- def __init__(
- self,
- score_threshold: float,
- mask_rcnn_cfg: dict,
- cse_cfg: dict,
- cse_post_process_cfg: dict,
- **kwargs
- ) -> None:
- super().__init__(**kwargs)
- self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
- self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold)
- self.post_process_cfg = cse_post_process_cfg
- self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold")
-
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
-
- def load_from_cache(self, cache_path: Path):
- with lzma.open(cache_path, "rb") as fp:
- state_dict = torch.load(fp)
- kwargs = dict(
- post_process_cfg=self.post_process_cfg,
- embed_map=self.cse_detector.embed_map,
- )
- return [
- state["cls"].from_state_dict(**kwargs, state_dict=state)
- for state in state_dict
- ]
-
- @torch.no_grad()
- def forward(self, im: torch.Tensor, cse_dets=None):
- mask_dets = self.mask_rcnn(im)
- if cse_dets is None:
- cse_dets = self.cse_detector(im)
- segmentation = mask_dets["segmentation"]
- segmentation, cse_dets, _ = combine_cse_maskrcnn_dets(
- segmentation, cse_dets, self.iou_combine_threshold
- )
- det = CSEPersonDetection(
- segmentation=segmentation,
- cse_dets=cse_dets,
- embed_map=self.cse_detector.embed_map,
- orig_imshape_CHW=im.shape,
- **self.post_process_cfg
- )
- return [det]
-
-
-class MaskRCNNPersonDetector(BaseDetector):
- def __init__(
- self,
- score_threshold: float,
- mask_rcnn_cfg: dict,
- cse_post_process_cfg: dict,
- **kwargs
- ) -> None:
- super().__init__(**kwargs)
- self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
- self.post_process_cfg = cse_post_process_cfg
-
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
-
- def load_from_cache(self, cache_path: Path):
- with lzma.open(cache_path, "rb") as fp:
- state_dict = torch.load(fp)
- kwargs = dict(
- post_process_cfg=self.post_process_cfg,
- )
- return [
- state["cls"].from_state_dict(**kwargs, state_dict=state)
- for state in state_dict
- ]
-
- @torch.no_grad()
- def forward(self, im: torch.Tensor):
- mask_dets = self.mask_rcnn(im)
- segmentation = mask_dets["segmentation"]
- det = PersonDetection(
- segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape
- )
- return [det]
-
-
-class KeypointMaskRCNNPersonDetector(BaseDetector):
- def __init__(
- self,
- score_threshold: float,
- mask_rcnn_cfg: dict,
- cse_post_process_cfg: dict,
- **kwargs
- ) -> None:
- super().__init__(**kwargs)
- self.mask_rcnn = KeypointMaskRCNN(
- **mask_rcnn_cfg, score_threshold=score_threshold
- )
- self.post_process_cfg = cse_post_process_cfg
-
- def __call__(self, *args, **kwargs):
- return self.forward(*args, **kwargs)
-
- def load_from_cache(self, cache_path: Path):
- with lzma.open(cache_path, "rb") as fp:
- state_dict = torch.load(fp)
- kwargs = dict(
- post_process_cfg=self.post_process_cfg,
- )
- return [
- state["cls"].from_state_dict(**kwargs, state_dict=state)
- for state in state_dict
- ]
-
- @torch.no_grad()
- def forward(self, im: torch.Tensor):
- mask_dets = self.mask_rcnn(im)
- segmentation = mask_dets["segmentation"]
- det = PersonDetection(
- segmentation,
- **self.post_process_cfg,
- orig_imshape_CHW=im.shape,
- keypoints=mask_dets["keypoints"]
- )
- return [det]
diff --git a/dp2/detection/structures.py b/dp2/detection/structures.py
deleted file mode 100644
index 3c78de781612b118b2d12e318f78100301439a95..0000000000000000000000000000000000000000
--- a/dp2/detection/structures.py
+++ /dev/null
@@ -1,464 +0,0 @@
-import torch
-import numpy as np
-from dp2 import utils
-from dp2.utils import vis_utils, crop_box
-from .utils import (
- cut_pad_resize, masks_to_boxes,
- get_kernel, transform_embedding, initialize_cse_boxes
- )
-from .box_utils import get_expanded_bbox, include_box
-import torchvision
-import tops
-from .box_utils_fdf import expand_bbox as expand_bbox_fdf
-
-
-class VehicleDetection:
-
- def __init__(self, segmentation: torch.BoolTensor) -> None:
- self.segmentation = segmentation
- self.boxes = masks_to_boxes(segmentation)
- assert self.boxes.shape[1] == 4, self.boxes.shape
- self.n_detections = self.segmentation.shape[0]
- area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0])
-
- sorted_idx = torch.argsort(area, descending=True)
- self.segmentation = self.segmentation[sorted_idx]
- self.boxes = self.boxes[sorted_idx].cpu()
-
- def pre_process(self):
- pass
-
- def get_crop(self, idx: int, im):
- assert idx < len(self)
- box = self.boxes[idx]
- im = crop_box(self.im, box)
- mask = crop_box(self.segmentation[idx])
- mask = mask == 0
- return dict(img=im, mask=mask.float(), boxes=box)
-
- def visualize(self, im):
- if len(self) == 0:
- return im
- im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not())
- return im
-
- def __len__(self):
- return self.n_detections
-
- @staticmethod
- def from_state_dict(state_dict, **kwargs):
- numel = np.prod(state_dict["shape"])
- arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel)
- segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"])
- return VehicleDetection(segmentation)
-
- def state_dict(self, **kwargs):
- segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy()))
- return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape)
-
-
-class FaceDetection:
-
- def __init__(self, boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool, **kwargs) -> None:
- self.boxes = boxes_ltrb.cpu()
- assert self.boxes.shape[1] == 4, self.boxes.shape
- self.target_imsize = tuple(target_imsize)
- # Sory by area to paste in largest faces last
- area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
- idx = area.argsort(descending=False)
- self.boxes = self.boxes[idx]
- self.fdf128_expand = fdf128_expand
-
- def visualize(self, im):
- if len(self) == 0:
- return im
- orig_device = im.device
- for box in self.boxes:
- simple_expand = False if self.fdf128_expand else True
- e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand))
- im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2)
- im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
-
- return im.to(device=orig_device)
-
- def get_crop(self, idx: int, im):
- assert idx < len(self)
- box = self.boxes[idx].numpy()
- simple_expand = False if self.fdf128_expand else True
- expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], simple_expand=simple_expand)
- im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True)
- area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
-
- # Find the square mask corresponding to box.
- box_mask = box.copy().astype(float)
- box_mask[[0, 2]] -= expanded_boxes[0]
- box_mask[[1, 3]] -= expanded_boxes[1]
-
- width = expanded_boxes[2] - expanded_boxes[0]
- resize_factor = self.target_imsize[0] / width
- box_mask = (box_mask * resize_factor).astype(int)
- mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32)
- crop_box(mask, box_mask).fill_(0)
- return dict(
- img=im[None], mask=mask[None],
- boxes=torch.from_numpy(expanded_boxes).view(1, -1))
-
- def __len__(self):
- return len(self.boxes)
-
- @staticmethod
- def from_state_dict(state_dict, **kwargs):
- return FaceDetection(state_dict["boxes"].cpu(), **kwargs)
-
- def state_dict(self, **kwargs):
- return dict(boxes=self.boxes, cls=self.__class__)
-
- def pre_process(self):
- pass
-
-
-def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape):
- """
- Dilation happens after padding, which could place dilation in the padded area.
- Remove this.
- """
- x0, y0, x1, y1 = exp_box
- H, W = orig_imshape
- # Padding in original image space
- p_y0 = max(0, -y0)
- p_y1 = max(y1 - H, 0)
- p_x0 = max(0, -x0)
- p_x1 = max(x1 - W, 0)
- resize_ratio = mask.shape[-2] / (y1-y0)
- p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]]
- mask[..., :p_y0, :] = 0
- mask[..., :p_x0] = 0
- mask[..., mask.shape[-2] - p_y1:, :] = 0
- mask[..., mask.shape[-1] - p_x1:] = 0
-
-
-class CSEPersonDetection:
-
- def __init__(self,
- segmentation, cse_dets,
- target_imsize,
- exp_bbox_cfg, exp_bbox_filter,
- dilation_percentage: float,
- embed_map: torch.Tensor,
- orig_imshape_CHW,
- normalize_embedding: bool) -> None:
- self.segmentation = segmentation
- self.cse_dets = cse_dets
- self.target_imsize = list(target_imsize)
- self.pre_processed = False
- self.exp_bbox_cfg = exp_bbox_cfg
- self.exp_bbox_filter = exp_bbox_filter
- self.dilation_percentage = dilation_percentage
- self.embed_map = embed_map
- self.normalize_embedding = normalize_embedding
- if self.normalize_embedding:
- embed_map_mean = self.embed_map.mean(dim=0, keepdim=True)
- embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
- self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd
- self.orig_imshape_CHW = orig_imshape_CHW
-
- @torch.no_grad()
- def pre_process(self):
- if self.pre_processed:
- return
- boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu()
- expanded_boxes = []
- included_boxes = []
- for i in range(len(boxes)):
- exp_box = get_expanded_bbox(
- boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
- target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
- if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
- continue
- included_boxes.append(i)
- expanded_boxes.append(exp_box)
- expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
- self.segmentation = self.segmentation[included_boxes]
- self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()}
-
- self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
- area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
- for i, box in enumerate(expanded_boxes):
- self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
-
- dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
- self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
- self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
- [remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))]
- self.boxes = expanded_boxes.cpu()
- self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
-
- self.pre_processed = True
- self.n_detections = len(self.boxes)
- self.mask = self.mask.logical_not()
-
- E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool)
- self.vertices = torch.zeros_like(E_mask, dtype=torch.long)
- for i in range(self.n_detections):
- E_, E_mask[i] = transform_embedding(
- self.cse_dets["instance_embedding"][i],
- self.cse_dets["instance_segmentation"][i],
- self.boxes[i],
- self.cse_dets["bbox_XYXY"][i].cpu(),
- self.target_imsize
- )
- self.vertices[i] = utils.from_E_to_vertex(E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None]
- self.E_mask = E_mask
-
- sorted_idx = torch.argsort(area, descending=False)
- self.mask = self.mask[sorted_idx]
- self.boxes = self.boxes[sorted_idx.cpu()]
- self.vertices = self.vertices[sorted_idx]
- self.E_mask = self.E_mask[sorted_idx]
- self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
-
- def get_crop(self, idx: int, im):
- self.pre_process()
- assert idx < len(self)
- box = self.boxes[idx]
- mask = self.mask[idx]
- im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
-
- vertices_ = self.vertices[idx]
- E_mask_ = self.E_mask[idx].float()
- if self.normalize_embedding:
- embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
- else:
- embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
-
- return dict(
- img=im,
- mask=mask.float()[None],
- boxes=box.reshape(1, -1),
- E_mask=E_mask_[None],
- vertices=vertices_[None],
- embed_map=self.embed_map,
- embedding=embedding[None],
- maskrcnn_mask=self.maskrcnn_mask[idx].float()[None]
- )
-
- def __len__(self):
- self.pre_process()
- return self.n_detections
-
- def state_dict(self, after_preprocess=False):
- """
- The processed annotations occupy more space than the original detections.
- """
- if not after_preprocess:
- return {
- "combined_segmentation": self.segmentation.bool(),
- "cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(),
- "cse_instance_embedding": self.cse_dets["instance_embedding"],
- "cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(),
- "cls": self.__class__,
- "orig_imshape_CHW": self.orig_imshape_CHW
- }
- self.pre_process()
- return dict(
- E_mask=torch.from_numpy(np.packbits(self.E_mask.bool().cpu().numpy())),
- mask=torch.from_numpy(np.packbits(self.mask.bool().cpu().numpy())),
- maskrcnn_mask=torch.from_numpy(np.packbits(self.maskrcnn_mask.bool().cpu().numpy())),
- vertices=self.vertices.to(torch.int16).cpu(),
- cls=self.__class__,
- boxes=self.boxes,
- orig_imshape_CHW=self.orig_imshape_CHW,
- )
-
- @staticmethod
- def from_state_dict(
- state_dict, embed_map,
- post_process_cfg, **kwargs):
- after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict
- if after_preprocess:
- detection = CSEPersonDetection(
- segmentation=None, cse_dets=None, embed_map=embed_map,
- orig_imshape_CHW=state_dict["orig_imshape_CHW"],
- **post_process_cfg)
- detection.vertices = tops.to_cuda(state_dict["vertices"].long())
- numel = np.prod(detection.vertices.shape)
- detection.E_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["E_mask"].numpy(), count=numel))).view(*detection.vertices.shape)
- detection.mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["mask"].numpy(), count=numel))).view(*detection.vertices.shape)
- detection.maskrcnn_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["maskrcnn_mask"].numpy(), count=numel))).view(*detection.vertices.shape)
- detection.n_detections = len(detection.mask)
- detection.pre_processed = True
-
- if isinstance(state_dict["boxes"], np.ndarray):
- state_dict["boxes"] = torch.from_numpy(state_dict["boxes"])
- detection.boxes = state_dict["boxes"]
- return detection
-
- cse_dets = dict(
- instance_segmentation=state_dict["cse_instance_segmentation"],
- instance_embedding=state_dict["cse_instance_embedding"],
- embed_map=embed_map,
- bbox_XYXY=state_dict["cse_bbox_XYXY"])
- cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()}
-
- segmentation = state_dict["combined_segmentation"]
- return CSEPersonDetection(
- segmentation, cse_dets, embed_map=embed_map,
- orig_imshape_CHW=state_dict["orig_imshape_CHW"],
- **post_process_cfg)
-
- def visualize(self, im):
- self.pre_process()
- if len(self) == 0:
- return im
- im = vis_utils.draw_cropped_masks(
- im.clone(), self.mask, self.boxes, visualize_instances=False)
- E = self.embed_map[self.vertices.long()].squeeze(1).permute(0,3, 1, 2)
- im = im.to(E.device)
- im = vis_utils.draw_cse_all(
- E, self.E_mask.squeeze(1).bool(), im,
- self.boxes, self.embed_map)
- im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
- return im
-
-
-def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes):
- keypoints = keypoints.clone()
- N = boxes.shape[0]
- tops.assert_shape(keypoints, (N, None, 3))
- tops.assert_shape(boxes, (N, 4))
- x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T]
-
- w = x1 - x0
- h = y1 - y0
- keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w
- keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h
- check_outside = lambda x: (x < 0).logical_or(x > 1)
- is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1]))
- keypoints[:, :, 2] = keypoints[:, :, 2] >= 0
- keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not())
- return keypoints
-
-
-class PersonDetection:
-
- def __init__(
- self,
- segmentation,
- target_imsize,
- exp_bbox_cfg, exp_bbox_filter,
- dilation_percentage: float,
- orig_imshape_CHW,
- keypoints=None,
- **kwargs) -> None:
- self.segmentation = segmentation
- self.target_imsize = list(target_imsize)
- self.pre_processed = False
- self.exp_bbox_cfg = exp_bbox_cfg
- self.exp_bbox_filter = exp_bbox_filter
- self.dilation_percentage = dilation_percentage
- self.orig_imshape_CHW = orig_imshape_CHW
- self.keypoints = keypoints
-
- @torch.no_grad()
- def pre_process(self):
- if self.pre_processed:
- return
- boxes = masks_to_boxes(self.segmentation).cpu()
- expanded_boxes = []
- included_boxes = []
- for i in range(len(boxes)):
- exp_box = get_expanded_bbox(
- boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
- target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
- if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
- continue
- included_boxes.append(i)
- expanded_boxes.append(exp_box)
- expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
- self.segmentation = self.segmentation[included_boxes]
- if self.keypoints is not None:
- self.keypoints = self.keypoints[included_boxes]
- area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
- self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
- for i, box in enumerate(expanded_boxes):
- self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
- if self.keypoints is not None:
- self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes)
- dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
- self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
- self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
-
- [remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))]
- self.boxes = expanded_boxes
- self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
-
- self.pre_processed = True
- self.n_detections = len(self.boxes)
- self.mask = self.mask.logical_not()
-
- sorted_idx = torch.argsort(area, descending=False)
- self.mask = self.mask[sorted_idx]
- self.boxes = self.boxes[sorted_idx.cpu()]
- self.segmentation = self.segmentation[sorted_idx]
- self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
- if self.keypoints is not None:
- self.keypoints = self.keypoints[sorted_idx]
-
- def get_crop(self, idx: int, im: torch.Tensor):
- assert idx < len(self)
- self.pre_process()
- box = self.boxes[idx]
- mask = self.mask[idx][None].float()
- im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
- batch = dict(
- img=im, mask=mask, boxes=box.reshape(1, -1),
- maskrcnn_mask=self.maskrcnn_mask[idx][None].float())
- if self.keypoints is not None:
- batch["keypoints"] = self.keypoints[idx:idx+1]
- return batch
-
- def __len__(self):
- self.pre_process()
- return self.n_detections
-
- def state_dict(self, **kwargs):
- return dict(
- segmentation=self.segmentation.bool(),
- cls=self.__class__,
- orig_imshape_CHW=self.orig_imshape_CHW,
- keypoints=self.keypoints
- )
-
- @staticmethod
- def from_state_dict(
- state_dict,
- post_process_cfg, **kwargs):
- return PersonDetection(
- state_dict["segmentation"],
- orig_imshape_CHW=state_dict["orig_imshape_CHW"],
- **post_process_cfg,
- keypoints=state_dict["keypoints"])
-
- def visualize(self, im):
- self.pre_process()
- im = im.cpu()
- if len(self) == 0:
- return im
- im = vis_utils.draw_cropped_masks(im.clone(), self.mask, self.boxes, visualize_instances=False)
- im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes)
- return im
-
-
-def get_dilated_boxes(exp_bbox: torch.LongTensor, mask):
- """
- mask: resized mask
- """
- assert exp_bbox.shape[0] == mask.shape[0]
- boxes = masks_to_boxes(mask.squeeze(1)).cpu()
- H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0]
- boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long()
- boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long()
- boxes[:, [0, 2]] += exp_bbox[:, 0:1]
- boxes[:, [1, 3]] += exp_bbox[:, 1:2]
- return boxes
-
diff --git a/dp2/detection/utils.py b/dp2/detection/utils.py
deleted file mode 100644
index 85dbd29c1832d2f48b20ed14c3d0357c958732f6..0000000000000000000000000000000000000000
--- a/dp2/detection/utils.py
+++ /dev/null
@@ -1,174 +0,0 @@
-import cv2
-import numpy as np
-import torch
-import tops
-from skimage.morphology import disk
-from torchvision.transforms.functional import resize, InterpolationMode
-from functools import lru_cache
-
-
-@lru_cache(maxsize=200)
-def get_kernel(n: int):
- kernel = disk(n, dtype=bool)
- return tops.to_cuda(torch.from_numpy(kernel).bool())
-
-
-def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape):
- """
- Transforms the detected embedding/mask directly to the target image shape
- """
-
- C, HE, WE = E.shape
- assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox)
- assert E_bbox[2] >= exp_bbox[0]
- assert E_bbox[1] >= exp_bbox[1]
- assert E_bbox[3] >= exp_bbox[1]
- assert E_bbox[2] <= exp_bbox[2]
- assert E_bbox[3] <= exp_bbox[3]
-
- x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
- x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
- y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
- y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
- new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32)
- new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool)
-
- E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)
- new_E[:, y0:y1, x0:x1] = E
- S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0
- new_S[y0:y1, x0:x1] = S
- return new_E, new_S
-
-
-def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor):
- """
- mask: shape [N, H, W]
- """
- assert len(mask1.shape) == 3
- assert len(mask2.shape) == 3
- assert mask1.device == mask2.device, (mask1.device, mask2.device)
- assert mask2.dtype == mask2.dtype
- assert mask1.dtype == torch.bool
- assert mask1.shape[1:] == mask2.shape[1:]
- N1, H1, W1 = mask1.shape
- N2, H2, W2 = mask2.shape
- iou = torch.zeros((N1, N2), dtype=torch.float32)
- for i in range(N1):
- cur = mask1[i:i+1]
- inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
- union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
- iou[i] = inter / union
- return iou
-
-
-def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float):
- N1 = mask1.shape[0]
- N2 = mask2.shape[0]
- ious = pairwise_mask_iou(mask1, mask2).cpu().numpy()
- indices = np.array([idx for idx, iou in np.ndenumerate(ious)])
- ious = ious.flatten()
- mask = ious >= iou_threshold
- ious = ious[mask]
- indices = indices[mask]
-
- # do not sort by iou to keep ordering of mask rcnn / cse sorting.
- taken1 = np.zeros((N1), dtype=bool)
- taken2 = np.zeros((N2), dtype=bool)
- matches = []
- for i, j in indices:
- if taken1[i].any() or taken2[j].any():
- continue
- matches.append((i, j))
- taken1[i] = True
- taken2[j] = True
- return matches
-
-
-def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float):
- assert 0 < iou_threshold <= 1
- matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold)
- H, W = segmentation.shape[1:]
- new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device)
- cse_im_seg = cse_dets["im_segmentation"]
- for idx, (i, j) in enumerate(matches):
- new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j])
- cse_dets = dict(
- instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]],
- instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]],
- bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]],
- scores=cse_dets["scores"][[j for (i, j) in matches]],
- )
- return new_seg, cse_dets, np.array(matches).reshape(-1, 2)
-
-
-def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor):
- """
- cse_boxes can be outside of segmentation.
- """
- boxes = masks_to_boxes(segmentation)
-
- assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape)
- combined = torch.stack((boxes, cse_boxes), dim=-1)
- boxes = torch.cat((
- combined[:, :2].min(dim=2).values,
- combined[:, 2:].max(dim=2).values,
- ), dim=1)
- return boxes
-
-
-def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False):
- """
- Crops or pads x to fit in the bbox and resize to target shape.
- """
- C, H, W = x.shape
- x0, y0, x1, y1 = bbox
-
- if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H:
- new_x = x[:, y0:y1, x0:x1]
- else:
- new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device)
- y0_t = max(0, -y0)
- y1_t = min(y1-y0, (y1-y0)-(y1-H))
- x0_t = max(0, -x0)
- x1_t = min(x1-x0, (x1-x0)-(x1-W))
- x0 = max(0, x0)
- y0 = max(0, y0)
- x1 = min(x1, W)
- y1 = min(y1, H)
- new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1]
- if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]:
- return new_x
- if x.dtype == torch.bool:
- new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5
- elif x.dtype == torch.float32:
- new_x = resize(new_x, target_shape, interpolation=InterpolationMode.BILINEAR, antialias=True)
- elif x.dtype == torch.uint8:
- if fdf_resize: # FDF dataset is created with cv2 INTER_AREA.
- # Incorrect resizing generates noticeable poorer inpaintings.
- upsampling = ((y1-y0) *(x1-x0)) < (target_shape[0] * target_shape[1])
- if upsampling:
- new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC, antialias=True).round().clamp(0, 255).byte()
- else:
- device = new_x.device
- new_x = new_x.permute(1, 2, 0).cpu().numpy()
- new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA)
- new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device)
- else:
- new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BILINEAR, antialias=True).round().clamp(0, 255).byte()
- else:
- raise ValueError(f"Not supported dtype: {x.dtype}")
- return new_x
-
-
-
-def masks_to_boxes(segmentation: torch.Tensor):
- assert len(segmentation.shape) == 3
- x = segmentation.any(dim=1).byte() # Compress rows
- x0 = x.argmax(dim=1)
-
- x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1)
- y = segmentation.any(dim=2).byte()
- y0 = y.argmax(dim=1)
- y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1)
- return torch.stack([x0, y0, x1, y1], dim=1)
-
diff --git a/dp2/discriminator/__init__.py b/dp2/discriminator/__init__.py
deleted file mode 100644
index 77a4a773eafecde739caf086660063a19cf2160f..0000000000000000000000000000000000000000
--- a/dp2/discriminator/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .sg2_discriminator import SG2Discriminator
\ No newline at end of file
diff --git a/dp2/discriminator/sg2_discriminator.py b/dp2/discriminator/sg2_discriminator.py
deleted file mode 100644
index 5f44e8f9d6cfb46c0763cc8ab5d96037fcd2c4e2..0000000000000000000000000000000000000000
--- a/dp2/discriminator/sg2_discriminator.py
+++ /dev/null
@@ -1,76 +0,0 @@
-from sg3_torch_utils.ops import upfirdn2d
-import torch
-import numpy as np
-import torch.nn as nn
-from .. import layers
-from ..layers.sg2_layers import DiscriminatorEpilogue, ResidualBlock, Block
-
-
-class SG2Discriminator(layers.Module):
-
- def __init__(
- self,
- cnum: int,
- max_cnum_mul: int,
- imsize,
- min_fmap_resolution: int,
- im_channels: int,
- input_condition: bool,
- conv_clamp: int,
- input_cse: bool,
- cse_nc: int):
- super().__init__()
-
- cse_nc = 0 if cse_nc is None else cse_nc
- self._max_imsize = max(imsize)
- self._cnum = cnum
- self._max_cnum_mul = max_cnum_mul
- self._min_fmap_resolution = min_fmap_resolution
- self._input_condition = input_condition
- self.input_cse = input_cse
- self.layers = nn.ModuleList()
-
- out_ch = self.get_chsize(self._max_imsize)
- self.from_rgb = Block(
- im_channels + input_condition*(im_channels+1) + input_cse*(cse_nc+1),
- out_ch, conv_clamp=conv_clamp
- )
- n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1
-
- for i in range(n_levels):
- resolution = [x//2**i for x in imsize]
- in_ch = self.get_chsize(max(resolution))
- out_ch = self.get_chsize(max(max(resolution)//2, min_fmap_resolution))
-
- down = 2
- if i == 0:
- down = 1
- block = ResidualBlock(
- in_ch, out_ch, down=down, conv_clamp=conv_clamp
- )
- self.layers.append(block)
- self.output_layer = DiscriminatorEpilogue(
- out_ch, resolution, conv_clamp=conv_clamp)
-
- self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1]))
-
- def forward(self, img, condition, mask, embedding=None, E_mask=None,**kwargs):
- to_cat = [img]
- if self._input_condition:
- to_cat.extend([condition, mask,])
- if self.input_cse:
- to_cat.extend([embedding, E_mask])
- x = torch.cat(to_cat, dim=1)
- x = self.from_rgb(x)
-
- for i, layer in enumerate(self.layers):
- x = layer(x)
-
- x = self.output_layer(x)
- return dict(score=x)
-
- def get_chsize(self, imsize):
- n = int(np.log2(self._max_imsize) - np.log2(imsize))
- mul = min(2 ** n, self._max_cnum_mul)
- ch = self._cnum * mul
- return int(ch)
diff --git a/dp2/gan_trainer.py b/dp2/gan_trainer.py
deleted file mode 100644
index e9cdbcd86f596c7be984b5463e486aafe6a0890e..0000000000000000000000000000000000000000
--- a/dp2/gan_trainer.py
+++ /dev/null
@@ -1,324 +0,0 @@
-import atexit
-from collections import defaultdict
-import logging
-import typing
-import torch
-import time
-from dp2.utils import vis_utils
-from dp2 import utils
-from tops import logger, checkpointer
-import tops
-from easydict import EasyDict
-
-
-def accumulate_gradients(params, fp16_ddp_accumulate):
- if len(params) == 0:
- return
- params = [param for param in params if param.grad is not None]
- flat = torch.cat([param.grad.flatten() for param in params])
- orig_dtype = flat.dtype
- if tops.world_size() > 1:
- if fp16_ddp_accumulate:
- flat = flat.half() / tops.world_size()
- else:
- flat /= tops.world_size()
- torch.distributed.all_reduce(flat)
- flat = flat.to(orig_dtype)
- grads = flat.split([param.numel() for param in params])
- for param, grad in zip(params, grads):
- param.grad = grad.reshape(param.shape)
-
-
-def accumulate_buffers(module: torch.nn.Module):
- buffers = [buf for buf in module.buffers()]
- if len(buffers) == 0:
- return
- flat = torch.cat([buf.flatten() for buf in buffers])
- if tops.world_size() > 1:
- torch.distributed.all_reduce(flat)
- flat /= tops.world_size()
- bufs = flat.split([buf.numel() for buf in buffers])
- for old, new in zip(buffers, bufs):
- old.copy_(new.reshape(old.shape), non_blocking=True)
-
-
-def check_ddp_consistency(module):
- if tops.world_size() == 1:
- return
- assert isinstance(module, torch.nn.Module)
- assert isinstance(module, torch.nn.Module)
- params_buffs = list(module.named_parameters()) + list(module.named_buffers())
- for name, tensor in params_buffs:
- fullname = type(module).__name__ + '.' + name
- tensor = tensor.detach()
- if tensor.is_floating_point():
- tensor = torch.nan_to_num(tensor)
- other = tensor.clone()
- torch.distributed.broadcast(tensor=other, src=0)
- assert (tensor == other).all(), fullname
-
-class AverageMeter():
- def __init__(self) -> None:
- self.to_log = dict()
- self.n = defaultdict(int)
- pass
-
- @torch.no_grad()
- def update(self, values: dict):
- for key, value in values.items():
- self.n[key] += 1
- if key in self.to_log:
- self.to_log[key] += value.mean().detach()
- else:
- self.to_log[key] = value.mean().detach()
-
- def get_average(self):
- return {key: value / self.n[key] for key, value in self.to_log.items()}
-
-
-class GANTrainer:
-
- def __init__(
- self,
- G: torch.nn.Module,
- D: torch.nn.Module,
- G_EMA: torch.nn.Module,
- D_optim: torch.optim.Optimizer,
- G_optim: torch.optim.Optimizer,
- dl_train: typing.Iterator,
- dl_val: typing.Iterable,
- scaler_D: torch.cuda.amp.GradScaler,
- scaler_G: torch.cuda.amp.GradScaler,
- ims_per_log: int,
- max_images_to_train: int,
- loss_handler,
- ims_per_val: int,
- evaluate_fn,
- batch_size: int,
- broadcast_buffers: bool,
- fp16_ddp_accumulate: bool,
- save_state: bool,
- *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- self.G = G
- self.D = D
- self.G_EMA = G_EMA
- self.D_optim = D_optim
- self.G_optim = G_optim
- self.dl_train = dl_train
- self.dl_val = dl_val
- self.scaler_D = scaler_D
- self.scaler_G = scaler_G
- self.loss_handler = loss_handler
- self.max_images_to_train = max_images_to_train
- self.images_per_val = ims_per_val
- self.images_per_log = ims_per_log
- self.evaluate_fn = evaluate_fn
- self.batch_size = batch_size
- self.broadcast_buffers = broadcast_buffers
- self.fp16_ddp_accumulate = fp16_ddp_accumulate
-
- self.train_state = EasyDict(
- next_log_step=0,
- next_val_step=ims_per_val,
- total_time=0
- )
-
- checkpointer.register_models(dict(
- generator=G, discriminator=D, EMA_generator=G_EMA,
- D_optimizer=D_optim,
- G_optimizer=G_optim,
- train_state=self.train_state,
- scaler_D=self.scaler_D,
- scaler_G=self.scaler_G
- ))
- if checkpointer.has_checkpoint():
- checkpointer.load_registered_models()
- logger.log(f"Resuming training from: global step: {logger.global_step()}")
- else:
- logger.add_dict({
- "stats/discriminator_parameters": tops.num_parameters(self.D),
- "stats/generator_parameters": tops.num_parameters(self.G),
- }, commit=False)
- if save_state:
- # If the job is unexpectedly killed, there could be a mismatch between previously saved checkpoint and the current checkpoint.
- atexit.register(checkpointer.save_registered_models)
-
- self._ims_per_log = ims_per_log
-
- self.to_log = AverageMeter()
- self.trainable_params_D = [param for param in self.D.parameters() if param.requires_grad]
- self.trainable_params_G = [param for param in self.G.parameters() if param.requires_grad]
- logger.add_dict({
- "stats/discriminator_trainable_parameters": sum(p.numel() for p in self.trainable_params_D),
- "stats/generator_trainable_parameters": sum(p.numel() for p in self.trainable_params_G),
- }, commit=False, level=logging.INFO)
- check_ddp_consistency(self.D)
- check_ddp_consistency(self.G)
- check_ddp_consistency(self.G_EMA.generator)
-
- def train_loop(self):
- self.log_time()
- while logger.global_step() <= self.max_images_to_train:
- batch = next(self.dl_train)
- self.G_EMA.update_beta()
- self.to_log.update(self.step_D(batch))
- self.to_log.update(self.step_G(batch))
- self.G_EMA.update(self.G)
-
- if logger.global_step() >= self.train_state.next_log_step:
- to_log = {f"loss/{key}": item.item() for key, item in self.to_log.get_average().items()}
- to_log.update({"amp/grad_scale_G": self.scaler_G.get_scale()})
- to_log.update({"amp/grad_scale_D": self.scaler_D.get_scale()})
- self.to_log = AverageMeter()
- logger.add_dict(to_log, commit=True)
- self.train_state.next_log_step += self.images_per_log
- if self.scaler_D.get_scale() < 1e-8 or self.scaler_G.get_scale() < 1e-8:
- print("Stopping training as gradient scale < 1e-8")
- logger.log("Stopping training as gradient scale < 1e-8")
- break
-
- if logger.global_step() >= self.train_state.next_val_step:
- self.evaluate()
- self.log_time()
- self.save_images()
- self.train_state.next_val_step += self.images_per_val
- logger.step(self.batch_size*tops.world_size())
- logger.log(f"Reached end of training at step {logger.global_step()}.")
- checkpointer.save_registered_models()
-
- def estimate_ims_per_hour(self):
- batch = next(self.dl_train)
- n_ims = int(100e3)
- n_steps = int(n_ims / (self.batch_size * tops.world_size()))
- n_ims = n_steps * self.batch_size * tops.world_size()
- for i in range(10): # Warmup
- self.G_EMA.update_beta()
- self.step_D(batch)
- self.step_G(batch)
- self.G_EMA.update(self.G)
- start_time = time.time()
- for i in utils.tqdm_(list(range(n_steps))):
- self.G_EMA.update_beta()
- self.step_D(batch)
- self.step_G(batch)
- self.G_EMA.update(self.G)
- total_time = time.time() - start_time
- ims_per_sec = n_ims / total_time
- ims_per_hour = ims_per_sec * 60*60
- ims_per_day = ims_per_hour * 24
- logger.log(f"Images per hour: {ims_per_hour/1e6:.3f}M")
- logger.log(f"Images per day: {ims_per_day/1e6:.3f}M")
- import math
- ims_per_4_day = int(math.ceil(ims_per_day / tops.world_size() * 4))
- logger.log(f"Images per 4 days: {ims_per_4_day}")
- logger.add_dict({
- "stats/ims_per_day": ims_per_day,
- "stats/ims_per_4_day": ims_per_4_day
- })
-
- def log_time(self):
- if not hasattr(self, "start_time"):
- self.start_time = time.time()
- self.last_time_step = logger.global_step()
- return
- n_images = logger.global_step() - self.last_time_step
- if n_images == 0:
- return
- n_secs = time.time() - self.start_time
- n_ims_per_sec = n_images / n_secs
- training_time_hours = n_secs / 60/ 60
- self.train_state.total_time += training_time_hours
- remaining_images = self.max_images_to_train - logger.global_step()
- remaining_time = remaining_images / n_ims_per_sec / 60 / 60
- logger.add_dict({
- "stats/n_ims_per_sec": n_ims_per_sec,
- "stats/total_traing_time_hours": self.train_state.total_time,
- "stats/remaining_time_hours": remaining_time
- })
- self.last_time_step = logger.global_step()
- self.start_time = time.time()
-
- def save_images(self):
- dl_val = iter(self.dl_val)
- batch = next(dl_val)
- # TRUNCATED visualization
- ims_to_log = 8
- self.G_EMA.eval()
- z = self.G.get_z(batch["img"])
- fakes_truncated = self.G_EMA.sample(**batch, truncation_value=0)["img"]
- fakes_truncated = utils.denormalize_img(fakes_truncated).mul(255).byte()[:ims_to_log].cpu()
- if "__key__" in batch:
- batch.pop("__key__")
- real = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log]
- to_vis = torch.cat((real, fakes_truncated))
- logger.add_images("images/truncated", to_vis, nrow=2)
-
- # Diverse images
- ims_diverse = 3
- batch = next(dl_val)
- to_vis = []
-
- for i in range(ims_diverse):
- z = self.G.get_z(batch["img"])[:1].repeat(batch["img"].shape[0], 1)
- fakes = utils.denormalize_img(self.G_EMA(**batch, z=z)["img"]).mul(255).byte()[:ims_to_log].cpu()
- to_vis.append(fakes)
- if "__key__" in batch:
- batch.pop("__key__")
- reals = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log]
- to_vis.insert(0, reals)
- to_vis = torch.cat(to_vis)
- logger.add_images("images/diverse", to_vis, nrow=ims_diverse+1)
-
- self.G_EMA.train()
- pass
-
- def evaluate(self):
- logger.log("Stating evaluation.")
- self.G_EMA.eval()
- try:
- checkpointer.save_registered_models(max_keep=3)
- except Exception:
- logger.log("Could not save checkpoint.")
- if self.broadcast_buffers:
- check_ddp_consistency(self.G)
- check_ddp_consistency(self.D)
- metrics = self.evaluate_fn(generator=self.G_EMA, dataloader=self.dl_val)
- metrics = {f"metrics/{k}": v for k,v in metrics.items()}
- logger.add_dict(metrics, level=logger.logger.INFO)
-
- def step_D(self, batch):
- utils.set_requires_grad(self.trainable_params_D, True)
- utils.set_requires_grad(self.trainable_params_G, False)
- tops.zero_grad(self.D)
- loss, to_log = self.loss_handler.D_loss(batch, grad_scaler=self.scaler_D)
- with torch.autograd.profiler.record_function("D_step"):
- self.scaler_D.scale(loss).backward()
- accumulate_gradients(self.trainable_params_D, fp16_ddp_accumulate=self.fp16_ddp_accumulate)
- if self.broadcast_buffers:
- accumulate_buffers(self.D)
- accumulate_buffers(self.G)
- # Step will not unscale if unscale is called previously.
- self.scaler_D.step(self.D_optim)
- self.scaler_D.update()
- utils.set_requires_grad(self.trainable_params_D, False)
- utils.set_requires_grad(self.trainable_params_G, False)
- return to_log
-
- def step_G(self, batch):
- utils.set_requires_grad(self.trainable_params_D, False)
- utils.set_requires_grad(self.trainable_params_G, True)
- tops.zero_grad(self.G)
- loss, to_log = self.loss_handler.G_loss(batch, grad_scaler=self.scaler_G)
- with torch.autograd.profiler.record_function("G_step"):
- self.scaler_G.scale(loss).backward()
- accumulate_gradients(self.trainable_params_G, fp16_ddp_accumulate=self.fp16_ddp_accumulate)
- if self.broadcast_buffers:
- accumulate_buffers(self.G)
- accumulate_buffers(self.D)
- self.scaler_G.step(self.G_optim)
- self.scaler_G.update()
- utils.set_requires_grad(self.trainable_params_D, False)
- utils.set_requires_grad(self.trainable_params_G, False)
- return to_log
\ No newline at end of file
diff --git a/dp2/generator/__init__.py b/dp2/generator/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/dp2/generator/base.py b/dp2/generator/base.py
deleted file mode 100644
index 851f00ee2fed4f8e0601405ffd635338280ec993..0000000000000000000000000000000000000000
--- a/dp2/generator/base.py
+++ /dev/null
@@ -1,144 +0,0 @@
-import torch
-import numpy as np
-import tqdm
-import tops
-from ..layers import Module
-from ..layers.sg2_layers import FullyConnectedLayer
-from dp2 import utils
-
-
-class BaseGenerator(Module):
-
- def __init__(self, z_channels: int):
- super().__init__()
- self.z_channels = z_channels
- self.latent_space = "Z"
-
- @torch.no_grad()
- def get_z(
- self,
- x: torch.Tensor = None,
- z: torch.Tensor = None,
- truncation_value: float = None,
- batch_size: int = None,
- dtype=None, device=None) -> torch.Tensor:
- """Generates a latent variable for generator.
- """
- if z is not None:
- return z
- if x is not None:
- batch_size = x.shape[0]
- dtype = x.dtype
- device = x.device
- if device is None:
- device = utils.get_device()
- if truncation_value == 0:
- return torch.zeros((batch_size, self.z_channels), device=device, dtype=dtype)
- z = torch.randn((batch_size, self.z_channels), device=device, dtype=dtype)
- if truncation_value is None:
- return z
- while z.abs().max() > truncation_value:
- m = z.abs() > truncation_value
- z[m] = torch.rand_like(z)[m]
- return z
-
- def sample(self, truncation_value, z=None, **kwargs):
- """
- Samples via interpolating to the mean (0).
- """
- if truncation_value is None:
- return self.forward(**kwargs)
- truncation_value = max(0, truncation_value)
- truncation_value = min(truncation_value, 1)
- if z is None:
- z = self.get_z(kwargs["condition"])
- z = z * truncation_value
- return self.forward(**kwargs, z=z)
-
-
-
-class SG2StyleNet(torch.nn.Module):
- def __init__(self,
- z_dim, # Input latent (Z) dimensionality.
- w_dim, # Intermediate latent (W) dimensionality.
- num_layers = 2, # Number of mapping layers.
- lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
- w_avg_beta = 0.998, # Decay for tracking the moving average of W during training.
- ):
- super().__init__()
- self.z_dim = z_dim
- self.w_dim = w_dim
- self.num_layers = num_layers
- self.w_avg_beta = w_avg_beta
- # Construct layers.
- features = [self.z_dim] + [self.w_dim] * self.num_layers
- for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
- layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
- setattr(self, f'fc{idx}', layer)
- self.register_buffer('w_avg', torch.zeros([w_dim]))
-
- def forward(self, z, update_emas=False, y=None):
- tops.assert_shape(z, [None, self.z_dim])
-
- # Embed, normalize, and concatenate inputs.
- x = z.to(torch.float32)
- x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
- # Execute layers.
- for idx in range(self.num_layers):
- x = getattr(self, f'fc{idx}')(x)
- # Update moving average of W.
- if update_emas:
- self.w_avg.copy_(x.float().detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
-
- return x
-
- def extra_repr(self):
- return f'z_dim={self.z_dim:d}, w_dim={self.w_dim:d}'
-
- def update_w(self, n=int(10e3), batch_size=32):
- """
- Calculate w_ema over n iterations.
- Useful in cases where w_ema is calculated incorrectly during training.
- """
- n = n // batch_size
- for i in tqdm.trange(n, desc="Updating w"):
- z = torch.randn((batch_size, self.z_dim), device=tops.get_device())
- self(z, update_emas=True)
-
-
-class BaseStyleGAN(BaseGenerator):
-
- def __init__(self, z_channels: int, w_dim: int):
- super().__init__(z_channels)
- self.style_net = SG2StyleNet(z_channels, w_dim)
- self.latent_space = "W"
-
- def get_w(self, z, update_emas):
- return self.style_net(z, update_emas=update_emas)
-
- @torch.no_grad()
- def sample(self, truncation_value, **kwargs):
- if truncation_value is None:
- return self.forward(**kwargs)
- truncation_value = max(0, truncation_value)
- truncation_value = min(truncation_value, 1)
- w = self.get_w(self.get_z(kwargs["condition"]), False)
- w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
- return self.forward(**kwargs, w=w)
-
- def update_w(self, *args, **kwargs):
- self.style_net.update_w(*args, **kwargs)
-
-
- @torch.no_grad()
- def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
- if truncation_value is None:
- return self.forward(**kwargs)
- truncation_value = max(0, truncation_value)
- truncation_value = min(truncation_value, 1)
- w = self.get_w(self.get_z(kwargs["condition"]), False)
- if w_indices is None:
- w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
- w_centers = self.style_net.w_centers[w_indices].to(w.device)
- w = w_centers.to(w.dtype).lerp(w, truncation_value)
- return self.forward(**kwargs, w=w)
diff --git a/dp2/generator/dummy_generators.py b/dp2/generator/dummy_generators.py
deleted file mode 100644
index c319e65c8191bf3f77d9d348698bff10c44e0b21..0000000000000000000000000000000000000000
--- a/dp2/generator/dummy_generators.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import torch
-import torch.nn as nn
-from .base import BaseGenerator
-
-
-class PixelationGenerator(BaseGenerator):
-
- def __init__(self, pixelation_size, **kwargs):
- super().__init__(z_channels=0)
- self.pixelation_size = pixelation_size
- self.z_channels = 0
- self.latent_space=None
-
- def forward(self, img, condition, mask, **kwargs):
- old_shape = img.shape[-2:]
- img = nn.functional.interpolate(img, size=(self.pixelation_size, self.pixelation_size), mode="bilinear", align_corners=True)
- img = nn.functional.interpolate(img, size=old_shape, mode="bilinear", align_corners=True)
- out = img*(1-mask) + condition*mask
- return {"img": out}
-
-
-class MaskOutGenerator(BaseGenerator):
-
- def __init__(self, noise: str, **kwargs):
- super().__init__(z_channels=0)
- self.noise = noise
- self.z_channels = 0
- assert self.noise in ["rand", "constant"]
- self.latent_space = None
-
- def forward(self, img, condition, mask, **kwargs):
-
- if self.noise == "constant":
- img = torch.zeros_like(img)
- elif self.noise == "rand":
- img = torch.rand_like(img)
- out = img*(1-mask) + condition*mask
- return {"img": out}
-
-
-class IdentityGenerator(BaseGenerator):
-
- def __init__(self):
- super().__init__(z_channels=0)
-
- def forward(self, img, condition, mask, **kwargs):
- return dict(img=img)
\ No newline at end of file
diff --git a/dp2/generator/imagen3_old.py b/dp2/generator/imagen3_old.py
deleted file mode 100644
index 87f35d7595dd018f1d0db2baddcfeb3789596ed0..0000000000000000000000000000000000000000
--- a/dp2/generator/imagen3_old.py
+++ /dev/null
@@ -1,1210 +0,0 @@
-# What is missing from this implementation
-# 1. Global context in res block
-# 2. Cross attention of conditional information in resnet block
-#
-from functools import partial
-import tops
-from tops.config import instantiate
-import warnings
-from typing import Iterable, List, Tuple
-import numpy as np
-import torch
-import torch.nn as nn
-from torch import einsum
-from einops import rearrange
-from dp2 import infer, utils
-from .base import BaseGenerator
-from sg3_torch_utils.ops import bias_act
-from dp2.layers import Sequential
-import torch.nn.functional as F
-from torchvision.transforms.functional import resize, InterpolationMode
-from sg3_torch_utils.ops import conv2d_resample, fma, upfirdn2d
-
-
-
-
-class Upfirdn2d(torch.nn.Module):
-
-
- def __init__(self, down=1, up=1, fix_gain=True):
- super().__init__()
- self.register_buffer("resample_filter", upfirdn2d.setup_filter([1, 3, 3, 1]))
- fw, fh = upfirdn2d._get_filter_size(self.resample_filter)
- px0, px1, py0, py1 = upfirdn2d._parse_padding(0)
- self.down = down
- self.up = up
- if up > 1:
- px0 += (fw + up - 1) // 2
- px1 += (fw - up) // 2
- py0 += (fh + up - 1) // 2
- py1 += (fh - up) // 2
- if down > 1:
- px0 += (fw - down + 1) // 2
- px1 += (fw - down) // 2
- py0 += (fh - down + 1) // 2
- py1 += (fh - down) // 2
- self.padding = [px0,px1,py0,py1]
- self.gain = up**2 if fix_gain else 1
-
- def forward(self, x, *args):
- if isinstance(x, dict):
- x = {k: v for k, v in x.items()}
- x["x"] = upfirdn2d.upfirdn2d(x["x"], self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain)
- return x
- x = upfirdn2d.upfirdn2d(x, self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain)
- if len(args) == 0:
- return x
- return (x, *args)
-@torch.no_grad()
-def spatial_embed_keypoints(keypoints: torch.Tensor, x):
- tops.assert_shape(keypoints, (None, None, 3))
- B, N_K, _ = keypoints.shape
- H, W = x.shape[-2:]
- keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32)
- x, y, visible = keypoints.chunk(3, dim=2)
- x = (x * W).round().long().clamp(0, W-1)
- y = (y * H).round().long().clamp(0, H-1)
- kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1)
- pos = (kp_idx*(H*W) + y*W + x + 1)
- # Offset all by 1 to index invisible keypoints as 0
- pos = (pos * visible.round().long()).squeeze(dim=-1)
- keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32)
- keypoint_spatial.scatter_(1, pos, 1)
- keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W)
- return keypoint_spatial
-
-
-def modulated_conv2d(
- x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
- weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
- styles, # Modulation coefficients of shape [batch_size, in_channels].
- noise = None, # Optional noise tensor to add to the output activations.
- up = 1, # Integer upsampling factor.
- down = 1, # Integer downsampling factor.
- padding = 0, # Padding with respect to the upsampled image.
- resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
- demodulate = True, # Apply weight demodulation?
- flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
- fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
-):
- batch_size = x.shape[0]
- out_channels, in_channels, kh, kw = weight.shape
- tops.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
- tops.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
- tops.assert_shape(styles, [batch_size, in_channels]) # [NI]
-
- # Pre-normalize inputs to avoid FP16 overflow.
- if x.dtype == torch.float16 and demodulate:
- weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
- styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
-
- # Calculate per-sample weights and demodulation coefficients.
- w = None
- dcoefs = None
- if demodulate or fused_modconv:
- w = weight.unsqueeze(0) # [NOIkk]
- w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
- if demodulate:
- dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
- if demodulate and fused_modconv:
- w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
-
- # Execute by scaling the activations before and after the convolution.
- if not fused_modconv:
- x = x * styles.reshape(batch_size, -1, 1, 1)
- x = conv2d_resample.conv2d_resample(x=x, w=weight, f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
- if demodulate and noise is not None:
- x = fma.fma(x, dcoefs.reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
- elif demodulate:
- x = x * dcoefs.reshape(batch_size, -1, 1, 1)
- elif noise is not None:
- x = x.add_(noise.to(x.dtype))
- return x
-
- with tops.suppress_tracer_warnings(): # this value will be treated as a constant
- batch_size = int(batch_size)
- # Execute as one fused op using grouped convolution.
- tops.assert_shape(x, [batch_size, in_channels, None, None])
- x = x.reshape(1, -1, *x.shape[2:])
- w = w.reshape(-1, in_channels, kh, kw)
- x = conv2d_resample.conv2d_resample(x=x, w=w, f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
- x = x.reshape(batch_size, -1, *x.shape[2:])
- if noise is not None:
- x = x.add_(noise)
- return x
-
-
-class Identity(nn.Module):
-
- def __init__(self) -> None:
- super().__init__()
-
- def forward(self, x, *args, **kwargs):
- return x
-
-
-class LayerNorm(nn.Module):
- def __init__(self, dim, stable=False):
- super().__init__()
- self.stable = stable
- self.g = nn.Parameter(torch.ones(dim))
-
- def forward(self, x):
- if self.stable:
- x = x / x.amax(dim=-1, keepdim=True).detach()
-
- eps = 1e-5 if x.dtype == torch.float32 else 1e-3
- var = torch.var(x, dim=-1, unbiased=False, keepdim=True)
- mean = torch.mean(x, dim=-1, keepdim=True)
- return (x - mean) * (var + eps).rsqrt() * self.g
-
-
-class FullyConnectedLayer(torch.nn.Module):
- def __init__(self,
- in_features, # Number of input features.
- out_features, # Number of output features.
- bias = True, # Apply additive bias before the activation function?
- activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
- lr_multiplier = 1, # Learning rate multiplier.
- bias_init = 0, # Initial value for the additive bias.
- ):
- super().__init__()
- self.repr = dict(
- in_features=in_features, out_features=out_features, bias=bias,
- activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init)
- self.activation = activation
- self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
- self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
- self.weight_gain = lr_multiplier / np.sqrt(in_features)
- self.bias_gain = lr_multiplier
- self.in_features = in_features
- self.out_features = out_features
-
- def forward(self, x):
- w = self.weight * self.weight_gain
- b = self.bias
- if b is not None:
- if self.bias_gain != 1:
- b = b * self.bias_gain
- x = F.linear(x, w)
- x = bias_act.bias_act(x, b, act=self.activation)
- return x
-
- def extra_repr(self) -> str:
- return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
-
-
-
-def checkpoint_fn(fn, *args, **kwargs):
- warnings.simplefilter("ignore")
- return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
-
-class Conv2d(torch.nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size=3,
- activation='lrelu',
- conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
- bias=True,
- norm=None,
- lr_multiplier=1,
- bias_init=0,
- w_dim=None,
- gradient_checkpoint_norm=False,
- gain=1,
- ):
- super().__init__()
- self.fused_modconv = False
- if norm == torch.nn.InstanceNorm2d:
- self.norm = torch.nn.InstanceNorm2d(None)
- elif isinstance(norm, torch.nn.Module):
- self.norm = norm
- elif norm == "fused_modconv":
- self.fused_modconv = True
- elif norm:
- self.norm = torch.nn.InstanceNorm2d(None)
- elif norm is not None:
- raise ValueError(f"norm not supported: {norm}")
- self.activation = activation
- self.conv_clamp = conv_clamp
- self.out_channels = out_channels
- self.in_channels = in_channels
- self.padding = kernel_size // 2
- self.repr = dict(
- in_channels=in_channels, out_channels=out_channels,
- kernel_size=kernel_size,
- activation=activation, conv_clamp=conv_clamp, bias=bias,
- fused_modconv=self.fused_modconv
- )
- self.act_gain = bias_act.activation_funcs[activation].def_gain * gain
- self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2))
- self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]))
- self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None
- self.bias_gain = lr_multiplier
- if w_dim is not None:
- if self.fused_modconv:
- self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
- else:
- self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
- self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0)
- self.gradient_checkpoint_norm = gradient_checkpoint_norm
-
- def forward(self, x, w=None, gain=1, **kwargs):
- if self.fused_modconv:
- styles = self.affine(w)
- with torch.cuda.amp.autocast(enabled=False):
- x = modulated_conv2d(x=x.half(), weight=self.weight.half(), styles=styles.half(), noise=None,
- padding=self.padding, flip_weight=True, fused_modconv=False).to(x.dtype)
- else:
- if hasattr(self, "affine"):
- gamma = self.affine(w).view(-1, self.in_channels, 1, 1)
- beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1)
- x = fma.fma(x, gamma ,beta)
- w = self.weight * self.weight_gain
- x = F.conv2d(input=x, weight=w, padding=self.padding,)
-
- if hasattr(self, "norm"):
- if self.gradient_checkpoint_norm:
- x = checkpoint_fn(self.norm, x)
- else:
- x = self.norm(x)
- act_gain = self.act_gain * gain
- act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
- b = self.bias * self.bias_gain if self.bias is not None else None
- x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
- return x
-
- def extra_repr(self) -> str:
- return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
-
-
-class CrossAttention(nn.Module):
- def __init__(
- self,
- dim,
- context_dim,
- dim_head=64,
- heads=8,
- norm_context=False,
- ):
- super().__init__()
- self.scale = dim_head ** -0.5
-
- self.heads = heads
- inner_dim = dim_head * heads
-
- self.norm = nn.InstanceNorm1d(dim)
- self.norm_context = nn.InstanceNorm1d(None) if norm_context else Identity()
-
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
- self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
-
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, dim, bias=False),
- nn.InstanceNorm1d(None)
- )
-
- def forward(self, x, w):
- x = self.norm(x)
- w = self.norm_context(w)
-
- q, k, v = (self.to_q(x), *self.to_kv(w).chunk(2, dim = -1))
- q = rearrange(q, "b n (h d) -> b h n d", h = self.heads)
- k = rearrange(k, "b n (h d) -> b h n d", h = self.heads)
- v = rearrange(v, "b n (h d) -> b h n d", h = self.heads)
- q = q * self.scale
- # similarities
- sim = einsum('b h i d, b h j d -> b h i j', q, k)
- attn = sim.softmax(dim = -1, dtype = torch.float32)
-
- out = einsum('b h i j, b h j d -> b h i d', attn, v)
- out = rearrange(out, 'b h n d -> b n (h d)')
- return self.to_out(out)
-
-
-class SG2ResidualBlock(torch.nn.Module):
- def __init__(
- self,
- in_channels, # Number of input channels, 0 = first block.
- out_channels, # Number of output channels.
- conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
- skip_gain=np.sqrt(.5),
- cross_attention: bool = False,
- cross_attention_len: int = None,
- use_adain: bool = True,
- **layer_kwargs, # Arguments for conv layer.
- ):
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- w_dim = layer_kwargs.pop("w_dim") if "w_dim" in layer_kwargs else None
- if use_adain:
- layer_kwargs["w_dim"] = w_dim
-
- self.conv0 = Conv2d(in_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs)
- self.conv1 = Conv2d(out_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs, gain=skip_gain)
-
- self.skip = Conv2d(in_channels, out_channels, kernel_size=1, bias=False, gain=skip_gain)
- if cross_attention and w_dim is not None:
- self.cross_attention_len = cross_attention_len
- self.cross_attn = CrossAttention(
- dim=out_channels, context_dim=w_dim//self.cross_attention_len,
- gain=skip_gain)
-
- def forward(self, x, w=None, **layer_kwargs):
- y = self.skip(x)
- x = self.conv0(x, w, **layer_kwargs)
- x = self.conv1(x, w, **layer_kwargs)
- if hasattr(self, "cross_attn"):
- h = x.shape[-2]
- x = rearrange(x, "b c h w -> b (h w) c")
- w = rearrange(w, "b (n c) -> b n c", n=self.cross_attention_len)
- x = self.cross_attn(x, w=w) + x
- x = rearrange(x, "b (h w) c -> b c h w", h=h)
- return y + x
-
-
-def default(val, d):
- if val is not None:
- return val
- return d() if callable(d) else d
-
-
-def cast_tuple(val, length=None):
- if isinstance(val, Iterable) and not isinstance(val, str):
- val = tuple(val)
- output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
- if length is not None:
- assert len(output) == length, (output, length)
- return output
-
-
-class Attention(nn.Module):
- # This is a version of Multi-Query Attention ()
- # Fast Transformer Decoding: One Write-Head is All You Need
- # Ablated in: https://arxiv.org/pdf/2203.07814.pdf
- # and https://arxiv.org/pdf/2204.02311.pdf
- def __init__(self, dim, norm, attn_fix_gain, gradient_checkpoint, dim_head=64, heads=8, cosine_sim_attn=False, fix_attention_again=False, gain=None):
- super().__init__()
- self.scale = dim_head**-0.5 if not cosine_sim_attn else 1.0
- self.cosine_sim_attn = cosine_sim_attn
- self.cosine_sim_scale = 16 if cosine_sim_attn else 1
- self.gradient_checkpoint = gradient_checkpoint
- self.heads = heads
- self.dim = dim
- self.fix_attention_again = fix_attention_again
- inner_dim = dim_head * heads
- if norm == "LN":
- self.norm = LayerNorm(dim)
- elif norm == "IN":
- self.norm = nn.InstanceNorm1d(dim)
- elif norm is None:
- self.norm = nn.Identity()
- else:
- raise ValueError(f"Norm not supported: {norm}")
-
- self.to_q = FullyConnectedLayer(dim, inner_dim, bias=False)
- self.to_kv = FullyConnectedLayer(dim, dim_head*2, bias=False)
-
- self.to_out = nn.Sequential(
- FullyConnectedLayer(inner_dim, dim, bias=False),
- LayerNorm(dim) if norm == "LN" else nn.InstanceNorm1d(dim)
- )
- if fix_attention_again:
- assert gain is not None
- self.gain = gain
- else:
- self.gain = np.sqrt(.5) if attn_fix_gain else 1
-
- def run_function(self, x, attn_bias):
- b, c, h, w = x.shape
- x = rearrange(x, "b c h w -> b (h w) c")
- in_ = x
- b, n, device = *x.shape[:2], x.device
- x = self.norm(x)
- q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
-
- q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
- q = q * self.scale
-
- # calculate query / key similarities
- sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale
-
- if attn_bias is not None:
- attn_bias = attn_bias
- attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)")
- sim = sim + attn_bias
-
- attn = sim.softmax(dim=-1)
-
- out = einsum("b h i j, b j d -> b h i d", attn, v)
-
- out = rearrange(out, "b h n d -> b n (h d)")
- if self.fix_attention_again:
- out = self.to_out(out)*self.gain + in_
- else:
- out = (self.to_out(out) + in_) * self.gain
- out = rearrange(out, "b (h w) c -> b c h w", h=h)
- return out
-
- def forward(self, x, *args, attn_bias=None, **kwargs):
- if self.gradient_checkpoint:
- return checkpoint_fn(self.run_function, x, attn_bias)
- return self.run_function(x, attn_bias)
-
- def get_attention(self, x, attn_bias=None):
- b, c, h, w = x.shape
- x = rearrange(x, "b c h w -> b (h w) c")
- in_ = x
- b, n, device = *x.shape[:2], x.device
- x = self.norm(x)
- q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
-
- q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
- q = q * self.scale
-
- # calculate query / key similarities
- sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale
-
- if attn_bias is not None:
- attn_bias = attn_bias
- attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)")
- sim = sim + attn_bias
-
- attn = sim.softmax(dim=-1)
- return attn, None
-
-
-class BiasedAttention(Attention):
-
- def __init__(self, *args, head_wise: bool=True, **kwargs):
- super().__init__(*args, **kwargs)
- out_ch = self.heads if head_wise else 1
- self.conv = Conv2d(self.dim+2, out_ch, activation="linear", kernel_size=3, bias_init=0)
- nn.init.zeros_(self.conv.weight.data)
-
- def forward(self, x, mask):
- mask = resize(mask, size=x.shape[-2:])
- bias = self.conv(torch.cat((x, mask, 1-mask), dim=1))
- return super().forward(x=x, attn_bias=bias)
-
- def get_attention(self, x, mask):
- mask = resize(mask, size=x.shape[-2:])
- bias = self.conv(torch.cat((x, mask, 1-mask), dim=1))
- return super().get_attention(x, bias)[0], bias
-
-class UNet(BaseGenerator):
-
- def __init__(
- self,
- im_channels: int,
- dim: int,
- dim_mults: tuple,
- num_resnet_blocks, # Number of resnet blocks per resolution
- n_middle_blocks: int,
- z_channels: int,
- conv_clamp: int,
- layer_attn,
- w_dim: int,
- norm_enc: bool,
- norm_dec: str,
- stylenet: nn.Module,
- enc_style: bool, # Toggle style injection in encoder
- use_maskrcnn_mask: bool,
- skip_all_unets: bool,
- fix_resize:bool,
- comodulate: bool,
- comod_net: nn.Module,
- lr_comod: float,
- dec_style: bool,
- input_keypoints: bool,
- n_keypoints: int,
- input_keypoint_indices: Tuple[int],
- use_adain: bool,
- cross_attention: bool,
- cross_attention_len: int,
- gradient_checkpoint_norm: bool,
- attn_cls: partial,
- mask_out_train: bool,
- fix_gain_again: bool,
- ) -> None:
- super().__init__(z_channels)
- self.enc_style = enc_style
- self.n_keypoints = n_keypoints
- self.input_keypoint_indices = list(input_keypoint_indices)
- self.input_keypoints = input_keypoints
- self.mask_out_train = mask_out_train
- n_layers = len(dim_mults)
- self.n_layers = n_layers
- layer_attn = cast_tuple(layer_attn, n_layers)
- num_resnet_blocks = cast_tuple(num_resnet_blocks, n_layers)
- self._cnum = dim
- self._image_channels = im_channels
- self._z_channels = z_channels
- encoder_layers = []
- condition_ch = im_channels
- self.from_rgb = Conv2d(
- condition_ch + 2 + 2*int(use_maskrcnn_mask) + self.input_keypoints*len(input_keypoint_indices)
- , dim, 7)
-
- self.use_maskrcnn_mask = use_maskrcnn_mask
- self.skip_all_unets = skip_all_unets
- dims = [dim*m for m in dim_mults]
- enc_blk = partial(
- SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_enc,
- use_adain=use_adain and self.enc_style,
- w_dim=w_dim,
- cross_attention=cross_attention,
- cross_attention_len=cross_attention_len,
- gradient_checkpoint_norm=gradient_checkpoint_norm
- )
- dec_blk = partial(
- SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_dec,
- use_adain=use_adain and dec_style,
- w_dim=w_dim,
- cross_attention=cross_attention,
- cross_attention_len=cross_attention_len,
- gradient_checkpoint_norm=gradient_checkpoint_norm
- )
- # Currently up/down sampling is done by bilinear upsampling.
- # This can be simplified by replacing it with a strided upsampling layer...
- self.encoder_attns = nn.ModuleList()
- for lidx in range(n_layers):
- gain = np.sqrt(1/3) if layer_attn[lidx] and fix_gain_again else np.sqrt(.5)
- dim_in = dims[lidx]
- dim_out = dims[min(lidx+1, n_layers-1)]
- res_blocks = nn.ModuleList()
- for i in range(num_resnet_blocks[lidx]):
- is_last = num_resnet_blocks[lidx] - 1 == i
- cur_dim = dim_out if is_last else dim_in
- block = enc_blk(dim_in, cur_dim, skip_gain=gain)
- res_blocks.append(block)
- if layer_attn[lidx]:
- self.encoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain))
- else:
- self.encoder_attns.append(Identity())
- encoder_layers.append(res_blocks)
- self.encoder = torch.nn.ModuleList(encoder_layers)
-
- # initialize decoder
- decoder_layers = []
- self.unet_layers = torch.nn.ModuleList()
- self.decoder_attns = torch.nn.ModuleList()
- for lidx in range(n_layers):
- dim_in = dims[min(-lidx, -1)]
- dim_out = dims[-1-lidx]
- res_blocks = nn.ModuleList()
- unet_skips = nn.ModuleList()
- for i in range(num_resnet_blocks[-lidx-1]):
- is_first = i == 0
- has_unet = is_first or skip_all_unets
- is_last = i == num_resnet_blocks[-lidx-1] - 1
- cur_dim = dim_in if is_first else dim_out
- if has_unet and is_last and layer_attn[-lidx-1] and fix_gain_again: # x + residual + unet + layer attn
- gain = np.sqrt(1/4)
- elif has_unet: # x + residual + unet
- gain = np.sqrt(1/3)
- elif layer_attn[-lidx-1] and fix_gain_again: # x + residual + attention
- gain = np.sqrt(1/3)
- else: # x + residual
- gain = np.sqrt(1/2) # Only residual block
- block = dec_blk(cur_dim, dim_out, skip_gain=gain)
- res_blocks.append(block)
- if has_unet:
- unet_block = Conv2d(
- cur_dim, cur_dim, kernel_size=1, conv_clamp=conv_clamp,
- norm=nn.InstanceNorm2d(None),
- gradient_checkpoint_norm=gradient_checkpoint_norm,
- gain=gain)
- unet_skips.append(unet_block)
- else:
- unet_skips.append(torch.nn.Identity())
- if layer_attn[-lidx-1]:
- self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain))
- else:
- self.decoder_attns.append(Identity())
-
- decoder_layers.append(res_blocks)
- self.unet_layers.append(unet_skips)
-
- middle_blocks = []
- for i in range(n_middle_blocks):
- block = dec_blk(dims[-1], dims[-1])
- middle_blocks.append(block)
- if n_middle_blocks != 0:
- self.middle_blocks = Sequential(*middle_blocks)
- self.decoder = torch.nn.ModuleList(decoder_layers)
- self.to_rgb = Conv2d(dim, im_channels, 1, activation="linear", conv_clamp=conv_clamp)
- self.stylenet = stylenet
- self.downsample = Upfirdn2d(down=2, fix_gain=fix_resize)
- self.upsample = Upfirdn2d(up=2, fix_gain=fix_resize)
- self.comodulate = comodulate
- if comodulate:
- assert not self.enc_style
- self.to_y = nn.Sequential(
- Conv2d(dims[-1], dims[-1], lr_multiplier=lr_comod, gradient_checkpoint_norm=gradient_checkpoint_norm),
- nn.AdaptiveAvgPool2d(1),
- nn.Flatten(),
- FullyConnectedLayer(dims[-1], 512, activation="lrelu", lr_multiplier=lr_comod)
- )
- self.comod_net = comod_net
-
-
- def forward(self, condition, mask, maskrcnn_mask=None, z=None, w=None, update_emas=False, keypoints=None, return_decoder_features=False, **kwargs):
- if z is None:
- z = self.get_z(condition)
- if w is None:
- w = self.stylenet(z, update_emas=update_emas)
- if self.use_maskrcnn_mask:
- x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
- else:
- x = torch.cat((condition, mask, 1-mask), dim=1)
-
- if self.input_keypoints:
- keypoints = keypoints[:, self.input_keypoint_indices]
- one_hot_pose = spatial_embed_keypoints(keypoints, x)
- x = torch.cat((x, one_hot_pose), dim=1)
- x = self.from_rgb(x)
- x, unet_features = self.forward_enc(x, mask, w)
- x, decoder_features = self.forward_dec(x, mask, w, unet_features)
- x = self.to_rgb(x)
- unmasked = x
- if self.mask_out_train:
- x = mask * condition + (1-mask) * x
- out = dict(img=x, unmasked=unmasked)
- if return_decoder_features:
- out["decoder_features"] = decoder_features
- return out
-
- def forward_enc(self, x, mask, w):
- unet_features = []
- for i, res_blocks in enumerate(self.encoder):
- is_last = i == len(self.encoder) - 1
- for block in res_blocks:
- x = block(x, w=w)
- unet_features.append(x)
- x = self.encoder_attns[i](x, mask=mask)
- if not is_last:
- x = self.downsample(x)
- if self.comodulate:
- y = self.to_y(x)
- y = torch.cat((w, y), dim=-1)
- w = self.comod_net(y)
- return x, unet_features
-
- def forward_dec(self, x, mask, w, unet_features):
- if hasattr(self, "middle_blocks"):
- x = self.middle_blocks(x, w=w)
- features = []
- unet_features = iter(reversed(unet_features))
- for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)):
- is_last = i == len(self.decoder) - 1
- for skip, block in zip(unet_skip, res_blocks):
- skip_x = next(unet_features)
- if not isinstance(skip, torch.nn.Identity):
- skip_x = skip(skip_x)
- x = x + skip_x
- x = block(x, w=w)
- x = self.decoder_attns[i](x, mask=mask)
- features.append(x)
- if not is_last:
- x = self.upsample(x)
- return x, features
-
- def get_w(self, z, update_emas):
- return self.stylenet(z, update_emas=update_emas)
-
- @torch.no_grad()
- def sample(self, truncation_value, **kwargs):
- if truncation_value is None:
- return self.forward(**kwargs)
- truncation_value = max(0, truncation_value)
- truncation_value = min(truncation_value, 1)
- w = self.get_w(self.get_z(kwargs["condition"]), False)
- w = self.stylenet.w_avg.to(w.dtype).lerp(w, truncation_value)
- return self.forward(**kwargs, w=w)
-
- def update_w(self, *args, **kwargs):
- self.style_net.update_w(*args, **kwargs)
-
- @property
- def style_net(self):
- return self.stylenet
-
- @torch.no_grad()
- def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
- if truncation_value is None:
- return self.forward(**kwargs)
- truncation_value = max(0, truncation_value)
- truncation_value = min(truncation_value, 1)
- w = self.get_w(self.get_z(kwargs["condition"]), False)
- if w_indices is None:
- w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
- w_centers = self.style_net.w_centers[w_indices].to(w.device)
- w = w_centers.to(w.dtype).lerp(w, truncation_value)
- return self.forward(**kwargs, w=w)
-
-
-def get_stem_unet_kwargs(cfg):
- if "stem_cfg" in cfg.generator: # If the stem has another stem, recursively apply get_stem_unet_kwargs
- return get_stem_unet_kwargs(cfg.generator.stem_cfg)
- return dict(cfg.generator)
-
-
-class GrowingUnet(BaseGenerator):
-
- def __init__(
- self,
- coarse_stem_cfg: str, # This can be a coarse generator or None
- sr_cfg: str, # Can be a previous progressive u-net, Unet or None
- residual: bool,
- new_dataset: bool, # The "new dataset" creates condition first -> resizes
- **unet_kwargs):
- kwargs = dict()
- if coarse_stem_cfg is not None:
- coarse_stem_cfg = utils.load_config(coarse_stem_cfg)
- kwargs = get_stem_unet_kwargs(coarse_stem_cfg)
- if sr_cfg is not None:
- sr_cfg = utils.load_config(sr_cfg)
- sr_stem_unet_kwargs = get_stem_unet_kwargs(sr_cfg)
- kwargs.update(sr_stem_unet_kwargs)
- kwargs.update(unet_kwargs)
- kwargs["stylenet"] = None
- kwargs.pop("_target_")
- if "sr_cfg" in kwargs: # Unet kwargs are inherited, do not pass this to the new u-net
- del kwargs["sr_cfg"]
- if "coarse_stem_cfg" in kwargs:
- del kwargs["coarse_stem_cfg"]
- super().__init__(z_channels=kwargs["z_channels"])
- if coarse_stem_cfg is not None:
- z_channels = coarse_stem_cfg.generator.z_channels
- super().__init__(z_channels)
- self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval()
- self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize)
- utils.set_requires_grad(self.coarse_stem, False)
- else:
- assert not residual
-
- if sr_cfg is not None:
- self.sr_stem = infer.build_trained_generator(sr_cfg, map_location="cpu").eval()
- del self.sr_stem.from_rgb
- del self.sr_stem.to_rgb
- if hasattr(self.sr_stem, "coarse_stem"):
- del self.sr_stem.coarse_stem
- if isinstance(self.sr_stem, UNet):
- del self.sr_stem.encoder[0][0] # Delete first residual block
- del self.sr_stem.decoder[-1][-1] # Delete last residual block
- else:
- assert isinstance(self.sr_stem, GrowingUnet)
- del self.sr_stem.unet.encoder[0][0] # Delete first residual block
- del self.sr_stem.unet.decoder[-1][-1] # Delete last residual block
- utils.set_requires_grad(self.sr_stem, False)
-
-
- args = kwargs.pop("_args_")
- if hasattr(self, "sr_stem"): # Growing the SR stem - Add a new layer to match sr
- n_layers = len(kwargs["dim_mults"])
- dim_mult = sr_stem_unet_kwargs["dim"] / (kwargs["dim"] * max(kwargs["dim_mults"]))
- kwargs["dim_mults"] = [*kwargs["dim_mults"], int(dim_mult)]
- kwargs["layer_attn"] = [*cast_tuple(kwargs["layer_attn"], n_layers), False]
- kwargs["num_resnet_blocks"] = [*cast_tuple(kwargs["num_resnet_blocks"], n_layers), 1]
- self.unet = UNet(
- *args,
- **kwargs
- )
- self.from_rgb = self.unet.from_rgb
- self.to_rgb = self.unet.to_rgb
- self.residual = residual
- self.new_dataset = new_dataset
- if residual:
- nn.init.zeros_(self.to_rgb.weight.data)
- del self.unet.from_rgb, self.unet.to_rgb
-
- def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, **kwargs):
- # Downsample for stem
- if z is None:
- z = self.get_z(img)
- if w is None:
- w = self.style_net(z)
- if hasattr(self, "coarse_stem"):
- with torch.no_grad():
- if self.new_dataset:
- img_stem = utils.denormalize_img(img)*255
- condition_stem = img_stem * mask + (1-mask)*127
- condition_stem = condition_stem.round()
- condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True)
- condition_stem = condition_stem / 255 *2 - 1
- mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float()
- maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float()
- else:
- mask_stem = (resize(mask, self.coarse_stem.imsize, antialias=True) > .99).float()
- maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, antialias=True) > .5).float()
- img_stem = utils.denormalize_img(img)*255
- img_stem = resize(img_stem, self.coarse_stem.imsize, antialias=True).round()
- img_stem = img_stem / 255 * 2 - 1
- condition_stem = img_stem * mask_stem
- stem_out = self.coarse_stem(
- condition=condition_stem, mask=mask_stem,
- maskrcnn_mask=maskrcnn_stem, w=w,
- keypoints=keypoints)
- x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True)
- condition = condition*mask + (1-mask) * x_lr
- if self.unet.use_maskrcnn_mask:
- x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
- else:
- x = torch.cat((condition, mask, 1-mask), dim=1)
- if self.unet.input_keypoints:
- keypoints = keypoints[:, self.unet.input_keypoint_indices]
- one_hot_pose = spatial_embed_keypoints(keypoints, x)
- x = torch.cat((x, one_hot_pose), dim=1)
- x = self.from_rgb(x)
- x, unet_features = self.forward_enc(x, mask, w)
- x = self.forward_dec(x, mask, w, unet_features)
- if self.residual:
- x = self.to_rgb(x) + condition
- else:
- x = self.to_rgb(x)
- return dict(
- img=condition * mask + (1-mask) * x,
- unmasked=x,
- x_lowres=[condition]
- )
-
- def forward_enc(self, x, mask, w):
- x, unet_features = self.unet.forward_enc(x, mask, w)
- if hasattr(self, "sr_stem"):
- x, unet_features_stem = self.sr_stem.forward_enc(x, mask, w)
- else:
- unet_features_stem = None
- return x, [unet_features, unet_features_stem]
-
- def forward_dec(self, x, mask, w, unet_features):
- unet_features, unet_features_stem = unet_features
- if hasattr(self, "sr_stem"):
- x = self.sr_stem.forward_dec(x, mask, w, unet_features_stem)
- x, unet_features = self.unet.forward_dec(x, mask, w, unet_features)
- return x
-
- def get_z(self, *args, **kwargs):
- if hasattr(self, "coarse_stem"):
- return self.coarse_stem.get_z(*args, **kwargs)
- if hasattr(self, "sr_stem"):
- return self.sr_stem.get_z(*args, **kwargs)
- raise AttributeError()
-
- @property
- def style_net(self):
- if hasattr(self, "coarse_stem"):
- return self.coarse_stem.style_net
- if hasattr(self, "sr_stem"):
- return self.sr_stem.style_net
- raise AttributeError()
-
- def update_w(self, *args, **kwargs):
- self.style_net.update_w(*args, **kwargs)
-
- def get_w(self, z, update_emas):
- return self.style_net(z, update_emas=update_emas)
-
- @torch.no_grad()
- def sample(self, truncation_value, **kwargs):
- if truncation_value is None:
- return self.forward(**kwargs)
- truncation_value = max(0, truncation_value)
- truncation_value = min(truncation_value, 1)
- w = self.get_w(self.get_z(kwargs["condition"]), False)
- w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
- return self.forward(**kwargs, w=w)
-
- @torch.no_grad()
- def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
- if truncation_value is None:
- return self.forward(**kwargs)
- truncation_value = max(0, truncation_value)
- truncation_value = min(truncation_value, 1)
- w = self.get_w(self.get_z(kwargs["condition"]), False)
- if w_indices is None:
- w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
- w_centers = self.style_net.w_centers[w_indices].to(w.device)
- w = w_centers.to(w.dtype).lerp(w, truncation_value)
- return self.forward(**kwargs, w=w)
-
-
-class CascadedUnet(BaseGenerator):
-
- def __init__(
- self,
- coarse_stem_cfg: str, # This can be a coarse generator or None
- residual: bool,
- new_dataset: bool, # The "new dataset" creates condition first -> resizes
- imsize: tuple,
- cascade:bool,
- **unet_kwargs):
- kwargs = dict()
- coarse_stem_cfg = utils.load_config(coarse_stem_cfg)
- kwargs = get_stem_unet_kwargs(coarse_stem_cfg)
- kwargs.update(unet_kwargs)
- super().__init__(z_channels=kwargs["z_channels"])
-
- self.input_keypoints = kwargs["input_keypoints"]
- self.input_keypoint_indices = kwargs["input_keypoint_indices"]
- self.use_maskrcnn_mask = kwargs["use_maskrcnn_mask"]
- self.imsize = imsize
- self.residual = residual
- self.new_dataset = new_dataset
-
-
- # Setup coarse stem
- stem_dims = [m*coarse_stem_cfg.generator.dim for m in coarse_stem_cfg.generator.dim_mults]
- self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval()
- self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize)
- utils.set_requires_grad(self.coarse_stem, False)
-
- self.stem_res_to_layer_idx = {
- self.coarse_stem.imsize[0] // 2^i: stem_dims[i]
- for i in range(len(stem_dims))
- }
-
- dim = kwargs["dim"]
- dim_mults = kwargs["dim_mults"]
- n_layers = len(dim_mults)
- dims = [dim*s for s in dim_mults]
- layer_attn = cast_tuple(kwargs["layer_attn"], n_layers)
- num_resnet_blocks = cast_tuple(kwargs["num_resnet_blocks"], n_layers)
- attn_cls = kwargs["attn_cls"]
- if not isinstance(attn_cls, partial):
- attn_cls = instantiate(attn_cls)
-
- dec_blk = partial(
- SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_dec"],
- use_adain=kwargs["use_adain"] and kwargs["dec_style"],
- w_dim=kwargs["w_dim"],
- cross_attention=kwargs["cross_attention"],
- cross_attention_len=kwargs["cross_attention_len"],
- gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"]
- )
- enc_blk = partial(
- SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_enc"],
- use_adain=kwargs["use_adain"] and kwargs["enc_style"],
- w_dim=kwargs["w_dim"],
- cross_attention=kwargs["cross_attention"],
- cross_attention_len=kwargs["cross_attention_len"],
- gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"]
- )
-
- # Currently up/down sampling is done by bilinear upsampling.
- # This can be simplified by replacing it with a strided upsampling layer...
- self.encoder_attns = nn.ModuleList()
- self.encoder_unet_skips = nn.ModuleDict()
- self.encoder = nn.ModuleList()
- for lidx in range(n_layers):
- has_stem_feature = imsize[0]//2^lidx in self.stem_res_to_layer_idx and cascade
- next_layer_has_stem_features = lidx+1 < n_layers and imsize[0]//2^(lidx+1) in self.stem_res_to_layer_idx and cascade
-
- dim_in = dims[lidx]
- dim_out = dims[min(lidx+1, n_layers-1)]
- res_blocks = nn.ModuleList()
- if has_stem_feature:
- prev_layer_has_attention = lidx != 0 and layer_attn[lidx-1]
- stem_lidx = self.stem_res_to_layer_idx[imsize[0]//2^lidx]
- self.encoder_unet_skips.add_module(
- str(imsize[0]//2^lidx),
- Conv2d(
- stem_dims[stem_lidx], dim_in, kernel_size=1,
- conv_clamp=kwargs["conv_clamp"],
- norm=nn.InstanceNorm2d(None),
- gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"],
- gain=np.sqrt(1/4) if prev_layer_has_attention else np.sqrt(1/3) # This + previous residual + attention
- )
- )
- for i in range(num_resnet_blocks[lidx]):
- is_last = num_resnet_blocks[lidx] - 1 == i
- cur_dim = dim_out if is_last else dim_in
- if not is_last:
- gain = np.sqrt(.5)
- elif next_layer_has_stem_features and layer_attn[lidx]:
- gain = np.sqrt(1/4)
- elif layer_attn[lidx] or next_layer_has_stem_features:
- gain = np.sqrt(1/3)
- else:
- gain = np.sqrt(.5)
- block = enc_blk(dim_in, cur_dim, skip_gain=gain)
- res_blocks.append(block)
- if layer_attn[lidx]:
- self.encoder_attns.append(attn_cls(dim=dim_out, gain=gain, fix_attention_again=True))
- else:
- self.encoder_attns.append(Identity())
- self.encoder.append(res_blocks)
-
- # initialize decoder
- self.decoder = torch.nn.ModuleList()
- self.unet_layers = torch.nn.ModuleList()
- self.decoder_attns = torch.nn.ModuleList()
- for lidx in range(n_layers):
- dim_in = dims[min(-lidx, -1)]
- dim_out = dims[-1-lidx]
- res_blocks = nn.ModuleList()
- unet_skips = nn.ModuleList()
- for i in range(num_resnet_blocks[-lidx-1]):
- is_first = i == 0
- has_unet = is_first or kwargs["skip_all_unets"]
- is_last = i == num_resnet_blocks[-lidx-1] - 1
- cur_dim = dim_in if is_first else dim_out
- if has_unet and is_last and layer_attn[-lidx-1]: # x + residual + unet + layer attn
- gain = np.sqrt(1/4)
- elif has_unet: # x + residual + unet
- gain = np.sqrt(1/3)
- elif layer_attn[-lidx-1]: # x + residual + attention
- gain = np.sqrt(1/3)
- else: # x + residual
- gain = np.sqrt(1/2) # Only residual block
- block = dec_blk(cur_dim, dim_out, skip_gain=gain)
- res_blocks.append(block)
- if kwargs["skip_all_unets"] or is_first:
- unet_block = Conv2d(
- cur_dim, cur_dim, kernel_size=1, conv_clamp=kwargs["conv_clamp"],
- norm=nn.InstanceNorm2d(None),
- gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"],
- gain=gain)
- unet_skips.append(unet_block)
- else:
- unet_skips.append(torch.nn.Identity())
- if layer_attn[-lidx-1]:
- self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=True, gain=gain))
- else:
- self.decoder_attns.append(Identity())
-
- self.decoder.append(res_blocks)
- self.unet_layers.append(unet_skips)
-
- self.from_rgb = Conv2d(
- 3 + 2 + 2*int(kwargs["use_maskrcnn_mask"]) + self.input_keypoints*len(kwargs["input_keypoint_indices"])
- , dim, 7)
- self.to_rgb = Conv2d(dim, 3, 1, activation="linear", conv_clamp=kwargs["conv_clamp"])
-
- self.downsample = Upfirdn2d(down=2, fix_gain=True)
- self.upsample = Upfirdn2d(up=2, fix_gain=True)
- self.cascade = cascade
- if residual:
- nn.init.zeros_(self.to_rgb.weight.data)
-
- def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, return_decoder_features=False, **kwargs):
- # Downsample for stem
- if z is None:
- z = self.get_z(img)
-
- with torch.no_grad(): # Forward pass stem
- if w is None:
- w = self.style_net(z)
- img_stem = utils.denormalize_img(img)*255
- condition_stem = img_stem * mask + (1-mask)*127
- condition_stem = condition_stem.round()
- condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True)
- condition_stem = condition_stem / 255 *2 - 1
- mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float()
- maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float()
- stem_out = self.coarse_stem(
- condition=condition_stem, mask=mask_stem,
- maskrcnn_mask=maskrcnn_stem, w=w,
- keypoints=keypoints,
- return_decoder_features=True)
- stem_features = stem_out["decoder_features"]
- x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True)
- condition = condition*mask + (1-mask) * x_lr
-
- if self.use_maskrcnn_mask:
- x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
- else:
- x = torch.cat((condition, mask, 1-mask), dim=1)
- if self.input_keypoints:
- keypoints = keypoints[:, self.input_keypoint_indices]
- one_hot_pose = spatial_embed_keypoints(keypoints, x)
- x = torch.cat((x, one_hot_pose), dim=1)
- x = self.from_rgb(x)
- x, unet_features = self.forward_enc(x, mask, w, stem_features)
- x, decoder_features = self.forward_dec(x, mask, w, unet_features)
- if self.residual:
- x = self.to_rgb(x) + condition
- else:
- x = self.to_rgb(x)
- out= dict(
- img=condition * mask + (1-mask) * x, # TODO: Probably do not want masked here... or ??
- unmasked=x,
- x_lowres=[condition]
- )
- if return_decoder_features:
- out["decoder_features"] = decoder_features
- return out
-
- def forward_enc(self, x, mask, w, stem_features: List[torch.Tensor]):
- unet_features = []
- stem_features.reverse()
- for i, res_blocks in enumerate(self.encoder):
- is_last = i == len(self.encoder) - 1
- res = self.imsize[0]//2^i
- if str(res) in self.encoder_unet_skips.keys() and self.cascade:
- y = stem_features[self.stem_res_to_layer_idx[res]]
- y = self.encoder_unet_skips[i](y)
- x = y + x
- for block in res_blocks:
- x = block(x, w=w)
- unet_features.append(x)
- x = self.encoder_attns[i](x, mask)
- if not is_last:
- x = self.downsample(x)
- return x, unet_features
-
- def forward_dec(self, x, mask, w, unet_features):
- features = []
- unet_features = iter(reversed(unet_features))
- for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)):
- is_last = i == len(self.decoder) - 1
- for skip, block in zip(unet_skip, res_blocks):
- skip_x = next(unet_features)
- if not isinstance(skip, torch.nn.Identity):
- skip_x = skip(skip_x)
- x = x + skip_x
- x = block(x, w=w)
- x = self.decoder_attns[i](x, mask)
- features.append(x)
- if not is_last:
- x = self.upsample(x)
- return x, features
-
- def get_z(self, *args, **kwargs):
- return self.coarse_stem.get_z(*args, **kwargs)
-
- @property
- def style_net(self):
- return self.coarse_stem.style_net
-
- def update_w(self, *args, **kwargs):
- self.style_net.update_w(*args, **kwargs)
-
- def get_w(self, z, update_emas):
- return self.style_net(z, update_emas=update_emas)
-
- @torch.no_grad()
- def sample(self, truncation_value, **kwargs):
- if truncation_value is None:
- return self.forward(**kwargs)
- truncation_value = max(0, truncation_value)
- truncation_value = min(truncation_value, 1)
- w = self.get_w(self.get_z(kwargs["condition"]), False)
- w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
- return self.forward(**kwargs, w=w)
-
- @torch.no_grad()
- def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
- if truncation_value is None:
- return self.forward(**kwargs)
- truncation_value = max(0, truncation_value)
- truncation_value = min(truncation_value, 1)
- w = self.get_w(self.get_z(kwargs["condition"]), False)
- if w_indices is None:
- w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
- w_centers = self.style_net.w_centers[w_indices].to(w.device)
- w = w_centers.to(w.dtype).lerp(w, truncation_value)
- return self.forward(**kwargs, w=w)
diff --git a/dp2/generator/stylegan_unet.py b/dp2/generator/stylegan_unet.py
deleted file mode 100644
index 68c7f7706d601e4ae2eb80a4b3fd03bc127b2164..0000000000000000000000000000000000000000
--- a/dp2/generator/stylegan_unet.py
+++ /dev/null
@@ -1,208 +0,0 @@
-import torch
-import numpy as np
-from dp2.layers import Sequential
-from dp2.layers.sg2_layers import Conv2d, FullyConnectedLayer, ResidualBlock
-from .base import BaseStyleGAN
-from typing import List, Tuple
-from .utils import spatial_embed_keypoints, mask_output
-
-
-def get_chsize(imsize, cnum, max_imsize, max_cnum_mul):
- n = int(np.log2(max_imsize) - np.log2(imsize))
- mul = min(2**n, max_cnum_mul)
- ch = cnum * mul
- return int(ch)
-
-class StyleGANUnet(BaseStyleGAN):
- def __init__(
- self,
- scale_grad: bool,
- im_channels: int,
- min_fmap_resolution: int,
- imsize: List[int],
- cnum: int,
- max_cnum_mul: int,
- mask_output: bool,
- conv_clamp: int,
- input_cse: bool,
- cse_nc: int,
- n_middle_blocks: int,
- input_keypoints: bool,
- n_keypoints: int,
- input_keypoint_indices: Tuple[int],
- fix_errors: bool,
- **kwargs
- ) -> None:
- super().__init__(**kwargs)
- self.n_keypoints = n_keypoints
- self.input_keypoint_indices = list(input_keypoint_indices)
- self.input_keypoints = input_keypoints
- assert not (input_cse and input_keypoints)
- cse_nc = 0 if cse_nc is None else cse_nc
- self.imsize = imsize
- self._cnum = cnum
- self._max_cnum_mul = max_cnum_mul
- self._min_fmap_resolution = min_fmap_resolution
- self._image_channels = im_channels
- self._max_imsize = max(imsize)
- self.input_cse = input_cse
- self.gain_unet = np.sqrt(1/3)
- n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1
- encoder_layers = []
- self.from_rgb = Conv2d(
- im_channels + 1 + input_cse*(cse_nc+1) + input_keypoints*len(self.input_keypoint_indices),
- cnum, 1
- )
- for i in range(n_levels): # Encoder layers
- resolution = [x//2**i for x in imsize]
- in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
- second_ch = in_ch
- out_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul)
- down = 2
-
- if i == 0: # first (lowest) block. Downsampling is performed at the start of the block
- down = 1
- if i == n_levels - 1:
- out_ch = second_ch
- block = ResidualBlock(in_ch, out_ch, down=down, conv_clamp=conv_clamp, fix_residual=fix_errors)
- encoder_layers.append(block)
- self._encoder_out_shape = [
- get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul),
- *resolution]
-
- self.encoder = torch.nn.ModuleList(encoder_layers)
-
- # initialize decoder
- decoder_layers = []
- for i in range(n_levels):
- resolution = [x//2**(n_levels-1-i) for x in imsize]
- in_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul)
- out_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
- if i == 0: # first (lowest) block
- in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
-
- up = 1
- if i != n_levels - 1:
- up = 2
- block = ResidualBlock(
- in_ch, out_ch, conv_clamp=conv_clamp, gain_out=np.sqrt(1/3),
- w_dim=self.style_net.w_dim, norm=True, up=up,
- fix_residual=fix_errors
- )
- decoder_layers.append(block)
- if i != 0:
- unet_block = Conv2d(
- in_ch, in_ch, kernel_size=1, conv_clamp=conv_clamp, norm=True,
- gain=np.sqrt(1/3) if fix_errors else np.sqrt(.5))
- setattr(self, f"unet_block{i}", unet_block)
-
- # Initialize "middle blocks" that do not have down/up sample
- middle_blocks = []
- for i in range(n_middle_blocks):
- ch = get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul)
- block = ResidualBlock(
- ch, ch, conv_clamp=conv_clamp, gain_out=np.sqrt(.5) if fix_errors else np.sqrt(1/3),
- w_dim=self.style_net.w_dim, norm=True,
- )
- middle_blocks.append(block)
- if n_middle_blocks != 0:
- self.middle_blocks = Sequential(*middle_blocks)
- self.decoder = torch.nn.ModuleList(decoder_layers)
- self.to_rgb = Conv2d(cnum, im_channels, 1, activation="linear", conv_clamp=conv_clamp)
- # Initialize "middle blocks" that do not have down/up sample
- self.decoder = torch.nn.ModuleList(decoder_layers)
- self.scale_grad = scale_grad
- self.mask_output = mask_output
-
- def forward_dec(self, x, w, unet_features, condition, mask, s, **kwargs):
- for i, layer in enumerate(self.decoder):
- if i != 0:
- unet_layer = getattr(self, f"unet_block{i}")
- x = x + unet_layer(unet_features[-i])
- x = layer(x, w=w, s=s)
- x = self.to_rgb(x)
- if self.mask_output:
- x = mask_output(True, condition, x, mask)
- return dict(img=x)
-
- def forward_enc(self, condition, mask, embedding, keypoints, E_mask, **kwargs):
- if self.input_cse:
- x = torch.cat((condition, mask, embedding, E_mask), dim=1)
- else:
- x = torch.cat((condition, mask), dim=1)
- if self.input_keypoints:
- keypoints = keypoints[:, self.input_keypoint_indices]
- one_hot_pose = spatial_embed_keypoints(keypoints, x)
- x = torch.cat((x, one_hot_pose), dim=1)
- x = self.from_rgb(x)
-
- unet_features = []
- for i, layer in enumerate(self.encoder):
- x = layer(x)
- if i != len(self.encoder)-1:
- unet_features.append(x)
- if hasattr(self, "middle_blocks"):
- for layer in self.middle_blocks:
- x = layer(x)
- return x, unet_features
-
- def forward(
- self, condition, mask,
- z=None, embedding=None, w=None, update_emas=False, x=None,
- s=None,
- keypoints=None,
- unet_features=None,
- E_mask=None,
- **kwargs):
- # Used to skip sampling from encoder in inference. E.g. for w projection.
- if x is not None and unet_features is not None:
- assert not self.training
- else:
- x, unet_features = self.forward_enc(condition, mask, embedding, keypoints, E_mask, **kwargs)
- if w is None:
- if z is None:
- z = self.get_z(condition)
- w = self.get_w(z, update_emas=update_emas)
- return self.forward_dec(x, w, unet_features, condition, mask, s, **kwargs)
-
-class ComodStyleUNet(StyleGANUnet):
-
- def __init__(self, min_comod_res=4, lr_multiplier_comod=1, **kwargs) -> None:
- super().__init__(**kwargs)
- min_fmap = min(self._encoder_out_shape[1:])
- enc_out_ch = self._encoder_out_shape[0]
- n_down = int(np.ceil(np.log2(min_fmap) - np.log2(min_comod_res)))
- comod_layers = []
- in_ch = enc_out_ch
- for i in range(n_down):
- comod_layers.append(Conv2d(enc_out_ch, 256, kernel_size=3, down=2, lr_multiplier=lr_multiplier_comod))
- in_ch = 256
- if n_down == 0:
- comod_layers = [Conv2d(in_ch, 256, kernel_size=3)]
- comod_layers.append(torch.nn.Flatten())
- out_res = [x//2**n_down for x in self._encoder_out_shape[1:]]
- in_ch_fc = np.prod(out_res) * 256
- comod_layers.append(FullyConnectedLayer(in_ch_fc, 512, lr_multiplier=lr_multiplier_comod))
- self.comod_block = Sequential(*comod_layers)
- self.comod_fc = FullyConnectedLayer(512+self.style_net.w_dim, self.style_net.w_dim, lr_multiplier=lr_multiplier_comod)
-
- def forward_dec(self, x, w, unet_features, condition, mask, **kwargs):
- y = self.comod_block(x)
- y = torch.cat((y, w), dim=1)
- y = self.comod_fc(y)
- for i, layer in enumerate(self.decoder):
- if i != 0:
- unet_layer = getattr(self, f"unet_block{i}")
- x = x + unet_layer(unet_features[-i], gain=np.sqrt(.5))
- x = layer(x, w=y)
- x = self.to_rgb(x)
- if self.mask_output:
- x = mask_output(True, condition, x, mask)
- return dict(img=x)
-
- def get_comod_y(self, batch, w):
- x, unet_features = self.forward_enc(**batch)
- y = self.comod_block(x)
- y = torch.cat((y, w), dim=1)
- y = self.comod_fc(y)
- return y
diff --git a/dp2/generator/utils.py b/dp2/generator/utils.py
deleted file mode 100644
index 5732b2c511a42f4bffd4b512244cf790bec96ef0..0000000000000000000000000000000000000000
--- a/dp2/generator/utils.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import torch
-import tops
-import torch
-from torch.cuda.amp import custom_bwd, custom_fwd
-
-
-@torch.no_grad()
-def spatial_embed_keypoints(keypoints: torch.Tensor, x):
- tops.assert_shape(keypoints, (None, None, 3))
- B, N_K, _ = keypoints.shape
- H, W = x.shape[-2:]
- keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32)
- x, y, visible = keypoints.chunk(3, dim=2)
- x = (x * W).round().long().clamp(0, W-1)
- y = (y * H).round().long().clamp(0, H-1)
- kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1)
- pos = (kp_idx*(H*W) + y*W + x + 1)
- # Offset all by 1 to index invisible keypoints as 0
- pos = (pos * visible.round().long()).squeeze(dim=-1)
- keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32)
- keypoint_spatial.scatter_(1, pos, 1)
- keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W)
- return keypoint_spatial
-
-class MaskOutput(torch.autograd.Function):
-
- @staticmethod
- @custom_fwd
- def forward(ctx, x_real, x_fake, mask):
- ctx.save_for_backward(mask)
- out = x_real * mask + (1-mask) * x_fake
- return out
-
- @staticmethod
- @custom_bwd
- def backward(ctx, grad_output):
- fake_grad = grad_output
- mask, = ctx.saved_tensors
- fake_grad = fake_grad * (1 - mask)
- known_percentage = mask.view(mask.shape[0], -1).mean(dim=1)
- fake_grad = fake_grad / (1-known_percentage).view(-1, 1, 1, 1)
- return None, fake_grad, None
-
-
-def mask_output(scale_grad, x_real, x_fake, mask):
- if scale_grad:
- return MaskOutput.apply(x_real, x_fake, mask)
- return x_real * mask + (1-mask) * x_fake
diff --git a/dp2/infer.py b/dp2/infer.py
deleted file mode 100644
index ddbac8dbfd0914979575a0799c10b4eeb00aff73..0000000000000000000000000000000000000000
--- a/dp2/infer.py
+++ /dev/null
@@ -1,72 +0,0 @@
-import tops
-import torch
-from tops import checkpointer
-from tops.config import instantiate
-from tops.logger import warn
-
-def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None):
- state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"]
- if ckpt_mapper is not None:
- state = ckpt_mapper(state)
- load_state_dict(G, state)
- tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M")
- print(ckpt.keys())
- if "w_centers" in ckpt:
- print("Has w_centers!")
- G.style_net.w_centers = ckpt["w_centers"]
- tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}")
-
-
-def build_trained_generator(cfg, map_location=None):
- map_location = map_location if map_location is not None else tops.get_device()
- G = instantiate(cfg.generator).to(map_location)
- G.eval()
- G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None
- if hasattr(cfg, "ckpt_mapper"):
- ckpt_mapper = instantiate(cfg.ckpt_mapper)
- else:
- ckpt_mapper = None
- if "model_url" in cfg.common:
- ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum, map_location=torch.device("cpu"))
- load_generator_state(ckpt, G, ckpt_mapper)
- return G
- try:
- ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
- load_generator_state(ckpt, G, ckpt_mapper)
- except FileNotFoundError as e:
- tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}")
- return G
-
-
-def build_trained_discriminator(cfg, map_location=None):
- map_location = map_location if map_location is not None else tops.get_device()
- D = instantiate(cfg.discriminator).to(map_location)
- D.eval()
- try:
- ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
- if hasattr(cfg, "ckpt_mapper_D"):
- ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"])
- D.load_state_dict(ckpt["discriminator"])
- except FileNotFoundError as e:
- tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}")
- return D
-
-
-def load_state_dict(module: torch.nn.Module, state_dict: dict):
- module_sd = module.state_dict()
- to_remove = []
- for key, item in state_dict.items():
- if key not in module_sd:
- continue
- if item.shape != module_sd[key].shape:
- to_remove.append(key)
- warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}")
- for key in to_remove:
- state_dict.pop(key)
- for key, item in state_dict.items():
- if key not in module_sd:
- warn(f"Did not fin key in model state dict: {key}")
- for key, item in module_sd.items():
- if key not in state_dict:
- warn(f"Did not find key in state dict: {key}")
- module.load_state_dict(state_dict, strict=False)
\ No newline at end of file
diff --git a/dp2/layers/__init__.py b/dp2/layers/__init__.py
deleted file mode 100644
index 4f54bed21f71c7facaa8129e1b689f16c5f9d13d..0000000000000000000000000000000000000000
--- a/dp2/layers/__init__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from typing import Dict
-import torch
-import tops
-import torch.nn as nn
-
-class Sequential(nn.Sequential):
-
- def forward(self, x: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]:
- for module in self:
- x = module(x, **kwargs)
- return x
-
-class Module(nn.Module):
-
- def __init__(self, *args, **kwargs):
- super().__init__()
-
- def extra_repr(self):
- num_params = tops.num_parameters(self) / 10**6
- return f"Num params: {num_params:.3f}M"
diff --git a/dp2/layers/sg2_layers.py b/dp2/layers/sg2_layers.py
deleted file mode 100644
index 3aac03935d0cb6ec172f2cd0aabefbdc74d0141e..0000000000000000000000000000000000000000
--- a/dp2/layers/sg2_layers.py
+++ /dev/null
@@ -1,227 +0,0 @@
-from typing import List
-import numpy as np
-import torch
-import tops
-import torch.nn.functional as F
-from sg3_torch_utils.ops import conv2d_resample
-from sg3_torch_utils.ops import upfirdn2d
-from sg3_torch_utils.ops import bias_act
-from sg3_torch_utils.ops.fma import fma
-
-
-class FullyConnectedLayer(torch.nn.Module):
- def __init__(self,
- in_features, # Number of input features.
- out_features, # Number of output features.
- bias = True, # Apply additive bias before the activation function?
- activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
- lr_multiplier = 1, # Learning rate multiplier.
- bias_init = 0, # Initial value for the additive bias.
- ):
- super().__init__()
- self.repr = dict(
- in_features=in_features, out_features=out_features, bias=bias,
- activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init)
- self.activation = activation
- self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
- self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
- self.weight_gain = lr_multiplier / np.sqrt(in_features)
- self.bias_gain = lr_multiplier
- self.in_features = in_features
- self.out_features = out_features
-
- def forward(self, x):
- w = self.weight * self.weight_gain
- b = self.bias
- if b is not None and self.bias_gain != 1:
- b = b * self.bias_gain
- x = F.linear(x, w)
- x = bias_act.bias_act(x, b, act=self.activation)
- return x
-
- def extra_repr(self) -> str:
- return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
-
-
-class Conv2d(torch.nn.Module):
- def __init__(self,
- in_channels, # Number of input channels.
- out_channels, # Number of output channels.
- kernel_size = 3, # Convolution kernel size.
- up = 1, # Integer upsampling factor.
- down = 1, # Integer downsampling factor
- activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
- resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
- conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
- bias = True,
- norm = False,
- lr_multiplier=1,
- bias_init=0,
- w_dim=None,
- gain=1,
- ):
- super().__init__()
- if norm:
- self.norm = torch.nn.InstanceNorm2d(None)
- assert norm in [True, False]
- self.up = up
- self.down = down
- self.activation = activation
- self.conv_clamp = conv_clamp if conv_clamp is None else conv_clamp * gain
- self.out_channels = out_channels
- self.in_channels = in_channels
- self.padding = kernel_size // 2
-
- self.repr = dict(
- in_channels=in_channels, out_channels=out_channels,
- kernel_size=kernel_size, up=up, down=down,
- activation=activation, resample_filter=resample_filter, conv_clamp=conv_clamp, bias=bias,
- )
-
- if self.up == 1 and self.down == 1:
- self.resample_filter = None
- else:
- self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
-
- self.act_gain = bias_act.activation_funcs[activation].def_gain * gain
- self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2))
- self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]))
- self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None
- self.bias_gain = lr_multiplier
- if w_dim is not None:
- self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
- self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0)
-
- def forward(self, x, w=None, s=None):
- tops.assert_shape(x, [None, self.weight.shape[1], None, None])
- if s is not None:
- s = s[..., :self.in_channels*2]
- gamma, beta = s.view(-1, self.in_channels*2, 1, 1).chunk(2, dim=1)
- x = fma(x, gamma, beta)
- elif hasattr(self, "affine"):
- gamma = self.affine(w).view(-1, self.in_channels, 1, 1)
- beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1)
- x = fma(x, gamma, beta)
- w = self.weight * self.weight_gain
- # Removing flip weight is not safe.
- x = conv2d_resample.conv2d_resample(x, w, self.resample_filter, self.up, self.down, self.padding, flip_weight=self.up==1)
- if hasattr(self, "norm"):
- x = self.norm(x)
- b = self.bias * self.bias_gain if self.bias is not None else None
- x = bias_act.bias_act(x, b, act=self.activation, gain=self.act_gain, clamp=self.conv_clamp)
- return x
-
- def extra_repr(self) -> str:
- return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
-
-
-class Block(torch.nn.Module):
- def __init__(self,
- in_channels, # Number of input channels, 0 = first block.
- out_channels, # Number of output channels.
- conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
- up = 1,
- down = 1,
- **layer_kwargs, # Arguments for SynthesisLayer.
- ):
- super().__init__()
- self.in_channels = in_channels
- self.down = down
- self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
- self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, **layer_kwargs)
-
- def forward(self, x, **layer_kwargs):
- x = self.conv0(x, **layer_kwargs)
- x = self.conv1(x, **layer_kwargs)
- return x
-
-
-class ResidualBlock(torch.nn.Module):
- def __init__(self,
- in_channels, # Number of input channels, 0 = first block.
- out_channels, # Number of output channels.
- conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
- up = 1,
- down = 1,
- gain_out=np.sqrt(0.5),
- fix_residual: bool = False,
- **layer_kwargs, # Arguments for conv layer.
- ):
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.down = down
- self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
-
- self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, gain=gain_out,**layer_kwargs)
-
- self.skip = Conv2d(
- in_channels, out_channels, kernel_size=1, bias=False, up=up, down=down,
- activation="linear" if fix_residual else "lrelu",
- gain=gain_out
- )
- self.gain_out = gain_out
-
- def forward(self, x, w=None, s=None, **layer_kwargs):
- y = self.skip(x)
- s_ = next(s) if s is not None else None
- x = self.conv0(x, w, s=s_, **layer_kwargs)
- s_ = next(s) if s is not None else None
- x = self.conv1(x, w, s=s_, **layer_kwargs)
- x = y + x
- return x
-
-
-class MinibatchStdLayer(torch.nn.Module):
- def __init__(self, group_size, num_channels=1):
- super().__init__()
- self.group_size = group_size
- self.num_channels = num_channels
-
- def forward(self, x):
- N, C, H, W = x.shape
- with tops.suppress_tracer_warnings(): # as_tensor results are registered as constants
- G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
- F = self.num_channels
- c = C // F
-
- y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
- y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
- y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
- y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
- y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
- y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
- y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
- x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
- return x
-
-#----------------------------------------------------------------------------
-
-class DiscriminatorEpilogue(torch.nn.Module):
- def __init__(self,
- in_channels, # Number of input channels.
- resolution: List[int], # Resolution of this block.
- mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
- mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
- activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
- conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
- ):
- super().__init__()
- self.in_channels = in_channels
- self.resolution = resolution
- self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
- self.conv = Conv2d(
- in_channels + mbstd_num_channels, in_channels,
- kernel_size=3, activation=activation, conv_clamp=conv_clamp)
- self.fc = FullyConnectedLayer(in_channels * resolution[0]*resolution[1], in_channels, activation=activation)
- self.out = FullyConnectedLayer(in_channels, 1)
-
- def forward(self, x):
- tops.assert_shape(x, [None, self.in_channels, *self.resolution]) # [NCHW]
- # Main layers.
- if self.mbstd is not None:
- x = self.mbstd(x)
- x = self.conv(x)
- x = self.fc(x.flatten(1))
- x = self.out(x)
- return x
diff --git a/dp2/loss/__init__.py b/dp2/loss/__init__.py
deleted file mode 100644
index 79165e22915aeeebabbcddebcccb281deedb40c2..0000000000000000000000000000000000000000
--- a/dp2/loss/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .sg2_loss import StyleGAN2Loss
\ No newline at end of file
diff --git a/dp2/loss/pl_regularization.py b/dp2/loss/pl_regularization.py
deleted file mode 100644
index 3a557a8b92cbcb2572e8f25633e7c89e996ef906..0000000000000000000000000000000000000000
--- a/dp2/loss/pl_regularization.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import torch
-import tops
-import numpy as np
-from sg3_torch_utils.ops import conv2d_gradfix
-
-pl_mean_total = torch.zeros([])
-
-class PLRegularization:
-
- def __init__(self, weight: float, batch_shrink: int, pl_decay:float, scale_by_mask: bool,**kwargs):
- self.pl_mean = torch.zeros([], device=tops.get_device())
- self.pl_weight = weight
- self.batch_shrink = batch_shrink
- self.pl_decay = pl_decay
- self.scale_by_mask = scale_by_mask
-
- def __call__(self, G, batch, grad_scaler):
- batch_size = batch["img"].shape[0] // self.batch_shrink
- batch = {k: v[:batch_size] for k, v in batch.items() if k != "embed_map"}
- if "embed_map" in batch:
- batch["embed_map"] = batch["embed_map"]
- z = G.get_z(batch["img"])
-
- with torch.cuda.amp.autocast(tops.AMP()):
- gen_ws = G.style_net(z)
- gen_img = G(**batch, w=gen_ws)["img"].float()
- pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
- with conv2d_gradfix.no_weight_gradients():
- # Sums over HWC
- pl_grads = torch.autograd.grad(
- outputs=[grad_scaler.scale(gen_img * pl_noise)],
- inputs=[gen_ws],
- create_graph=True,
- grad_outputs=torch.ones_like(gen_img),
- only_inputs=True)[0]
-
- pl_grads = pl_grads.float() / grad_scaler.get_scale()
- if self.scale_by_mask:
- # Percentage of pixels known
- scaling = batch["mask"].flatten(start_dim=1).mean(dim=1).view(-1, 1)
- pl_grads = pl_grads / scaling
- pl_lengths = pl_grads.square().sum(1).sqrt()
- pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
- if not torch.isnan(pl_mean).any():
- self.pl_mean.copy_(pl_mean.detach())
- pl_penalty = (pl_lengths - pl_mean).square()
- to_log = dict(pl_penalty=pl_penalty.mean().detach())
- return pl_penalty.view(-1) * self.pl_weight, to_log
\ No newline at end of file
diff --git a/dp2/loss/r1_regularization.py b/dp2/loss/r1_regularization.py
deleted file mode 100644
index 2098d792a591e4652259703085abe9d0bc55489b..0000000000000000000000000000000000000000
--- a/dp2/loss/r1_regularization.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import torch
-import tops
-
-def r1_regularization(
- real_img, real_score, mask, lambd: float, lazy_reg_interval: int,
- lazy_regularization: bool,
- scaler: torch.cuda.amp.GradScaler, mask_out: bool,
- mask_out_scale: bool,
- **kwargs
- ):
- grad = torch.autograd.grad(
- outputs=scaler.scale(real_score),
- inputs=real_img,
- grad_outputs=torch.ones_like(real_score),
- create_graph=True,
- only_inputs=True,
- )[0]
- inv_scale = 1.0 / scaler.get_scale()
- grad = grad * inv_scale
- with torch.cuda.amp.autocast(tops.AMP()):
- if mask_out:
- grad = grad * (1 - mask)
- grad = grad.square().sum(dim=[1, 2, 3])
- if mask_out and mask_out_scale:
- total_pixels = real_img.shape[1] * real_img.shape[2] * real_img.shape[3]
- n_fake = (1-mask).sum(dim=[1, 2, 3])
- scaling = total_pixels / n_fake
- grad = grad * scaling
- if lazy_regularization:
- lambd_ = lambd * lazy_reg_interval / 2 # From stylegan2, lazy regularization
- return grad * lambd_, grad.detach()
\ No newline at end of file
diff --git a/dp2/loss/sg2_loss.py b/dp2/loss/sg2_loss.py
deleted file mode 100644
index 4a0d1150a66bbff4d7f408d9b68612adac179cdb..0000000000000000000000000000000000000000
--- a/dp2/loss/sg2_loss.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import functools
-import torch
-import tops
-from tops import logger
-from dp2.utils import forward_D_fake
-from .utils import nsgan_d_loss, nsgan_g_loss
-from .r1_regularization import r1_regularization
-from .pl_regularization import PLRegularization
-
-class StyleGAN2Loss:
-
- def __init__(
- self,
- D,
- G,
- r1_opts: dict,
- EP_lambd: float,
- lazy_reg_interval: int,
- lazy_regularization: bool,
- pl_reg_opts: dict,
- ) -> None:
- self.gradient_step_D = 0
- self._lazy_reg_interval = lazy_reg_interval
- self.D = D
- self.G = G
- self.EP_lambd = EP_lambd
- self.lazy_regularization = lazy_regularization
- self.r1_reg = functools.partial(
- r1_regularization, **r1_opts, lazy_reg_interval=lazy_reg_interval,
- lazy_regularization=lazy_regularization)
- self.do_PL_Reg = False
- if pl_reg_opts.weight > 0:
- self.pl_reg = PLRegularization(**pl_reg_opts)
- self.do_PL_Reg = True
- self.pl_start_nimg = pl_reg_opts.start_nimg
-
- def D_loss(self, batch: dict, grad_scaler):
- to_log = {}
- # Forward through G and D
- do_GP = self.lazy_regularization and self.gradient_step_D % self._lazy_reg_interval == 0
- if do_GP:
- batch["img"] = batch["img"].detach().requires_grad_(True)
- with torch.cuda.amp.autocast(enabled=tops.AMP()):
- with torch.no_grad():
- G_fake = self.G(**batch, update_emas=True)
- D_out_real = self.D(**batch)
-
- D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
-
- # Non saturating loss
- nsgan_loss = nsgan_d_loss(D_out_real["score"], D_out_fake["score"])
- tops.assert_shape(nsgan_loss, (batch["img"].shape[0], ))
- to_log["d_loss"] = nsgan_loss.mean()
- total_loss = nsgan_loss
- epsilon_penalty = D_out_real["score"].pow(2).view(-1)
- to_log["epsilon_penalty"] = epsilon_penalty.mean()
- tops.assert_shape(epsilon_penalty, total_loss.shape)
- total_loss = total_loss + epsilon_penalty * self.EP_lambd
-
- # Improved gradient penalty with lazy regularization
- # Gradient penalty applies specialized autocast.
- if do_GP:
- gradient_pen, grad_unscaled = self.r1_reg(batch["img"], D_out_real["score"], batch["mask"], scaler=grad_scaler)
- to_log["r1_gradient_penalty"] = grad_unscaled.mean()
- tops.assert_shape(gradient_pen, total_loss.shape)
- total_loss = total_loss + gradient_pen
-
- batch["img"] = batch["img"].detach().requires_grad_(False)
- if "score" in D_out_real:
- to_log["real_scores"] = D_out_real["score"]
- to_log["real_logits_sign"] = D_out_real["score"].sign()
- to_log["fake_logits_sign"] = D_out_fake["score"].sign()
- to_log["fake_scores"] = D_out_fake["score"]
- to_log = {key: item.mean().detach() for key, item in to_log.items()}
- self.gradient_step_D += 1
- return total_loss.mean(), to_log
-
- def G_loss(self, batch: dict, grad_scaler):
- with torch.cuda.amp.autocast(enabled=tops.AMP()):
- to_log = {}
- # Forward through G and D
- G_fake = self.G(**batch)
- D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
- # Adversarial Loss
- total_loss = nsgan_g_loss(D_out_fake["score"]).view(-1)
- to_log["g_loss"] = total_loss.mean()
- tops.assert_shape(total_loss, (batch["img"].shape[0], ))
-
- if self.do_PL_Reg and logger.global_step() >= self.pl_start_nimg:
- pl_reg, to_log_ = self.pl_reg(self.G, batch, grad_scaler=grad_scaler)
- total_loss = total_loss + pl_reg.mean()
- to_log.update(to_log_)
- to_log = {key: item.mean().detach() for key, item in to_log.items()}
- return total_loss.mean(), to_log
diff --git a/dp2/loss/utils.py b/dp2/loss/utils.py
deleted file mode 100644
index 1c9fe96156f40ffcc08ad84652d50327263db2fd..0000000000000000000000000000000000000000
--- a/dp2/loss/utils.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import torch
-import torch.nn.functional as F
-
-def nsgan_g_loss(fake_score):
- """
- Non-saturating criterion from Goodfellow et al. 2014
- """
- return torch.nn.functional.softplus(-fake_score)
-
-
-def nsgan_d_loss(real_score, fake_score):
- """
- Non-saturating criterion from Goodfellow et al. 2014
- """
- d_loss = F.softplus(-real_score) + F.softplus(fake_score)
- return d_loss.view(-1)
-
-
-def smooth_masked_l1_loss(x, target, mask):
- """
- Pixel-wise l1 loss for the area indicated by mask
- """
- # Beta=.1 <-> square loss if pixel difference <= 12.8
- l1 = F.smooth_l1_loss(x*mask, target*mask, beta=.1, reduction="none").sum(dim=[1,2,3]) / mask.sum(dim=[1, 2, 3])
- return l1
diff --git a/dp2/metrics/__init__.py b/dp2/metrics/__init__.py
deleted file mode 100644
index fc224b42cc5ceeaf2e68d6371ab9da1dade557ed..0000000000000000000000000000000000000000
--- a/dp2/metrics/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .torch_metrics import compute_metrics_iteratively
-from .fid import compute_fid
-from .ppl import calculate_ppl
\ No newline at end of file
diff --git a/dp2/metrics/fid.py b/dp2/metrics/fid.py
deleted file mode 100644
index 66eb5e0060d60294c4cdf80254e583cf8fad8bc2..0000000000000000000000000000000000000000
--- a/dp2/metrics/fid.py
+++ /dev/null
@@ -1,72 +0,0 @@
-import tops
-from dp2 import utils
-from pathlib import Path
-from torch_fidelity.generative_model_modulewrapper import GenerativeModelModuleWrapper
-import torch
-import torch_fidelity
-
-
-class GeneratorIteratorWrapper(GenerativeModelModuleWrapper):
-
- def __init__(self, generator, dataloader, zero_z: bool, n_diverse: int):
- if isinstance(generator, utils.EMA):
- generator = generator.generator
- z_size = generator.z_channels
- super().__init__(generator, z_size, "normal", 0)
- self.zero_z = zero_z
- self.dataloader = iter(dataloader)
- self.n_diverse = n_diverse
- self.cur_div_idx = 0
-
- @torch.no_grad()
- def forward(self, z, **kwargs):
- if self.cur_div_idx == 0:
- self.batch = next(self.dataloader)
- if self.zero_z:
- z = z.zero_()
- self.cur_div_idx += 1
- self.cur_div_idx = 0 if self.cur_div_idx == self.n_diverse else self.cur_div_idx
- with torch.cuda.amp.autocast(enabled=tops.AMP()):
- img = self.module(**self.batch)["img"]
- img = (utils.denormalize_img(img)*255).byte()
- return img
-
-
-def compute_fid(generator, dataloader, real_directory, n_source, zero_z, n_diverse):
- generator = GeneratorIteratorWrapper(generator, dataloader, zero_z, n_diverse)
- batch_size = dataloader.batch_size
- num_samples = (n_source * n_diverse) // batch_size * batch_size
- assert n_diverse >= 1
- assert (not zero_z) or n_diverse == 1
- assert num_samples % batch_size == 0
- assert n_source <= batch_size * len(dataloader), (batch_size*len(dataloader), n_source, n_diverse)
- metrics = torch_fidelity.calculate_metrics(
- input1=generator,
- input2=real_directory,
- cuda=torch.cuda.is_available(),
- fid=True,
- input2_cache_name="_".join(Path(real_directory).parts) + "_cached",
- input1_model_num_samples=int(num_samples),
- batch_size=dataloader.batch_size
- )
- return metrics["frechet_inception_distance"]
-
-
-if __name__ == "__main__":
- import click
- from dp2.config import Config
- from dp2.data import build_dataloader_val
- from dp2.infer import build_trained_generator
- @click.command()
- @click.argument("config_path")
- @click.option("--n_source", default=200, type=int)
- @click.option("--n_diverse", default=5, type=int)
- @click.option("--zero_z", default=False, is_flag=True)
- def run(config_path, n_source: int, n_diverse: int, zero_z: bool):
- cfg = Config.fromfile(config_path)
- dataloader = build_dataloader_val(cfg)
- generator, _ = build_trained_generator(cfg)
- print(compute_fid(
- generator, dataloader, cfg.fid_real_directory, n_source, zero_z, n_diverse))
-
- run()
\ No newline at end of file
diff --git a/dp2/metrics/fid_clip.py b/dp2/metrics/fid_clip.py
deleted file mode 100644
index 6712fd48503b787bc8e2197a6a737bcd73546b35..0000000000000000000000000000000000000000
--- a/dp2/metrics/fid_clip.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import pickle
-import torch
-import torchvision
-from pathlib import Path
-from dp2 import utils
-import tops
-try:
- import clip
-except ImportError:
- print("Could not import clip.")
-from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric
-clip_model = None
-clip_preprocess = None
-
-
-@torch.no_grad()
-def compute_fid_clip(
- dataloader, generator,
- cache_directory,
- data_len=None,
- **kwargs
- ) -> dict:
- """
- FID CLIP following the description in The Role of ImageNet Classes in Frechet Inception Distance, Thomas Kynkaamniemi et al.
- Args:
- n_samples (int): Creates N samples from same image to calculate stats
- """
- global clip_model, clip_preprocess
- if clip_model is None:
- clip_model, preprocess = clip.load("ViT-B/32", device="cpu")
- normalize_fn = preprocess.transforms[-1]
- img_mean = normalize_fn.mean
- img_std = normalize_fn.std
- clip_model = tops.to_cuda(clip_model.visual)
- clip_preprocess = tops.to_cuda(torch.nn.Sequential(
- torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
- torchvision.transforms.Normalize(img_mean, img_std)
- ))
- cache_directory = Path(cache_directory)
- if data_len is None:
- data_len = len(dataloader)*dataloader.batch_size
- fid_cache_path = cache_directory.joinpath("fid_stats_clip.pkl")
- has_fid_cache = fid_cache_path.is_file()
- if not has_fid_cache:
- fid_features_real = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device())
- fid_features_fake = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device())
- eidx = 0
- n_samples_seen = 0
- for batch in utils.tqdm_(iter(dataloader), desc="Computing FID CLIP."):
- sidx = eidx
- eidx = sidx + batch["img"].shape[0]
- n_samples_seen += batch["img"].shape[0]
- with torch.cuda.amp.autocast(tops.AMP()):
- fakes = generator(**batch)["img"]
- real_data = batch["img"]
- fakes = utils.denormalize_img(fakes)
- real_data = utils.denormalize_img(real_data)
- if not has_fid_cache:
- real_data = clip_preprocess(real_data)
- fid_features_real[sidx:eidx] = clip_model(real_data)
- fakes = clip_preprocess(fakes)
- fid_features_fake[sidx:eidx] = clip_model(fakes)
- fid_features_fake = fid_features_fake[:n_samples_seen]
- fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu()
- if has_fid_cache:
- if tops.rank() == 0:
- with open(fid_cache_path, "rb") as fp:
- fid_stat_real = pickle.load(fp)
- else:
- fid_features_real = fid_features_real[:n_samples_seen]
- fid_features_real = tops.all_gather_uneven(fid_features_real).cpu()
- assert fid_features_real.shape == fid_features_fake.shape
- if tops.rank() == 0:
- fid_stat_real = fid_features_to_statistics(fid_features_real)
- cache_directory.mkdir(exist_ok=True, parents=True)
- with open(fid_cache_path, "wb") as fp:
- pickle.dump(fid_stat_real, fp)
-
- if tops.rank() == 0:
- print("Starting calculation of fid from features of shape:", fid_features_fake.shape)
- fid_stat_fake = fid_features_to_statistics(fid_features_fake)
- fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"]
- return dict(fid_clip=fid_)
- return dict(fid_clip=-1)
diff --git a/dp2/metrics/lpips.py b/dp2/metrics/lpips.py
deleted file mode 100644
index cfda315da83fab7bb1bf5e8d1b09f3a721ded064..0000000000000000000000000000000000000000
--- a/dp2/metrics/lpips.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import torch
-import tops
-import sys
-from contextlib import redirect_stdout
-from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average
-
-class SampleSimilarityLPIPS(torch.nn.Module):
- SUPPORTED_DTYPES = {
- 'uint8': torch.uint8,
- 'float32': torch.float32,
- }
-
- def __init__(self):
-
- super().__init__()
- self.chns = [64, 128, 256, 512, 512]
- self.L = len(self.chns)
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=True)
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=True)
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=True)
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=True)
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=True)
- self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
- with redirect_stdout(sys.stderr):
- fp = tops.download_file(URL_VGG16_LPIPS)
- state_dict = torch.load(fp, map_location="cpu")
- self.load_state_dict(state_dict)
- self.net = VGG16features()
- self.eval()
- for param in self.parameters():
- param.requires_grad = False
- mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2
- inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255)
- self.register_buffer("mean", mean_rescaled)
- self.register_buffer("std", inv_std_rescaled)
-
- def normalize(self, x):
- # torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
- x = (x.float() - self.mean) * self.std
- return x
-
- @staticmethod
- def resize(x, size):
- if x.shape[-1] > size and x.shape[-2] > size:
- x = torch.nn.functional.interpolate(x, (size, size), mode='area')
- else:
- x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False)
- return x
-
- def lpips_from_feats(self, feats0, feats1):
- diffs = {}
- for kk in range(self.L):
- diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
-
- res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)]
- val = sum(res)
- return val
-
- def get_feats(self, x):
- assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW'
- if x.shape[-2] < 16: # Resize images < 16x16
- f = 16 / x.shape[-2]
- size = tuple([int(f*_) for _ in x.shape[-2:]])
- x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False)
- in0_input = self.normalize(x)
- outs0 = self.net.forward(in0_input)
-
- feats = {}
- for kk in range(self.L):
- feats[kk] = normalize_tensor(outs0[kk])
- return feats
-
- def forward(self, in0, in1):
- feats0 = self.get_feats(in0)
- feats1 = self.get_feats(in1)
- return self.lpips_from_feats(feats0, feats1), feats0, feats1
diff --git a/dp2/metrics/ppl.py b/dp2/metrics/ppl.py
deleted file mode 100644
index 3d30b220546bcd8e44c36eede56361868e26879a..0000000000000000000000000000000000000000
--- a/dp2/metrics/ppl.py
+++ /dev/null
@@ -1,110 +0,0 @@
-import numpy as np
-import torch
-import tops
-from dp2 import utils
-from torch_fidelity.helpers import get_kwarg, vassert
-from torch_fidelity.defaults import DEFAULTS as PPL_DEFAULTS
-from torch_fidelity.utils import sample_random, batch_interp, create_sample_similarity
-
-
-def slerp(a, b, t):
- a = a / a.norm(dim=-1, keepdim=True)
- b = b / b.norm(dim=-1, keepdim=True)
- d = (a * b).sum(dim=-1, keepdim=True)
- p = t * torch.acos(d)
- c = b - d * a
- c = c / c.norm(dim=-1, keepdim=True)
- d = a * torch.cos(p) + c * torch.sin(p)
- d = d / d.norm(dim=-1, keepdim=True)
- return d
-
-
-@torch.no_grad()
-def calculate_ppl(
- dataloader,
- generator,
- latent_space=None,
- data_len=None,
- **kwargs) -> dict:
- """
- Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py
- """
- if latent_space is None:
- latent_space = generator.latent_space
- assert latent_space in ["Z", "W"], f"Not supported latent space: {latent_space}"
- epsilon = PPL_DEFAULTS["ppl_epsilon"]
- interp = PPL_DEFAULTS['ppl_z_interp_mode']
- similarity_name = PPL_DEFAULTS['ppl_sample_similarity']
- sample_similarity_resize = PPL_DEFAULTS['ppl_sample_similarity_resize']
- sample_similarity_dtype = PPL_DEFAULTS['ppl_sample_similarity_dtype']
- discard_percentile_lower = PPL_DEFAULTS['ppl_discard_percentile_lower']
- discard_percentile_higher = PPL_DEFAULTS['ppl_discard_percentile_higher']
-
- vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number')
- vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile')
- vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile')
- if discard_percentile_lower is not None and discard_percentile_higher is not None:
- vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles')
-
- sample_similarity = create_sample_similarity(
- similarity_name,
- sample_similarity_resize=sample_similarity_resize,
- sample_similarity_dtype=sample_similarity_dtype,
- cuda=False,
- **kwargs
- )
- sample_similarity = tops.to_cuda(sample_similarity)
- rng = np.random.RandomState(get_kwarg('rng_seed', kwargs))
- distances = []
- if data_len is None:
- data_len = len(dataloader) * dataloader.batch_size
- z0 = sample_random(rng, (data_len, generator.z_channels), "normal")
- z1 = sample_random(rng, (data_len, generator.z_channels), "normal")
- if latent_space == "Z":
- z1 = batch_interp(z0, z1, epsilon, interp)
-
- distances = torch.zeros(data_len, dtype=torch.float32, device=tops.get_device())
- print(distances.shape)
- end = 0
- n_samples = 0
- for it, batch in enumerate(utils.tqdm_(dataloader, desc="Perceptual Path Length")):
- start = end
- end = start + batch["img"].shape[0]
- n_samples += batch["img"].shape[0]
- batch_lat_e0 = tops.to_cuda(z0[start:end])
- batch_lat_e1 = tops.to_cuda(z1[start:end])
- if latent_space == "W":
- w0 = generator.get_w(batch_lat_e0, update_emas=False)
- w1 = generator.get_w(batch_lat_e1, update_emas=False)
- w1 = w0.lerp(w1, epsilon) # PPL end
- rgb1 = generator(**batch, w=w0)["img"]
- rgb2 = generator(**batch, w=w1)["img"]
- else:
- rgb1 = generator(**batch, z=batch_lat_e0)["img"]
- rgb2 = generator(**batch, z=batch_lat_e1)["img"]
- rgb1 = utils.denormalize_img(rgb1).mul(255).byte()
- rgb2 = utils.denormalize_img(rgb2).mul(255).byte()
-
- sim = sample_similarity(rgb1, rgb2)
- dist_lat_e01 = sim / (epsilon ** 2)
- distances[start:end] = dist_lat_e01.view(-1)
- distances = distances[:n_samples]
- distances = tops.all_gather_uneven(distances).cpu().numpy()
- if tops.rank() != 0:
- return {"ppl/mean": -1, "ppl/std": -1}
- if tops.rank() == 0:
- cond, lo, hi = None, None, None
- if discard_percentile_lower is not None:
- lo = np.percentile(distances, discard_percentile_lower, interpolation='lower')
- cond = lo <= distances
- if discard_percentile_higher is not None:
- hi = np.percentile(distances, discard_percentile_higher, interpolation='higher')
- cond = np.logical_and(cond, distances <= hi)
- if cond is not None:
- distances = np.extract(cond, distances)
- return {
- "ppl/mean": float(np.mean(distances)),
- "ppl/std": float(np.std(distances)),
- }
- else:
- return {"ppl/mean"}
diff --git a/dp2/metrics/torch_metrics.py b/dp2/metrics/torch_metrics.py
deleted file mode 100644
index a6682afbbbe9e5205d48743ac39db8c647607666..0000000000000000000000000000000000000000
--- a/dp2/metrics/torch_metrics.py
+++ /dev/null
@@ -1,176 +0,0 @@
-import pickle
-import numpy as np
-import torch
-import time
-from pathlib import Path
-from dp2 import utils
-import tops
-from .lpips import SampleSimilarityLPIPS
-from torch_fidelity.defaults import DEFAULTS as trf_defaults
-from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric
-from torch_fidelity.utils import create_feature_extractor
-lpips_model = None
-fid_model = None
-
-@torch.no_grad()
-def mse(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
- se = (images1 - images2) ** 2
- se = se.view(images1.shape[0], -1).mean(dim=1)
- return se
-
-@torch.no_grad()
-def psnr(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
- mse_ = mse(images1, images2)
- psnr = 10 * torch.log10(1 / mse_)
- return psnr
-
-@torch.no_grad()
-def lpips(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
- return _lpips_w_grad(images1, images2)
-
-
-def _lpips_w_grad(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
- global lpips_model
- if lpips_model is None:
- lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
-
- images1 = images1.mul(255)
- images2 = images2.mul(255)
- with torch.cuda.amp.autocast(tops.AMP()):
- dists = lpips_model(images1, images2)[0].view(-1)
- return dists
-
-
-
-
-@torch.no_grad()
-def compute_metrics_iteratively(
- dataloader, generator,
- cache_directory,
- data_len=None,
- truncation_value: float=None,
- ) -> dict:
- """
- Args:
- n_samples (int): Creates N samples from same image to calculate stats
- dataset_percentage (float): The percentage of the dataset to compute metrics on.
- """
-
- global lpips_model, fid_model
- if lpips_model is None:
- lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
- if fid_model is None:
- fid_model = create_feature_extractor(
- trf_defaults["feature_extractor"], [trf_defaults["feature_layer_fid"]], cuda=False)
- fid_model = tops.to_cuda(fid_model)
- cache_directory = Path(cache_directory)
- start_time = time.time()
- lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device())
- diversity_total = torch.zeros_like(lpips_total)
- fid_cache_path = cache_directory.joinpath("fid_stats.pkl")
- has_fid_cache = fid_cache_path.is_file()
- if data_len is None:
- data_len = len(dataloader)*dataloader.batch_size
- if not has_fid_cache:
- fid_features_real = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device())
- fid_features_fake = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device())
- n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device())
- eidx = 0
- for batch in utils.tqdm_(iter(dataloader), desc="Computing FID, LPIPS and LPIPS Diversity"):
- sidx = eidx
- eidx = sidx + batch["img"].shape[0]
- n_samples_seen += batch["img"].shape[0]
- with torch.cuda.amp.autocast(tops.AMP()):
- fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"]
- fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"]
- fakes1 = utils.denormalize_img(fakes1).mul(255)
- fakes2 = utils.denormalize_img(fakes2).mul(255)
- real_data = utils.denormalize_img(batch["img"]).mul(255)
- lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1)
- fake2_lpips_feats = lpips_model.get_feats(fakes2)
- lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats)
-
- lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2)
- diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum()
- if not has_fid_cache:
- fid_features_real[sidx:eidx] = fid_model(real_data.byte())[0]
- fid_features_fake[sidx:eidx] = fid_model(fakes1.byte())[0]
- fid_features_fake = fid_features_fake[:n_samples_seen]
- if has_fid_cache:
- if tops.rank() == 0:
- with open(fid_cache_path, "rb") as fp:
- fid_stat_real = pickle.load(fp)
- else:
- fid_features_real = fid_features_real[:n_samples_seen]
- fid_features_real = tops.all_gather_uneven(fid_features_real).cpu()
- if tops.rank() == 0:
- fid_stat_real = fid_features_to_statistics(fid_features_real)
- cache_directory.mkdir(exist_ok=True, parents=True)
- with open(fid_cache_path, "wb") as fp:
- pickle.dump(fid_stat_real, fp)
- fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu()
- if tops.rank() == 0:
- print("Starting calculation of fid from features of shape:", fid_features_fake.shape)
- fid_stat_fake = fid_features_to_statistics(fid_features_fake)
- fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"]
- tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM)
- tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM)
- tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM)
- lpips_total = lpips_total / n_samples_seen
- diversity_total = diversity_total / n_samples_seen
- to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total)
- if tops.rank() == 0:
- to_return["fid"] = fid_
- else:
- to_return["fid"] = -1
- to_return["validation_time_s"] = time.time() - start_time
- return to_return
-
-
-@torch.no_grad()
-def compute_lpips(
- dataloader, generator,
- truncation_value: float=None,
- data_len=None,
- ) -> dict:
- """
- Args:
- n_samples (int): Creates N samples from same image to calculate stats
- dataset_percentage (float): The percentage of the dataset to compute metrics on.
- """
- global lpips_model, fid_model
- if lpips_model is None:
- lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
- start_time = time.time()
- lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device())
- diversity_total = torch.zeros_like(lpips_total)
- if data_len is None:
- data_len = len(dataloader) * dataloader.batch_size
- eidx = 0
- n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device())
- for batch in utils.tqdm_(dataloader, desc="Validating on dataset."):
- sidx = eidx
- eidx = sidx + batch["img"].shape[0]
- n_samples_seen += batch["img"].shape[0]
- with torch.cuda.amp.autocast(tops.AMP()):
- fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"]
- fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"]
- real_data = batch["img"]
- fakes1 = utils.denormalize_img(fakes1).mul(255)
- fakes2 = utils.denormalize_img(fakes2).mul(255)
- real_data = utils.denormalize_img(real_data).mul(255)
- lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1)
- fake2_lpips_feats = lpips_model.get_feats(fakes2)
- lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats)
-
- lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2)
- diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum()
- tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM)
- tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM)
- tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM)
- lpips_total = lpips_total / n_samples_seen
- diversity_total = diversity_total / n_samples_seen
- to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total)
- to_return = {k: v.cpu().item() for k, v in to_return.items()}
- to_return["validation_time_s"] = time.time() - start_time
- return to_return
diff --git a/dp2/utils/__init__.py b/dp2/utils/__init__.py
deleted file mode 100644
index a28c97be96580185208e5c6e56b1307a0fd9b9be..0000000000000000000000000000000000000000
--- a/dp2/utils/__init__.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import pathlib
-from tops.config import LazyConfig
-from .torch_utils import (
- im2torch, im2numpy, denormalize_img, set_requires_grad, forward_D_fake,
- binary_dilation, crop_box, remove_pad
-)
-from .ema import EMA
-from .utils import init_tops, tqdm_, print_config, config_to_str, trange_
-from .cse import from_E_to_vertex
-
-
-def load_config(config_path):
- config_path = pathlib.Path(config_path)
- assert config_path.is_file(), config_path
- cfg = LazyConfig.load(str(config_path))
- cfg.output_dir = pathlib.Path(str(config_path).replace("configs", str(cfg.common.output_dir)).replace(".py", ""))
- if cfg.common.experiment_name is None:
- cfg.experiment_name = str(config_path)
- else:
- cfg.experiment_name = cfg.common.experiment_name
- cfg.checkpoint_dir = cfg.output_dir.joinpath("checkpoints")
- print("Saving outputs to:", cfg.output_dir)
- return cfg
diff --git a/dp2/utils/bufferless_video_capture.py b/dp2/utils/bufferless_video_capture.py
deleted file mode 100644
index b071c1c4316ad48127c86c4f52ca40f66530edf7..0000000000000000000000000000000000000000
--- a/dp2/utils/bufferless_video_capture.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import queue
-import threading
-import cv2
-
-
-class BufferlessVideoCapture:
-
- def __init__(self, name, width=None, height=None):
- self.cap = cv2.VideoCapture(name)
- if width is not None and height is not None:
- self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
- self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
- self.q = queue.Queue()
- t = threading.Thread(target=self._reader)
- t.daemon = True
- t.start()
-
- # read frames as soon as they are available, keeping only most recent one
- def _reader(self):
- while True:
- ret, frame = self.cap.read()
- if not ret:
- break
- if not self.q.empty():
- try:
- self.q.get_nowait() # discard previous (unprocessed) frame
- except queue.Empty:
- pass
- self.q.put((ret, frame))
-
- def read(self):
- return self.q.get()
\ No newline at end of file
diff --git a/dp2/utils/cse.py b/dp2/utils/cse.py
deleted file mode 100644
index f1b5a2b55cb912df0a50961eafb260eb297b0d4f..0000000000000000000000000000000000000000
--- a/dp2/utils/cse.py
+++ /dev/null
@@ -1,21 +0,0 @@
-import warnings
-import torch
-from densepose.modeling.cse.utils import get_closest_vertices_mask_from_ES
-
-
-def from_E_to_vertex(E, M, embed_map):
- """
- M is 1 for unkown regions
- """
- assert len(E.shape) == 4
- assert len(E.shape) == len(M.shape), (E.shape, M.shape)
- assert E.shape[0] == 1
- M = M.float()
- M = torch.cat([M, 1-M], dim=1)
- with warnings.catch_warnings(): # Ignore userError for pytorch interpolate from detectron2
- warnings.filterwarnings("ignore")
- vertices, _ = get_closest_vertices_mask_from_ES(
- E, M, E.shape[2], E.shape[3],
- embed_map, device=E.device)
-
- return vertices.long()
diff --git a/dp2/utils/ema.py b/dp2/utils/ema.py
deleted file mode 100644
index 0c508213d4c445e2417607a7d3957d3ec953eb1f..0000000000000000000000000000000000000000
--- a/dp2/utils/ema.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import torch
-import copy
-import tops
-from tops import logger
-from .torch_utils import set_requires_grad
-
-class EMA:
- """
- Expoenential moving average.
- See:
- Yazici, Y. et al.The unusual effectiveness of averaging in GAN training. ICLR 2019
-
- """
-
- def __init__(
- self,
- generator: torch.nn.Module,
- batch_size: int,
- rampup: float,
- ):
- self.rampup = rampup
- self._nimg_half_time = batch_size * 10 / 32 * 1000
- self._batch_size = batch_size
- with torch.no_grad():
- self.generator = copy.deepcopy(generator.cpu()).eval()
- self.generator = tops.to_cuda(self.generator)
- self.old_ra_beta = 0
- set_requires_grad(self.generator, False)
-
- def update_beta(self):
- y = self._nimg_half_time
- global_step = logger.global_step()
- if self.rampup != None:
- y = min(y, global_step*self.rampup)
- self.ra_beta = 0.5 ** (self._batch_size/max(y, 1e-8))
- if self.ra_beta != self.old_ra_beta:
- logger.add_scalar("stats/EMA_beta", self.ra_beta)
- self.old_ra_beta = self.ra_beta
-
- @torch.no_grad()
- def update(self, normal_G):
- with torch.autograd.profiler.record_function("EMA_update"):
- for ema_p, p in zip(self.generator.parameters(),
- normal_G.parameters()):
- ema_p.copy_(p.lerp(ema_p, self.ra_beta))
- for ema_buf, buff in zip(self.generator.buffers(),
- normal_G.buffers()):
- ema_buf.copy_(buff)
-
- def __call__(self, *args, **kwargs):
- return self.generator(*args, **kwargs)
-
- def __getattr__(self, name: str):
- if hasattr(self.generator, name):
- return getattr(self.generator, name)
- raise AttributeError(f"Generator object has no attribute {name}")
-
- def cuda(self, *args, **kwargs):
- self.generator = self.generator.cuda()
- return self
-
- def state_dict(self, *args, **kwargs):
- return self.generator.state_dict(*args, **kwargs)
-
- def load_state_dict(self, *args, **kwargs):
- return self.generator.load_state_dict(*args, **kwargs)
-
- def eval(self):
- self.generator.eval()
-
- def train(self):
- self.generator.train()
-
- @property
- def module(self):
- return self.generator.module
-
- def sample(self, *args, **kwargs):
- return self.generator.sample(*args, **kwargs)
diff --git a/dp2/utils/torch_utils.py b/dp2/utils/torch_utils.py
deleted file mode 100644
index 46defb854b11a615aad3b918dce4a650b0b80889..0000000000000000000000000000000000000000
--- a/dp2/utils/torch_utils.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import torch
-import tops
-
-
-def denormalize_img(image, mean=0.5, std=0.5):
- image = image * std + mean
- image = torch.clamp(image.float(), 0, 1)
- image = (image * 255)
- image = torch.round(image)
- return image / 255
-
-
-@torch.no_grad()
-def im2numpy(images, to_uint8=False, denormalize=False):
- if denormalize:
- images = denormalize_img(images)
- if images.dtype != torch.uint8:
- images = images.clamp(0, 1)
- return tops.im2numpy(images, to_uint8=to_uint8)
-
-
-@torch.no_grad()
-def im2torch(im, cuda=False, normalize=True, to_float=True):
- im = tops.im2torch(im, cuda=cuda, to_float=to_float)
- if normalize:
- assert im.min() >= 0.0 and im.max() <= 1.0
- if normalize:
- im = im * 2 - 1
- return im
-
-
-@torch.no_grad()
-def binary_dilation(im: torch.Tensor, kernel: torch.Tensor):
- assert len(im.shape) == 4
- assert len(kernel.shape) == 2
- kernel = kernel.unsqueeze(0).unsqueeze(0)
- padding = kernel.shape[-1]//2
- assert kernel.shape[-1] % 2 != 0
- if isinstance(im, torch.cuda.FloatTensor):
- im, kernel = im.half(), kernel.half()
- else:
- im, kernel = im.float(), kernel.float()
- im = torch.nn.functional.conv2d(
- im, kernel, groups=im.shape[1], padding=padding)
- im = im > 0.5
- return im
-
-
-@torch.no_grad()
-def binary_erosion(im: torch.Tensor, kernel: torch.Tensor):
- assert len(im.shape) == 4
- assert len(kernel.shape) == 2
- kernel = kernel.unsqueeze(0).unsqueeze(0)
- padding = kernel.shape[-1]//2
- assert kernel.shape[-1] % 2 != 0
- if isinstance(im, torch.cuda.FloatTensor):
- im, kernel = im.half(), kernel.half()
- else:
- im, kernel = im.float(), kernel.float()
- ksum = kernel.sum()
- padding = (padding, padding, padding, padding)
- im = torch.nn.functional.pad(im, padding, mode="reflect")
- im = torch.nn.functional.conv2d(
- im, kernel, groups=im.shape[1])
- return im.round() == ksum
-
-
-def set_requires_grad(value: torch.nn.Module, requires_grad: bool):
- if isinstance(value, (list, tuple)):
- for param in value:
- param.requires_grad = requires_grad
- return
- for p in value.parameters():
- p.requires_grad = requires_grad
-
-
-def forward_D_fake(batch, fake_img, discriminator, **kwargs):
- fake_batch = {k: v for k, v in batch.items() if k != "img"}
- fake_batch["img"] = fake_img
- return discriminator(**fake_batch, **kwargs)
-
-
-
-def remove_pad(x: torch.Tensor, bbox_XYXY, imshape):
- """
- Remove padding that is shown as negative
- """
- H, W = imshape
- x0, y0, x1, y1 = bbox_XYXY
- padding = [
- max(0, -x0),
- max(0, -y0),
- max(x1 - W, 0),
- max(y1 - H, 0)
- ]
- x0, y0 = padding[:2]
- x1 = x.shape[2] - padding[2]
- y1 = x.shape[1] - padding[3]
- return x[:, y0:y1, x0:x1]
-
-
-def crop_box(x: torch.Tensor, bbox_XYXY) -> torch.Tensor:
- """
- Crops x by bbox_XYXY.
- """
- x0, y0, x1, y1 = bbox_XYXY
- x0 = max(x0, 0)
- y0 = max(y0, 0)
- x1 = min(x1, x.shape[-1])
- y1 = min(y1, x.shape[-2])
- return x[..., y0:y1, x0:x1]
\ No newline at end of file
diff --git a/dp2/utils/utils.py b/dp2/utils/utils.py
deleted file mode 100644
index 965eb6ae987bd6b5644e50116bfa6317e4e36769..0000000000000000000000000000000000000000
--- a/dp2/utils/utils.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import tops
-import tqdm
-from tops import logger, highlight_py_str
-from tops.config import LazyConfig
-
-
-def print_config(cfg):
- logger.log("\n" + highlight_py_str(LazyConfig.to_py(cfg, prefix="")))
-
-
-def config_to_str(cfg):
- return LazyConfig.to_py(cfg, prefix=".")
-
-
-def init_tops(cfg, reinit=False):
- tops.init(
- cfg.output_dir, cfg.common.logger_backend, cfg.experiment_name,
- cfg.common.wandb_project, dict(cfg), reinit)
-
-
-def tqdm_(iterator, *args, **kwargs):
- if tops.rank() == 0:
- return tqdm.tqdm(iterator, *args, **kwargs)
- return iterator
-
-
-def trange_(*args, **kwargs):
- if tops.rank() == 0:
- return tqdm.trange(*args, **kwargs)
- return range(*args)
\ No newline at end of file
diff --git a/dp2/utils/vis_utils.py b/dp2/utils/vis_utils.py
deleted file mode 100644
index ee227f359b4ab3af3a1123e2900373fed01f6c38..0000000000000000000000000000000000000000
--- a/dp2/utils/vis_utils.py
+++ /dev/null
@@ -1,407 +0,0 @@
-import torch
-import tops
-import cv2
-import torchvision.transforms.functional as F
-from typing import Optional, List, Union, Tuple
-from .cse import from_E_to_vertex
-import numpy as np
-from tops import download_file
-from .torch_utils import (
- denormalize_img, binary_dilation, binary_erosion,
- remove_pad, crop_box)
-from torchvision.utils import _generate_color_palette
-from PIL import Image, ImageColor, ImageDraw
-
-
-def get_coco_keypoints():
- # From: https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/keypoints.py
- keypoints = [
- 'nose',
- 'left_eye',
- 'right_eye',
- 'left_ear',
- 'right_ear',
- 'left_shoulder',
- 'right_shoulder',
- 'left_elbow',
- 'right_elbow',
- 'left_wrist',
- 'right_wrist',
- 'left_hip',
- 'right_hip',
- 'left_knee',
- 'right_knee',
- 'left_ankle',
- 'right_ankle'
- ]
- keypoint_flip_map = {
- 'left_eye': 'right_eye',
- 'left_ear': 'right_ear',
- 'left_shoulder': 'right_shoulder',
- 'left_elbow': 'right_elbow',
- 'left_wrist': 'right_wrist',
- 'left_hip': 'right_hip',
- 'left_knee': 'right_knee',
- 'left_ankle': 'right_ankle'
- }
- connectivity = {
- "nose": "left_eye",
- "left_eye": "right_eye",
- "right_eye": "nose",
- "left_ear": "left_eye",
- "right_ear": "right_eye",
- "left_shoulder": "nose",
- "right_shoulder": "nose",
- "left_elbow": "left_shoulder",
- "right_elbow": "right_shoulder",
- "left_wrist": "left_elbow",
- "right_wrist": "right_elbow",
- "left_hip": "left_shoulder",
- "right_hip": "right_shoulder",
- "left_knee": "left_hip",
- "right_knee": "right_hip",
- "left_ankle": "left_knee",
- "right_ankle": "right_knee"
- }
- connectivity_indices = [
- (sidx, keypoints.index(connectivity[kp]))
- for sidx, kp in enumerate(keypoints)
- ]
- return keypoints, keypoint_flip_map, connectivity_indices
-
-
-@torch.no_grad()
-def draw_keypoints(
- image: torch.Tensor,
- keypoints: torch.Tensor,
- connectivity: Optional[List[Tuple[int, int]]] = None,
- colors: Optional[Union[str, Tuple[int, int, int]]] = None,
- radius: int = 1,
- width: int = 1,
-) -> torch.Tensor:
-
-
- """
- Function taken from torchvision source code. Added in torchvision 0.12
-
- Draws Keypoints on given RGB image.
- The values of the input image should be uint8 between 0 and 255.
-
- Args:
- image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
- keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
- in the format [x, y].
- connectivity (List[Tuple[int, int]]]): A List of tuple where,
- each tuple contains pair of keypoints to be connected.
- colors (str, Tuple): The color can be represented as
- PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
- radius (int): Integer denoting radius of keypoint.
- width (int): Integer denoting width of line connecting keypoints.
-
- Returns:
- img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
- """
-
- if not isinstance(image, torch.Tensor):
- raise TypeError(f"The image must be a tensor, got {type(image)}")
- elif image.dtype != torch.uint8:
- raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
- elif image.dim() != 3:
- raise ValueError("Pass individual images, not batches")
- elif image.size()[0] != 3:
- raise ValueError("Pass an RGB image. Other Image formats are not supported")
-
- if keypoints.ndim != 3:
- raise ValueError("keypoints must be of shape (num_instances, K, 2)")
-
- ndarr = image.permute(1, 2, 0).cpu().numpy()
- img_to_draw = Image.fromarray(ndarr)
- draw = ImageDraw.Draw(img_to_draw)
- img_kpts = keypoints.to(torch.int64).tolist()
-
- for kpt_id, kpt_inst in enumerate(img_kpts):
- for inst_id, kpt in enumerate(kpt_inst):
- x1 = kpt[0] - radius
- x2 = kpt[0] + radius
- y1 = kpt[1] - radius
- y2 = kpt[1] + radius
- draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
-
- if connectivity:
- for connection in connectivity:
- if connection[1] >= len(kpt_inst) or connection[0] >= len(kpt_inst):
- continue
- start_pt_x = kpt_inst[connection[0]][0]
- start_pt_y = kpt_inst[connection[0]][1]
-
- end_pt_x = kpt_inst[connection[1]][0]
- end_pt_y = kpt_inst[connection[1]][1]
-
- draw.line(
- ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
- width=width,
- )
-
- return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
-
-def visualize_batch(
- img: torch.Tensor, mask: torch.Tensor,
- vertices: torch.Tensor=None,
- E_mask: torch.Tensor=None,
- embed_map: torch.Tensor=None,
- semantic_mask: torch.Tensor=None,
- embedding: torch.Tensor=None,
- keypoints: torch.Tensor=None,
- maskrcnn_mask: torch.Tensor=None,
- **kwargs) -> torch.ByteTensor:
- img = denormalize_img(img).mul(255).byte()
- img = draw_mask(img, mask)
- if maskrcnn_mask is not None:
- img = draw_mask(img, maskrcnn_mask)
- if vertices is not None or embedding is not None:
- assert E_mask is not None
- assert embed_map is not None
- img = draw_cse(img, E_mask, embedding, embed_map, vertices)
- elif semantic_mask is not None:
- img = draw_segmentation_masks(img, semantic_mask)
- if keypoints is not None:
- keypoints = keypoints.clone()
- keypoints[:, :, 0] *= img.shape[-1]
- keypoints[:, :, 1] *= img.shape[-2]
- _, _, connectivity = get_coco_keypoints()
- connectivity = np.array(connectivity)
- for idx in range(img.shape[0]):
- if keypoints.shape[-1] == 3:
- visible = (keypoints[idx:idx+1, :, 2] > 0 ).view(-1)
- else:
- visible = torch.ones(keypoints.shape[1], device=keypoints.device, dtype=torch.bool)
-
- if keypoints.shape[1] == 17: # COCO Connectivity
- c = connectivity[visible.cpu().numpy()].tolist()
- else:
- c = None
-
- kp = keypoints[idx:idx+1, visible].long()
- img[idx] = draw_keypoints(img[idx], kp, colors="red", connectivity=c)
- return img
-
-
-@torch.no_grad()
-def draw_cse(
- img: torch.Tensor, E_seg: torch.Tensor,
- embedding: torch.Tensor = None,
- embed_map: torch.Tensor = None,
- vertices: torch.Tensor = None, t=0.7
- ):
- """
- E_seg: 1 for areas with embedding
- """
- assert img.dtype == torch.uint8
- img = img.view(-1, *img.shape[-3:])
- E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:])
- if vertices is None:
- assert embedding is not None
- assert embed_map is not None
- embedding = embedding.view(-1, *embedding.shape[-3:])
- vertices = torch.stack(
- [from_E_to_vertex(e[None], e_seg[None].logical_not().float(), embed_map)
- for e, e_seg in zip(embedding, E_seg)])
-
- i = np.arange(0, 256, dtype=np.uint8).reshape(1, -1)
- colormap_JET = torch.from_numpy(cv2.applyColorMap(i, cv2.COLORMAP_JET)[0])
- color_embed_map, _ = np.load(download_file("https://dl.fbaipublicfiles.com/densepose/data/cse/mds_d=256.npy"), allow_pickle=True)
- color_embed_map = torch.from_numpy(color_embed_map).float()[:, 0]
- color_embed_map -= color_embed_map.min()
- color_embed_map /= color_embed_map.max()
- vertx2idx = (color_embed_map*255).long()
- vertx2colormap = colormap_JET[vertx2idx]
-
- vertices = vertices.view(-1, *vertices.shape[-2:])
- E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:])
- # This operation might be good to do on cpu...
- E_color = vertx2colormap[vertices.long()]
- E_color = E_color.to(E_seg.device)
- E_color = E_color.permute(0, 3, 1, 2)
- E_color = E_color*E_seg.byte()
-
- m = E_seg.bool().repeat(1, 3, 1, 1)
- img[m] = (img[m] * (1-t) + t * E_color[m]).byte()
- return img
-
-
-def draw_cse_all(
- embedding: List[torch.Tensor], E_mask: List[torch.Tensor],
- im: torch.Tensor, boxes_XYXY: list, embed_map: torch.Tensor, t=0.7):
- """
- E_seg: 1 for areas with embedding
- """
- assert len(im.shape) == 3, im.shape
- assert im.dtype == torch.uint8
-
- N = len(E_mask)
- im = im.clone()
- for i in range(N):
- assert len(E_mask[i].shape) == 2
- assert len(embedding[i].shape) == 3
- assert embed_map.shape[1] == embedding[i].shape[0]
- assert len(boxes_XYXY[i]) == 4
- E = embedding[i]
- x0, y0, x1, y1 = boxes_XYXY[i]
- E = F.resize(E, (y1-y0, x1-x0), antialias=True)
- s = E_mask[i].float()
- s = (F.resize(s.squeeze()[None], (y1-y0, x1-x0), antialias=True) > 0).float()
- box = boxes_XYXY[i]
-
- im_ = crop_box(im, box)
- s = remove_pad(s, box, im.shape[1:])
- E = remove_pad(E, box, im.shape[1:])
- E_color = draw_cse(img=im_, E_seg=s[None], embedding=E[None],embed_map=embed_map)[0]
- E_color = E_color.to(im.device)
- s = s.bool().repeat(3, 1, 1)
- crop_box(im, box)[s] = (im_[s] * (1-t) + t * E_color[s]).byte()
- return im
-
-
-
-@torch.no_grad()
-def draw_segmentation_masks(
- image: torch.Tensor,
- masks: torch.Tensor,
- alpha: float = 0.8,
- colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
-) -> torch.Tensor:
-
- """
- Draws segmentation masks on given RGB image.
- The values of the input image should be uint8 between 0 and 255.
-
- Args:
- image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
- masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
- alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
- 0 means full transparency, 1 means no transparency.
- colors (list or None): List containing the colors of the masks. The colors can
- be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
- When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list
- with one element. By default, random colors are generated for each mask.
-
- Returns:
- img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
- """
-
- if not isinstance(image, torch.Tensor):
- raise TypeError(f"The image must be a tensor, got {type(image)}")
- elif image.dtype != torch.uint8:
- raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
- elif image.dim() != 3:
- raise ValueError("Pass individual images, not batches")
- elif image.size()[0] != 3:
- raise ValueError("Pass an RGB image. Other Image formats are not supported")
- if masks.ndim == 2:
- masks = masks[None, :, :]
- if masks.ndim != 3:
- raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
- if masks.dtype != torch.bool:
- raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
- if masks.shape[-2:] != image.shape[-2:]:
- raise ValueError("The image and the masks must have the same height and width")
- num_masks = masks.size()[0]
- if num_masks == 0:
- return image
- if colors is None:
- colors = _generate_color_palette(num_masks)
- if not isinstance(colors[0], (Tuple, List)):
- colors = [colors for i in range(num_masks)]
- if colors is not None and num_masks > len(colors):
- raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
-
-
- if not isinstance(colors, list):
- colors = [colors]
- if not isinstance(colors[0], (tuple, str)):
- raise ValueError("colors must be a tuple or a string, or a list thereof")
- if isinstance(colors[0], tuple) and len(colors[0]) != 3:
- raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
-
- out_dtype = torch.uint8
-
- colors_ = []
- for color in colors:
- if isinstance(color, str):
- color = ImageColor.getrgb(color)
- color = torch.tensor(color, dtype=out_dtype, device=masks.device)
- colors_.append(color)
- img_to_draw = image.detach().clone()
- # TODO: There might be a way to vectorize this
- for mask, color in zip(masks, colors_):
- img_to_draw[:, mask] = color[:, None]
-
- out = image * (1 - alpha) + img_to_draw * alpha
- return out.to(out_dtype)
-
-
-def draw_mask(im: torch.Tensor, mask: torch.Tensor, t=0.2, color=(255, 255, 255), visualize_instances=True):
- """
- Visualize mask where mask = 0.
- Supports multiple instances.
- mask shape: [N, C, H, W], where C is different instances in same image.
- """
- orig_imshape = im.shape
- if mask.numel() == 0: return im
- assert len(mask.shape) in (3, 4), mask.shape
- mask = mask.view(-1, *mask.shape[-3:])
- im = im.view(-1, *im.shape[-3:])
- assert im.dtype == torch.uint8, im.dtype
- assert 0 <= t <= 1
- if not visualize_instances:
- mask = mask.any(dim=1, keepdim=True)
- mask = mask.bool()
- kernel = torch.ones((3, 3), dtype=mask.dtype, device=mask.device)
- outer_border = binary_dilation(mask, kernel).logical_xor(mask)
- outer_border = outer_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0
- inner_border = binary_erosion(mask, kernel).logical_xor(mask)
- inner_border = inner_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0
- mask = (mask == 0).any(dim=1, keepdim=True).repeat(1, 3, 1, 1)
- color = torch.tensor(color).to(im.device).byte().view(1, 3, 1, 1)#.repeat(1, *im.shape[1:])
- color = color.repeat(im.shape[0], 1, *im.shape[-2:])
- im[mask] = (im[mask] * (1-t) + t * color[mask]).byte()
- im[outer_border] = 255
- im[inner_border] = 0
- return im.view(*orig_imshape)
-
-
-def draw_cropped_masks(im: torch.Tensor, mask: torch.Tensor, boxes: torch.Tensor, **kwargs):
- for i, box in enumerate(boxes):
- x0, y0, x1, y1 = boxes[i]
- orig_shape = (y1-y0, x1-x0)
- m = F.resize(mask[i], orig_shape, F.InterpolationMode.NEAREST).squeeze()[None]
- m = remove_pad(m, boxes[i], im.shape[-2:])
- crop_box(im, boxes[i]).set_(draw_mask(crop_box(im, boxes[i]), m))
- return im
-
-
-def draw_cropped_keypoints(im: torch.Tensor, all_keypoints: torch.Tensor, boxes: torch.Tensor, **kwargs):
- n_boxes = boxes.shape[0]
- tops.assert_shape(all_keypoints, (n_boxes, 17, 3))
- im = im.clone()
- for i, box in enumerate(boxes):
-
- x0, y0, x1, y1 = boxes[i]
- orig_shape = (y1-y0, x1-x0)
- keypoints = all_keypoints[i].clone()
- keypoints[:, 0] *= orig_shape[1]
- keypoints[:, 1] *= orig_shape[0]
- keypoints = keypoints.long()
- _, _, connectivity = get_coco_keypoints()
- connectivity = np.array(connectivity)
- visible = (keypoints[:, 2] > 0)
- if keypoints.shape[0] == 17: # COCO Connectivity
- c = connectivity[visible.cpu().numpy()].tolist()
- else:
- c = None
- # Remove padding from keypoints before visualization
- keypoints[:, 0] += min(x0, 0)
- keypoints[:, 1] += min(y0, 0)
- im_with_kp = draw_keypoints(crop_box(im, box), keypoints[None, visible, :2], colors="red", connectivity=c)
- crop_box(im, box).copy_(im_with_kp)
- return im
diff --git a/media b/media
new file mode 120000
index 0000000000000000000000000000000000000000..4de87a74095d5613c3ca247cb84bd26f286e1ada
--- /dev/null
+++ b/media
@@ -0,0 +1 @@
+deep_privacy2/media
\ No newline at end of file
diff --git a/sg3_torch_utils/LICENSE.txt b/sg3_torch_utils/LICENSE.txt
deleted file mode 100644
index 6b5ee9bf994cc9441cb659c3527160b4ee5bcb33..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/LICENSE.txt
+++ /dev/null
@@ -1,97 +0,0 @@
-Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
-
-
-NVIDIA Source Code License for StyleGAN3
-
-
-=======================================================================
-
-1. Definitions
-
-"Licensor" means any person or entity that distributes its Work.
-
-"Software" means the original work of authorship made available under
-this License.
-
-"Work" means the Software and any additions to or derivative works of
-the Software that are made available under this License.
-
-The terms "reproduce," "reproduction," "derivative works," and
-"distribution" have the meaning as provided under U.S. copyright law;
-provided, however, that for the purposes of this License, derivative
-works shall not include works that remain separable from, or merely
-link (or bind by name) to the interfaces of, the Work.
-
-Works, including the Software, are "made available" under this License
-by including in or with the Work either (a) a copyright notice
-referencing the applicability of this License to the Work, or (b) a
-copy of this License.
-
-2. License Grants
-
- 2.1 Copyright Grant. Subject to the terms and conditions of this
- License, each Licensor grants to you a perpetual, worldwide,
- non-exclusive, royalty-free, copyright license to reproduce,
- prepare derivative works of, publicly display, publicly perform,
- sublicense and distribute its Work and any resulting derivative
- works in any form.
-
-3. Limitations
-
- 3.1 Redistribution. You may reproduce or distribute the Work only
- if (a) you do so under this License, (b) you include a complete
- copy of this License with your distribution, and (c) you retain
- without modification any copyright, patent, trademark, or
- attribution notices that are present in the Work.
-
- 3.2 Derivative Works. You may specify that additional or different
- terms apply to the use, reproduction, and distribution of your
- derivative works of the Work ("Your Terms") only if (a) Your Terms
- provide that the use limitation in Section 3.3 applies to your
- derivative works, and (b) you identify the specific derivative
- works that are subject to Your Terms. Notwithstanding Your Terms,
- this License (including the redistribution requirements in Section
- 3.1) will continue to apply to the Work itself.
-
- 3.3 Use Limitation. The Work and any derivative works thereof only
- may be used or intended for use non-commercially. Notwithstanding
- the foregoing, NVIDIA and its affiliates may use the Work and any
- derivative works commercially. As used herein, "non-commercially"
- means for research or evaluation purposes only.
-
- 3.4 Patent Claims. If you bring or threaten to bring a patent claim
- against any Licensor (including any claim, cross-claim or
- counterclaim in a lawsuit) to enforce any patents that you allege
- are infringed by any Work, then your rights under this License from
- such Licensor (including the grant in Section 2.1) will terminate
- immediately.
-
- 3.5 Trademarks. This License does not grant any rights to use any
- Licensor’s or its affiliates’ names, logos, or trademarks, except
- as necessary to reproduce the notices described in this License.
-
- 3.6 Termination. If you violate any term of this License, then your
- rights under this License (including the grant in Section 2.1) will
- terminate immediately.
-
-4. Disclaimer of Warranty.
-
-THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
-KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
-MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
-NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
-THIS LICENSE.
-
-5. Limitation of Liability.
-
-EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
-THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
-SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
-INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
-OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
-(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
-LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
-COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
-THE POSSIBILITY OF SUCH DAMAGES.
-
-=======================================================================
diff --git a/sg3_torch_utils/__init__.py b/sg3_torch_utils/__init__.py
deleted file mode 100755
index ece0ea08fe2e939cc260a1dafc0ab5b391b773d9..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-# empty
diff --git a/sg3_torch_utils/custom_ops.py b/sg3_torch_utils/custom_ops.py
deleted file mode 100755
index 4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/custom_ops.py
+++ /dev/null
@@ -1,126 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-import os
-import glob
-import torch
-import torch.utils.cpp_extension
-import importlib
-import hashlib
-import shutil
-from pathlib import Path
-
-from torch.utils.file_baton import FileBaton
-
-#----------------------------------------------------------------------------
-# Global options.
-
-verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
-
-#----------------------------------------------------------------------------
-# Internal helper funcs.
-
-def _find_compiler_bindir():
- patterns = [
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
- 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
- ]
- for pattern in patterns:
- matches = sorted(glob.glob(pattern))
- if len(matches):
- return matches[-1]
- return None
-
-#----------------------------------------------------------------------------
-# Main entry point for compiling and loading C++/CUDA plugins.
-
-_cached_plugins = dict()
-
-def get_plugin(module_name, sources, **build_kwargs):
- assert verbosity in ['none', 'brief', 'full']
-
- # Already cached?
- if module_name in _cached_plugins:
- return _cached_plugins[module_name]
-
- # Print status.
- if verbosity == 'full':
- print(f'Setting up PyTorch plugin "{module_name}"...')
- elif verbosity == 'brief':
- print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
-
- try: # pylint: disable=too-many-nested-blocks
- # Make sure we can find the necessary compiler binaries.
- if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
- compiler_bindir = _find_compiler_bindir()
- if compiler_bindir is None:
- raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
- os.environ['PATH'] += ';' + compiler_bindir
-
- # Compile and load.
- verbose_build = (verbosity == 'full')
-
- # Incremental build md5sum trickery. Copies all the input source files
- # into a cached build directory under a combined md5 digest of the input
- # source files. Copying is done only if the combined digest has changed.
- # This keeps input file timestamps and filenames the same as in previous
- # extension builds, allowing for fast incremental rebuilds.
- #
- # This optimization is done only in case all the source files reside in
- # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
- # environment variable is set (we take this as a signal that the user
- # actually cares about this.)
- source_dirs_set = set(os.path.dirname(source) for source in sources)
- if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
- all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
-
- # Compute a combined hash digest for all source files in the same
- # custom op directory (usually .cu, .cpp, .py and .h files).
- hash_md5 = hashlib.md5()
- for src in all_source_files:
- with open(src, 'rb') as f:
- hash_md5.update(f.read())
- build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
- digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
-
- if not os.path.isdir(digest_build_dir):
- os.makedirs(digest_build_dir, exist_ok=True)
- baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
- if baton.try_acquire():
- try:
- for src in all_source_files:
- shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
- finally:
- baton.release()
- else:
- # Someone else is copying source files under the digest dir,
- # wait until done and continue.
- baton.wait()
- digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
- torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
- verbose=verbose_build, sources=digest_sources, **build_kwargs)
- else:
- torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
- module = importlib.import_module(module_name)
-
- except:
- if verbosity == 'brief':
- print('Failed!')
- raise
-
- # Print status and add to cache.
- if verbosity == 'full':
- print(f'Done setting up PyTorch plugin "{module_name}".')
- elif verbosity == 'brief':
- print('Done.')
- _cached_plugins[module_name] = module
- return module
-
-#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/misc.py b/sg3_torch_utils/misc.py
deleted file mode 100755
index 10d8e31880affdd185580b6f5b98e92c79597dc3..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/misc.py
+++ /dev/null
@@ -1,172 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-import re
-import contextlib
-import numpy as np
-import torch
-import warnings
-
-#----------------------------------------------------------------------------
-# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
-# same constant is used multiple times.
-
-_constant_cache = dict()
-
-def constant(value, shape=None, dtype=None, device=None, memory_format=None):
- value = np.asarray(value)
- if shape is not None:
- shape = tuple(shape)
- if dtype is None:
- dtype = torch.get_default_dtype()
- if device is None:
- device = torch.device('cpu')
- if memory_format is None:
- memory_format = torch.contiguous_format
-
- key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
- tensor = _constant_cache.get(key, None)
- if tensor is None:
- tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
- if shape is not None:
- tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
- tensor = tensor.contiguous(memory_format=memory_format)
- _constant_cache[key] = tensor
- return tensor
-
-#----------------------------------------------------------------------------
-# Replace NaN/Inf with specified numerical values.
-
-try:
- nan_to_num = torch.nan_to_num # 1.8.0a0
-except AttributeError:
- def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
- assert isinstance(input, torch.Tensor)
- if posinf is None:
- posinf = torch.finfo(input.dtype).max
- if neginf is None:
- neginf = torch.finfo(input.dtype).min
- assert nan == 0
- return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
-
-#----------------------------------------------------------------------------
-# Symbolic assert.
-
-try:
- symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
-except AttributeError:
- symbolic_assert = torch.Assert # 1.7.0
-
-#----------------------------------------------------------------------------
-# Context manager to suppress known warnings in torch.jit.trace().
-
-class suppress_tracer_warnings(warnings.catch_warnings):
- def __enter__(self):
- super().__enter__()
- warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
- return self
-
-#----------------------------------------------------------------------------
-# Assert that the shape of a tensor matches the given list of integers.
-# None indicates that the size of a dimension is allowed to vary.
-# Performs symbolic assertion when used in torch.jit.trace().
-
-def assert_shape(tensor, ref_shape):
- if tensor.ndim != len(ref_shape):
- raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
- for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
- if ref_size is None:
- pass
- elif isinstance(ref_size, torch.Tensor):
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
- symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
- elif isinstance(size, torch.Tensor):
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
- symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
- elif size != ref_size:
- raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
-
-#----------------------------------------------------------------------------
-# Function decorator that calls torch.autograd.profiler.record_function().
-
-def profiled_function(fn):
- def decorator(*args, **kwargs):
- with torch.autograd.profiler.record_function(fn.__name__):
- return fn(*args, **kwargs)
- decorator.__name__ = fn.__name__
- return decorator
-
-#----------------------------------------------------------------------------
-# Sampler for torch.utils.data.DataLoader that loops over the dataset
-# indefinitely, shuffling items as it goes.
-
-class InfiniteSampler(torch.utils.data.Sampler):
- def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
- assert len(dataset) > 0
- assert num_replicas > 0
- assert 0 <= rank < num_replicas
- assert 0 <= window_size <= 1
- super().__init__(dataset)
- self.dataset = dataset
- self.rank = rank
- self.num_replicas = num_replicas
- self.shuffle = shuffle
- self.seed = seed
- self.window_size = window_size
-
- def __iter__(self):
- order = np.arange(len(self.dataset))
- rnd = None
- window = 0
- if self.shuffle:
- rnd = np.random.RandomState(self.seed)
- rnd.shuffle(order)
- window = int(np.rint(order.size * self.window_size))
-
- idx = 0
- while True:
- i = idx % order.size
- if idx % self.num_replicas == self.rank:
- yield order[i]
- if window >= 2:
- j = (i - rnd.randint(window)) % order.size
- order[i], order[j] = order[j], order[i]
- idx += 1
-
-#----------------------------------------------------------------------------
-# Utilities for operating with torch.nn.Module parameters and buffers.
-
-def params_and_buffers(module):
- assert isinstance(module, torch.nn.Module)
- return list(module.parameters()) + list(module.buffers())
-
-def named_params_and_buffers(module):
- assert isinstance(module, torch.nn.Module)
- return list(module.named_parameters()) + list(module.named_buffers())
-
-def copy_params_and_buffers(src_module, dst_module, require_all=False):
- assert isinstance(src_module, torch.nn.Module)
- assert isinstance(dst_module, torch.nn.Module)
- src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
- for name, tensor in named_params_and_buffers(dst_module):
- assert (name in src_tensors) or (not require_all)
- if name in src_tensors:
- tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
-
-#----------------------------------------------------------------------------
-# Context manager for easily enabling/disabling DistributedDataParallel
-# synchronization.
-
-@contextlib.contextmanager
-def ddp_sync(module, sync):
- assert isinstance(module, torch.nn.Module)
- if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
- yield
- else:
- with module.no_sync():
- yield
diff --git a/sg3_torch_utils/ops/__init__.py b/sg3_torch_utils/ops/__init__.py
deleted file mode 100755
index ece0ea08fe2e939cc260a1dafc0ab5b391b773d9..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-# empty
diff --git a/sg3_torch_utils/ops/bias_act.cpp b/sg3_torch_utils/ops/bias_act.cpp
deleted file mode 100755
index 5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/bias_act.cpp
+++ /dev/null
@@ -1,99 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-#include
-#include
-#include
-#include "bias_act.h"
-
-//------------------------------------------------------------------------
-
-static bool has_same_layout(torch::Tensor x, torch::Tensor y)
-{
- if (x.dim() != y.dim())
- return false;
- for (int64_t i = 0; i < x.dim(); i++)
- {
- if (x.size(i) != y.size(i))
- return false;
- if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
- return false;
- }
- return true;
-}
-
-//------------------------------------------------------------------------
-
-static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
-{
- // Validate arguments.
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
- TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
- TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
- TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
- TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
- TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
- TORCH_CHECK(b.dim() == 1, "b must have rank 1");
- TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
- TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
- TORCH_CHECK(grad >= 0, "grad must be non-negative");
-
- // Validate layout.
- TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
- TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
- TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
- TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
- TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
-
- // Create output tensor.
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
- torch::Tensor y = torch::empty_like(x);
- TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
-
- // Initialize CUDA kernel parameters.
- bias_act_kernel_params p;
- p.x = x.data_ptr();
- p.b = (b.numel()) ? b.data_ptr() : NULL;
- p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
- p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
- p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
- p.y = y.data_ptr();
- p.grad = grad;
- p.act = act;
- p.alpha = alpha;
- p.gain = gain;
- p.clamp = clamp;
- p.sizeX = (int)x.numel();
- p.sizeB = (int)b.numel();
- p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
-
- // Choose CUDA kernel.
- void* kernel;
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
- {
- kernel = choose_bias_act_kernel(p);
- });
- TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
-
- // Launch CUDA kernel.
- p.loopX = 4;
- int blockSize = 4 * 32;
- int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
- void* args[] = {&p};
- AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
- return y;
-}
-
-//------------------------------------------------------------------------
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
-{
- m.def("bias_act", &bias_act);
-}
-
-//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/bias_act.cu b/sg3_torch_utils/ops/bias_act.cu
deleted file mode 100755
index dd8fc4756d7d94727f94af738665b68d9c518880..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/bias_act.cu
+++ /dev/null
@@ -1,173 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-#include
-#include "bias_act.h"
-
-//------------------------------------------------------------------------
-// Helpers.
-
-template struct InternalType;
-template <> struct InternalType { typedef double scalar_t; };
-template <> struct InternalType { typedef float scalar_t; };
-template <> struct InternalType { typedef float scalar_t; };
-
-//------------------------------------------------------------------------
-// CUDA kernel.
-
-template
-__global__ void bias_act_kernel(bias_act_kernel_params p)
-{
- typedef typename InternalType::scalar_t scalar_t;
- int G = p.grad;
- scalar_t alpha = (scalar_t)p.alpha;
- scalar_t gain = (scalar_t)p.gain;
- scalar_t clamp = (scalar_t)p.clamp;
- scalar_t one = (scalar_t)1;
- scalar_t two = (scalar_t)2;
- scalar_t expRange = (scalar_t)80;
- scalar_t halfExpRange = (scalar_t)40;
- scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
- scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
-
- // Loop over elements.
- int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
- for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
- {
- // Load.
- scalar_t x = (scalar_t)((const T*)p.x)[xi];
- scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
- scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
- scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
- scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
- scalar_t yy = (gain != 0) ? yref / gain : 0;
- scalar_t y = 0;
-
- // Apply bias.
- ((G == 0) ? x : xref) += b;
-
- // linear
- if (A == 1)
- {
- if (G == 0) y = x;
- if (G == 1) y = x;
- }
-
- // relu
- if (A == 2)
- {
- if (G == 0) y = (x > 0) ? x : 0;
- if (G == 1) y = (yy > 0) ? x : 0;
- }
-
- // lrelu
- if (A == 3)
- {
- if (G == 0) y = (x > 0) ? x : x * alpha;
- if (G == 1) y = (yy > 0) ? x : x * alpha;
- }
-
- // tanh
- if (A == 4)
- {
- if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
- if (G == 1) y = x * (one - yy * yy);
- if (G == 2) y = x * (one - yy * yy) * (-two * yy);
- }
-
- // sigmoid
- if (A == 5)
- {
- if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
- if (G == 1) y = x * yy * (one - yy);
- if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
- }
-
- // elu
- if (A == 6)
- {
- if (G == 0) y = (x >= 0) ? x : exp(x) - one;
- if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
- if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
- }
-
- // selu
- if (A == 7)
- {
- if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
- if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
- if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
- }
-
- // softplus
- if (A == 8)
- {
- if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
- if (G == 1) y = x * (one - exp(-yy));
- if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
- }
-
- // swish
- if (A == 9)
- {
- if (G == 0)
- y = (x < -expRange) ? 0 : x / (exp(-x) + one);
- else
- {
- scalar_t c = exp(xref);
- scalar_t d = c + one;
- if (G == 1)
- y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
- else
- y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
- yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
- }
- }
-
- // Apply gain.
- y *= gain * dy;
-
- // Clamp.
- if (clamp >= 0)
- {
- if (G == 0)
- y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
- else
- y = (yref > -clamp & yref < clamp) ? y : 0;
- }
-
- // Store.
- ((T*)p.y)[xi] = (T)y;
- }
-}
-
-//------------------------------------------------------------------------
-// CUDA kernel selection.
-
-template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
-{
- if (p.act == 1) return (void*)bias_act_kernel;
- if (p.act == 2) return (void*)bias_act_kernel;
- if (p.act == 3) return (void*)bias_act_kernel;
- if (p.act == 4) return (void*)bias_act_kernel;
- if (p.act == 5) return (void*)bias_act_kernel;
- if (p.act == 6) return (void*)bias_act_kernel;
- if (p.act == 7) return (void*)bias_act_kernel;
- if (p.act == 8) return (void*)bias_act_kernel;
- if (p.act == 9) return (void*)bias_act_kernel;
- return NULL;
-}
-
-//------------------------------------------------------------------------
-// Template specializations.
-
-template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
-template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
-template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
-
-//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/bias_act.h b/sg3_torch_utils/ops/bias_act.h
deleted file mode 100755
index a32187e1fb7e3bae509d4eceaf900866866875a4..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/bias_act.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-//------------------------------------------------------------------------
-// CUDA kernel parameters.
-
-struct bias_act_kernel_params
-{
- const void* x; // [sizeX]
- const void* b; // [sizeB] or NULL
- const void* xref; // [sizeX] or NULL
- const void* yref; // [sizeX] or NULL
- const void* dy; // [sizeX] or NULL
- void* y; // [sizeX]
-
- int grad;
- int act;
- float alpha;
- float gain;
- float clamp;
-
- int sizeX;
- int sizeB;
- int stepB;
- int loopX;
-};
-
-//------------------------------------------------------------------------
-// CUDA kernel selection.
-
-template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
-
-//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/bias_act.py b/sg3_torch_utils/ops/bias_act.py
deleted file mode 100755
index 7c39717268055fafe737419486cf96f1f93f4fb5..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/bias_act.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Custom PyTorch ops for efficient bias and activation."""
-
-import os
-import warnings
-import numpy as np
-import torch
-import traceback
-
-from .. import custom_ops
-from easydict import EasyDict
-from torch.cuda.amp import custom_bwd, custom_fwd
-#----------------------------------------------------------------------------
-
-activation_funcs = {
- 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
- 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
- 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
- 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
- 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
- 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
- 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
- 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
- 'swish': EasyDict(func=lambda x, **_: torch.nn.functional.silu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
-}
-
-#----------------------------------------------------------------------------
-
-_inited = False
-_plugin = None
-enabled = False
-_null_tensor = torch.empty([0])
-
-def _init():
- global _inited, _plugin
- if not _inited:
- _inited = True
- sources = ['bias_act.cpp', 'bias_act.cu']
- sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
- try:
- _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
- except:
- warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
- return _plugin is not None
-
-#----------------------------------------------------------------------------
-
-def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
- r"""Fused bias and activation function.
-
- Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
- and scales the result by `gain`. Each of the steps is optional. In most cases,
- the fused op is considerably more efficient than performing the same calculation
- using standard PyTorch ops. It supports first and second order gradients,
- but not third order gradients.
-
- Args:
- x: Input activation tensor. Can be of any shape.
- b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
- as `x`. The shape must be known, and it must match the dimension of `x`
- corresponding to `dim`.
- dim: The dimension in `x` corresponding to the elements of `b`.
- The value of `dim` is ignored if `b` is not specified.
- act: Name of the activation function to evaluate, or `"linear"` to disable.
- Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
- See `activation_funcs` for a full list. `None` is not allowed.
- alpha: Shape parameter for the activation function, or `None` to use the default.
- gain: Scaling factor for the output tensor, or `None` to use default.
- See `activation_funcs` for the default scaling of each activation function.
- If unsure, consider specifying 1.
- clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
- the clamping (default).
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
-
- Returns:
- Tensor of the same shape and datatype as `x`.
- """
- assert isinstance(x, torch.Tensor)
- assert impl in ['ref', 'cuda']
- if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init():
- return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
- return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
-
-#----------------------------------------------------------------------------
-
-def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
- """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
- """
- assert isinstance(x, torch.Tensor)
- assert clamp is None or clamp >= 0
- spec = activation_funcs[act]
- alpha = float(alpha if alpha is not None else spec.def_alpha)
- gain = float(gain if gain is not None else spec.def_gain)
- clamp = float(clamp if clamp is not None else -1)
-
- # Add bias.
- if b is not None:
- assert isinstance(b, torch.Tensor) and b.ndim == 1
- assert 0 <= dim < x.ndim
- assert b.shape[0] == x.shape[dim]
- x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
-
- # Evaluate activation function.
- alpha = float(alpha)
- x = spec.func(x, alpha=alpha)
-
- # Scale by gain.
- gain = float(gain)
- if gain != 1:
- x = x * gain
-
- # Clamp.
- if clamp >= 0:
- x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
- return x
-
-#----------------------------------------------------------------------------
-
-_bias_act_cuda_cache = dict()
-
-def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
- """Fast CUDA implementation of `bias_act()` using custom ops.
- """
- # Parse arguments.
- assert clamp is None or clamp >= 0
- spec = activation_funcs[act]
- alpha = float(alpha if alpha is not None else spec.def_alpha)
- gain = float(gain if gain is not None else spec.def_gain)
- clamp = float(clamp if clamp is not None else -1)
-
- # Lookup from cache.
- key = (dim, act, alpha, gain, clamp)
- if key in _bias_act_cuda_cache:
- return _bias_act_cuda_cache[key]
-
- # Forward op.
- class BiasActCuda(torch.autograd.Function):
- @staticmethod
- @custom_fwd(cast_inputs=torch.float16)
- def forward(ctx, x, b): # pylint: disable=arguments-differ
- ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
- x = x.contiguous(memory_format=ctx.memory_format)
- b = b.contiguous() if b is not None else _null_tensor
- y = x
- if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
- y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
- ctx.save_for_backward(
- x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
- b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
- y if 'y' in spec.ref else _null_tensor)
- return y
-
- @staticmethod
- @custom_bwd
- def backward(ctx, dy): # pylint: disable=arguments-differ
- dy = dy.contiguous(memory_format=ctx.memory_format)
- x, b, y = ctx.saved_tensors
- dx = None
- db = None
-
- if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
- dx = dy
- if act != 'linear' or gain != 1 or clamp >= 0:
- dx = BiasActCudaGrad.apply(dy, x, b, y)
-
- if ctx.needs_input_grad[1]:
- db = dx.sum([i for i in range(dx.ndim) if i != dim])
-
- return dx, db
-
- # Backward op.
- class BiasActCudaGrad(torch.autograd.Function):
- @staticmethod
- @custom_fwd(cast_inputs=torch.float16)
- def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
- ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
- dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
- ctx.save_for_backward(
- dy if spec.has_2nd_grad else _null_tensor,
- x, b, y)
- return dx
-
- @staticmethod
- @custom_bwd
- def backward(ctx, d_dx): # pylint: disable=arguments-differ
- d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
- dy, x, b, y = ctx.saved_tensors
- d_dy = None
- d_x = None
- d_b = None
- d_y = None
-
- if ctx.needs_input_grad[0]:
- d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
-
- if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
- d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
-
- if spec.has_2nd_grad and ctx.needs_input_grad[2]:
- d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
-
- return d_dy, d_x, d_b, d_y
-
- # Add to cache.
- _bias_act_cuda_cache[key] = BiasActCuda
- return BiasActCuda
-
-#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/conv2d_gradfix.py b/sg3_torch_utils/ops/conv2d_gradfix.py
deleted file mode 100755
index e66591f19fad68760d3df7c9737a14574b70ee83..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/conv2d_gradfix.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Custom replacement for `torch.nn.functional.conv2d` that supports
-arbitrarily high order gradients with zero performance penalty."""
-
-import warnings
-import contextlib
-import torch
-from torch.cuda.amp import custom_bwd, custom_fwd
-
-# pylint: disable=redefined-builtin
-# pylint: disable=arguments-differ
-# pylint: disable=protected-access
-
-#----------------------------------------------------------------------------
-
-enabled = False # Enable the custom op by setting this to true.
-weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
-
-@contextlib.contextmanager
-def no_weight_gradients():
- global weight_gradients_disabled
- old = weight_gradients_disabled
- weight_gradients_disabled = True
- yield
- weight_gradients_disabled = old
-
-#----------------------------------------------------------------------------
-
-def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
- if _should_use_custom_op(input):
- return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
- return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
-
-def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
- if _should_use_custom_op(input):
- return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
- return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
-
-#----------------------------------------------------------------------------
-
-def _should_use_custom_op(input):
- assert isinstance(input, torch.Tensor)
- if (not enabled) or (not torch.backends.cudnn.enabled):
- return False
- if input.device.type != 'cuda':
- return False
- if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.10']):
- return True
- warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
- return False
-
-def _tuple_of_ints(xs, ndim):
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
- assert len(xs) == ndim
- assert all(isinstance(x, int) for x in xs)
- return xs
-
-#----------------------------------------------------------------------------
-
-_conv2d_gradfix_cache = dict()
-
-def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
- # Parse arguments.
- ndim = 2
- weight_shape = tuple(weight_shape)
- stride = _tuple_of_ints(stride, ndim)
- padding = _tuple_of_ints(padding, ndim)
- output_padding = _tuple_of_ints(output_padding, ndim)
- dilation = _tuple_of_ints(dilation, ndim)
-
- # Lookup from cache.
- key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
- if key in _conv2d_gradfix_cache:
- return _conv2d_gradfix_cache[key]
-
- # Validate arguments.
- assert groups >= 1
- assert len(weight_shape) == ndim + 2
- assert all(stride[i] >= 1 for i in range(ndim))
- assert all(padding[i] >= 0 for i in range(ndim))
- assert all(dilation[i] >= 0 for i in range(ndim))
- if not transpose:
- assert all(output_padding[i] == 0 for i in range(ndim))
- else: # transpose
- assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
-
- # Helpers.
- common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
- def calc_output_padding(input_shape, output_shape):
- if transpose:
- return [0, 0]
- return [
- input_shape[i + 2]
- - (output_shape[i + 2] - 1) * stride[i]
- - (1 - 2 * padding[i])
- - dilation[i] * (weight_shape[i + 2] - 1)
- for i in range(ndim)
- ]
-
- # Forward & backward.
- class Conv2d(torch.autograd.Function):
- @staticmethod
- @custom_fwd(cast_inputs=torch.float16)
- def forward(ctx, input, weight, bias):
- assert weight.shape == weight_shape
- if not transpose:
- output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
- else: # transpose
- output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
- ctx.save_for_backward(input, weight)
- return output
-
- @staticmethod
- @custom_bwd
- def backward(ctx, grad_output):
- input, weight = ctx.saved_tensors
- grad_input = None
- grad_weight = None
- grad_bias = None
-
- if ctx.needs_input_grad[0]:
- p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
- grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output.float(), weight.float(), None)
- assert grad_input.shape == input.shape
-
- if ctx.needs_input_grad[1] and not weight_gradients_disabled:
- grad_weight = Conv2dGradWeight.apply(grad_output.float(), input.float())
- assert grad_weight.shape == weight_shape
-
- if ctx.needs_input_grad[2]:
- grad_bias = grad_output.float().sum([0, 2, 3])
-
- return grad_input, grad_weight, grad_bias
-
- # Gradient with respect to the weights.
- class Conv2dGradWeight(torch.autograd.Function):
- @staticmethod
- @custom_fwd(cast_inputs=torch.float16)
- def forward(ctx, grad_output, input):
- op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
- flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
- grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
- assert grad_weight.shape == weight_shape
- ctx.save_for_backward(grad_output, input)
- return grad_weight
-
- @staticmethod
- @custom_bwd
- def backward(ctx, grad2_grad_weight):
- grad_output, input = ctx.saved_tensors
- grad2_grad_output = None
- grad2_input = None
-
- if ctx.needs_input_grad[0]:
- grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
- assert grad2_grad_output.shape == grad_output.shape
-
- if ctx.needs_input_grad[1]:
- p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
- grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
- assert grad2_input.shape == input.shape
-
- return grad2_grad_output, grad2_input
-
- _conv2d_gradfix_cache[key] = Conv2d
- return Conv2d
-
-#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/conv2d_resample.py b/sg3_torch_utils/ops/conv2d_resample.py
deleted file mode 100755
index 4a999b58b36a5da53752024e86a9ebdf9c031d97..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/conv2d_resample.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""2D convolution with optional up/downsampling."""
-
-import torch
-
-from .. import misc
-from . import conv2d_gradfix
-from . import upfirdn2d
-from .upfirdn2d import _parse_padding
-from .upfirdn2d import _get_filter_size
-
-#----------------------------------------------------------------------------
-
-def _get_weight_shape(w):
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
- shape = [int(sz) for sz in w.shape]
- misc.assert_shape(w, shape)
- return shape
-
-#----------------------------------------------------------------------------
-
-def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
- """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
- """
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
-
- # Flip weight if requested.
- if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
- w = w.flip([2, 3])
-
- # Otherwise => execute using conv2d_gradfix.
- op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
- return op(x, w, stride=stride, padding=padding, groups=groups)
-
-#----------------------------------------------------------------------------
-
-@misc.profiled_function
-def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
- r"""2D convolution with optional up/downsampling.
-
- Padding is performed only once at the beginning, not between the operations.
-
- Args:
- x: Input tensor of shape
- `[batch_size, in_channels, in_height, in_width]`.
- w: Weight tensor of shape
- `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
- f: Low-pass filter for up/downsampling. Must be prepared beforehand by
- calling upfirdn2d.setup_filter(). None = identity (default).
- up: Integer upsampling factor (default: 1).
- down: Integer downsampling factor (default: 1).
- padding: Padding with respect to the upsampled image. Can be a single number
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- groups: Split input channels into N groups (default: 1).
- flip_weight: False = convolution, True = correlation (default: True).
- flip_filter: False = convolution, True = correlation (default: False).
-
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- # Validate arguments.
- assert isinstance(x, torch.Tensor) and (x.ndim == 4)
- assert isinstance(w, torch.Tensor) and (w.ndim == 4)
- assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
- assert isinstance(up, int) and (up >= 1)
- assert isinstance(down, int) and (down >= 1)
- assert isinstance(groups, int) and (groups >= 1)
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
- fw, fh = _get_filter_size(f)
- px0, px1, py0, py1 = _parse_padding(padding)
-
- # Adjust padding to account for up/downsampling.
- if up > 1:
- px0 += (fw + up - 1) // 2
- px1 += (fw - up) // 2
- py0 += (fh + up - 1) // 2
- py1 += (fh - up) // 2
- if down > 1:
- px0 += (fw - down + 1) // 2
- px1 += (fw - down) // 2
- py0 += (fh - down + 1) // 2
- py1 += (fh - down) // 2
-
- # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
- if kw == 1 and kh == 1 and (down > 1 and up == 1):
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
- return x
-
- # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
- if kw == 1 and kh == 1 and (up > 1 and down == 1):
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
- x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
- return x
-
- # Fast path: downsampling only => use strided convolution.
- if down > 1 and up == 1:
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
- x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
- return x
-
- # Fast path: upsampling with optional downsampling => use transpose strided convolution.
- if up > 1:
- if groups == 1:
- w = w.transpose(0, 1)
- else:
- w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
- w = w.transpose(1, 2)
- w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
- px0 -= kw - 1
- px1 -= kw - up
- py0 -= kh - 1
- py1 -= kh - up
- pxt = max(min(-px0, -px1), 0)
- pyt = max(min(-py0, -py1), 0)
- x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
- if down > 1:
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
- return x
-
- # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
- if up == 1 and down == 1:
- if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
- return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
-
- # Fallback: Generic reference implementation.
- x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
- if down > 1:
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
- return x
-
-#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/fma.py b/sg3_torch_utils/ops/fma.py
deleted file mode 100755
index b4e8ef9169440d4c3bd95befae7d26e3c1e1f017..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/fma.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
-
-import torch
-from torch.cuda.amp import custom_bwd, custom_fwd
-
-#----------------------------------------------------------------------------
-
-def fma(a, b, c): # => a * b + c
- return _FusedMultiplyAdd.apply(a, b, c)
-
-#----------------------------------------------------------------------------
-
-class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
- @staticmethod
- @custom_fwd(cast_inputs=torch.float16)
- def forward(ctx, a, b, c): # pylint: disable=arguments-differ
- out = torch.addcmul(c, a, b)
- ctx.save_for_backward(a, b)
- ctx.c_shape = c.shape
- return out
-
- @staticmethod
- @custom_bwd
- def backward(ctx, dout): # pylint: disable=arguments-differ
- a, b = ctx.saved_tensors
- c_shape = ctx.c_shape
- da = None
- db = None
- dc = None
-
- if ctx.needs_input_grad[0]:
- da = _unbroadcast(dout * b, a.shape)
-
- if ctx.needs_input_grad[1]:
- db = _unbroadcast(dout * a, b.shape)
-
- if ctx.needs_input_grad[2]:
- dc = _unbroadcast(dout, c_shape)
-
- return da, db, dc
-
-#----------------------------------------------------------------------------
-
-def _unbroadcast(x, shape):
- extra_dims = x.ndim - len(shape)
- assert extra_dims >= 0
- dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
- if len(dim):
- x = x.sum(dim=dim, keepdim=True)
- if extra_dims:
- x = x.reshape(-1, *x.shape[extra_dims+1:])
- assert x.shape == shape
- return x
-
-#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/grid_sample_gradfix.py b/sg3_torch_utils/ops/grid_sample_gradfix.py
deleted file mode 100755
index 87067e150c591b1ace91816e7a5c3ee3a4aeacd3..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/grid_sample_gradfix.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Custom replacement for `torch.nn.functional.grid_sample` that
-supports arbitrarily high order gradients between the input and output.
-Only works on 2D images and assumes
-`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
-
-import torch
-from torch.cuda.amp import custom_bwd, custom_fwd
-from pkg_resources import parse_version
-# pylint: disable=redefined-builtin
-# pylint: disable=arguments-differ
-# pylint: disable=protected-access
-_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
-
-
-#----------------------------------------------------------------------------
-
-enabled = False # Enable the custom op by setting this to true.
-
-#----------------------------------------------------------------------------
-
-def grid_sample(input, grid):
- if _should_use_custom_op():
- return _GridSample2dForward.apply(input, grid)
- return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
-
-#----------------------------------------------------------------------------
-
-def _should_use_custom_op():
- return enabled
-
-#----------------------------------------------------------------------------
-
-class _GridSample2dForward(torch.autograd.Function):
- @staticmethod
- @custom_fwd(cast_inputs=torch.float16)
- def forward(ctx, input, grid):
- assert input.ndim == 4
- assert grid.ndim == 4
- output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
- ctx.save_for_backward(input, grid)
- return output
-
- @staticmethod
- @custom_bwd
- def backward(ctx, grad_output):
- input, grid = ctx.saved_tensors
- grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
- return grad_input, grad_grid
-
-#----------------------------------------------------------------------------
-
-class _GridSample2dBackward(torch.autograd.Function):
- @staticmethod
- @custom_fwd(cast_inputs=torch.float16)
- def forward(ctx, grad_output, input, grid):
- op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
- if _use_pytorch_1_11_api:
- output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
- grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
- else:
- grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
- ctx.save_for_backward(grid)
- return grad_input, grad_grid
-
- @staticmethod
- @custom_bwd
- def backward(ctx, grad2_grad_input, grad2_grad_grid):
- _ = grad2_grad_grid # unused
- grid, = ctx.saved_tensors
- grad2_grad_output = None
- grad2_input = None
- grad2_grid = None
-
- if ctx.needs_input_grad[0]:
- grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
-
- assert not ctx.needs_input_grad[2]
- return grad2_grad_output, grad2_input, grad2_grid
-
-#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/upfirdn2d.cpp b/sg3_torch_utils/ops/upfirdn2d.cpp
deleted file mode 100755
index 2d7177fc60040751d20e9a8da0301fa3ab64968a..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/upfirdn2d.cpp
+++ /dev/null
@@ -1,103 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-#include
-#include
-#include
-#include "upfirdn2d.h"
-
-//------------------------------------------------------------------------
-
-static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
-{
- // Validate arguments.
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
- TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
- TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
- TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
- TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
- TORCH_CHECK(x.dim() == 4, "x must be rank 4");
- TORCH_CHECK(f.dim() == 2, "f must be rank 2");
- TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
- TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
- TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
-
- // Create output tensor.
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
- int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
- int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
- TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
- torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
- TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
-
- // Initialize CUDA kernel parameters.
- upfirdn2d_kernel_params p;
- p.x = x.data_ptr();
- p.f = f.data_ptr();
- p.y = y.data_ptr();
- p.up = make_int2(upx, upy);
- p.down = make_int2(downx, downy);
- p.pad0 = make_int2(padx0, pady0);
- p.flip = (flip) ? 1 : 0;
- p.gain = gain;
- p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
- p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
- p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
- p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
- p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
- p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
- p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
- p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
-
- // Choose CUDA kernel.
- upfirdn2d_kernel_spec spec;
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
- {
- spec = choose_upfirdn2d_kernel(p);
- });
-
- // Set looping options.
- p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
- p.loopMinor = spec.loopMinor;
- p.loopX = spec.loopX;
- p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
- p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
-
- // Compute grid size.
- dim3 blockSize, gridSize;
- if (spec.tileOutW < 0) // large
- {
- blockSize = dim3(4, 32, 1);
- gridSize = dim3(
- ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
- (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
- p.launchMajor);
- }
- else // small
- {
- blockSize = dim3(256, 1, 1);
- gridSize = dim3(
- ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
- (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
- p.launchMajor);
- }
-
- // Launch CUDA kernel.
- void* args[] = {&p};
- AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
- return y;
-}
-
-//------------------------------------------------------------------------
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
-{
- m.def("upfirdn2d", &upfirdn2d);
-}
-
-//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/upfirdn2d.cu b/sg3_torch_utils/ops/upfirdn2d.cu
deleted file mode 100755
index ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/upfirdn2d.cu
+++ /dev/null
@@ -1,350 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-#include
-#include "upfirdn2d.h"
-
-//------------------------------------------------------------------------
-// Helpers.
-
-template struct InternalType;
-template <> struct InternalType { typedef double scalar_t; };
-template <> struct InternalType { typedef float scalar_t; };
-template <> struct InternalType { typedef float scalar_t; };
-
-static __device__ __forceinline__ int floor_div(int a, int b)
-{
- int t = 1 - a / b;
- return (a + t * b) / b - t;
-}
-
-//------------------------------------------------------------------------
-// Generic CUDA implementation for large filters.
-
-template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
-{
- typedef typename InternalType::scalar_t scalar_t;
-
- // Calculate thread index.
- int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
- int outY = minorBase / p.launchMinor;
- minorBase -= outY * p.launchMinor;
- int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
- int majorBase = blockIdx.z * p.loopMajor;
- if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
- return;
-
- // Setup Y receptive field.
- int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
- int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
- int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
- int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
- if (p.flip)
- filterY = p.filterSize.y - 1 - filterY;
-
- // Loop over major, minor, and X.
- for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
- for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
- {
- int nc = major * p.sizeMinor + minor;
- int n = nc / p.inSize.z;
- int c = nc - n * p.inSize.z;
- for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
- {
- // Setup X receptive field.
- int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
- int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
- int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
- int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
- if (p.flip)
- filterX = p.filterSize.x - 1 - filterX;
-
- // Initialize pointers.
- const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
- const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
- int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
- int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
-
- // Inner loop.
- scalar_t v = 0;
- for (int y = 0; y < h; y++)
- {
- for (int x = 0; x < w; x++)
- {
- v += (scalar_t)(*xp) * (scalar_t)(*fp);
- xp += p.inStride.x;
- fp += filterStepX;
- }
- xp += p.inStride.y - w * p.inStride.x;
- fp += filterStepY - w * filterStepX;
- }
-
- // Store result.
- v *= p.gain;
- ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
- }
- }
-}
-
-//------------------------------------------------------------------------
-// Specialized CUDA implementation for small filters.
-
-template
-static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
-{
- typedef typename InternalType::scalar_t scalar_t;
- const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
- const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
- __shared__ volatile scalar_t sf[filterH][filterW];
- __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
-
- // Calculate tile index.
- int minorBase = blockIdx.x;
- int tileOutY = minorBase / p.launchMinor;
- minorBase -= tileOutY * p.launchMinor;
- minorBase *= loopMinor;
- tileOutY *= tileOutH;
- int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
- int majorBase = blockIdx.z * p.loopMajor;
- if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
- return;
-
- // Load filter (flipped).
- for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
- {
- int fy = tapIdx / filterW;
- int fx = tapIdx - fy * filterW;
- scalar_t v = 0;
- if (fx < p.filterSize.x & fy < p.filterSize.y)
- {
- int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
- int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
- v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
- }
- sf[fy][fx] = v;
- }
-
- // Loop over major and X.
- for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
- {
- int baseNC = major * p.sizeMinor + minorBase;
- int n = baseNC / p.inSize.z;
- int baseC = baseNC - n * p.inSize.z;
- for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
- {
- // Load input pixels.
- int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
- int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
- int tileInX = floor_div(tileMidX, upx);
- int tileInY = floor_div(tileMidY, upy);
- __syncthreads();
- for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
- {
- int relC = inIdx;
- int relInX = relC / loopMinor;
- int relInY = relInX / tileInW;
- relC -= relInX * loopMinor;
- relInX -= relInY * tileInW;
- int c = baseC + relC;
- int inX = tileInX + relInX;
- int inY = tileInY + relInY;
- scalar_t v = 0;
- if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
- v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
- sx[relInY][relInX][relC] = v;
- }
-
- // Loop over output pixels.
- __syncthreads();
- for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
- {
- int relC = outIdx;
- int relOutX = relC / loopMinor;
- int relOutY = relOutX / tileOutW;
- relC -= relOutX * loopMinor;
- relOutX -= relOutY * tileOutW;
- int c = baseC + relC;
- int outX = tileOutX + relOutX;
- int outY = tileOutY + relOutY;
-
- // Setup receptive field.
- int midX = tileMidX + relOutX * downx;
- int midY = tileMidY + relOutY * downy;
- int inX = floor_div(midX, upx);
- int inY = floor_div(midY, upy);
- int relInX = inX - tileInX;
- int relInY = inY - tileInY;
- int filterX = (inX + 1) * upx - midX - 1; // flipped
- int filterY = (inY + 1) * upy - midY - 1; // flipped
-
- // Inner loop.
- if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
- {
- scalar_t v = 0;
- #pragma unroll
- for (int y = 0; y < filterH / upy; y++)
- #pragma unroll
- for (int x = 0; x < filterW / upx; x++)
- v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
- v *= p.gain;
- ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
- }
- }
- }
- }
-}
-
-//------------------------------------------------------------------------
-// CUDA kernel selection.
-
-template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
-{
- int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
-
- upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous
- if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last
-
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
- {
- if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- }
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
- {
- if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- }
- if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
- {
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
- }
- if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
- {
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
- }
- if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
- {
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
- }
- if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
- {
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
- }
- if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
- {
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
- }
- if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
- {
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
- }
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
- {
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
- }
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
- {
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
- }
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
- {
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
- }
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
- {
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
- }
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
- {
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
- }
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
- {
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
- }
- return spec;
-}
-
-//------------------------------------------------------------------------
-// Template specializations.
-
-template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
-template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
-template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
-
-//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/upfirdn2d.h b/sg3_torch_utils/ops/upfirdn2d.h
deleted file mode 100755
index c9e2032bcac9d2abde7a75eea4d812da348afadd..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/upfirdn2d.h
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-//
-// NVIDIA CORPORATION and its licensors retain all intellectual property
-// and proprietary rights in and to this software, related documentation
-// and any modifications thereto. Any use, reproduction, disclosure or
-// distribution of this software and related documentation without an express
-// license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-#include
-
-//------------------------------------------------------------------------
-// CUDA kernel parameters.
-
-struct upfirdn2d_kernel_params
-{
- const void* x;
- const float* f;
- void* y;
-
- int2 up;
- int2 down;
- int2 pad0;
- int flip;
- float gain;
-
- int4 inSize; // [width, height, channel, batch]
- int4 inStride;
- int2 filterSize; // [width, height]
- int2 filterStride;
- int4 outSize; // [width, height, channel, batch]
- int4 outStride;
- int sizeMinor;
- int sizeMajor;
-
- int loopMinor;
- int loopMajor;
- int loopX;
- int launchMinor;
- int launchMajor;
-};
-
-//------------------------------------------------------------------------
-// CUDA kernel specialization.
-
-struct upfirdn2d_kernel_spec
-{
- void* kernel;
- int tileOutW;
- int tileOutH;
- int loopMinor;
- int loopX;
-};
-
-//------------------------------------------------------------------------
-// CUDA kernel selection.
-
-template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
-
-//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/upfirdn2d.py b/sg3_torch_utils/ops/upfirdn2d.py
deleted file mode 100755
index a0bbd22d245481e7c5a19315e5cb3242b1278787..0000000000000000000000000000000000000000
--- a/sg3_torch_utils/ops/upfirdn2d.py
+++ /dev/null
@@ -1,388 +0,0 @@
-# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
-#
-# NVIDIA CORPORATION and its licensors retain all intellectual property
-# and proprietary rights in and to this software, related documentation
-# and any modifications thereto. Any use, reproduction, disclosure or
-# distribution of this software and related documentation without an express
-# license agreement from NVIDIA CORPORATION is strictly prohibited.
-
-"""Custom PyTorch ops for efficient resampling of 2D images."""
-
-import os
-import warnings
-import numpy as np
-import torch
-import traceback
-
-from .. import custom_ops
-from .. import misc
-from . import conv2d_gradfix
-from torch.cuda.amp import custom_bwd, custom_fwd
-
-#----------------------------------------------------------------------------
-
-_inited = False
-_plugin = None
-enabled = False
-
-def _init():
- global _inited, _plugin
- if not _inited:
- sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
- sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
- try:
- _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
- except:
- warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
- return _plugin is not None
-
-def _parse_scaling(scaling):
- if isinstance(scaling, int):
- scaling = [scaling, scaling]
- assert isinstance(scaling, (list, tuple))
- assert all(isinstance(x, int) for x in scaling)
- sx, sy = scaling
- assert sx >= 1 and sy >= 1
- return sx, sy
-
-def _parse_padding(padding):
- if isinstance(padding, int):
- padding = [padding, padding]
- assert isinstance(padding, (list, tuple))
- assert all(isinstance(x, int) for x in padding)
- if len(padding) == 2:
- padx, pady = padding
- padding = [padx, padx, pady, pady]
- padx0, padx1, pady0, pady1 = padding
- return padx0, padx1, pady0, pady1
-
-def _get_filter_size(f):
- if f is None:
- return 1, 1
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
- fw = f.shape[-1]
- fh = f.shape[0]
- with misc.suppress_tracer_warnings():
- fw = int(fw)
- fh = int(fh)
- misc.assert_shape(f, [fh, fw][:f.ndim])
- assert fw >= 1 and fh >= 1
- return fw, fh
-
-#----------------------------------------------------------------------------
-
-def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
- r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
-
- Args:
- f: Torch tensor, numpy array, or python list of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable),
- `[]` (impulse), or
- `None` (identity).
- device: Result device (default: cpu).
- normalize: Normalize the filter so that it retains the magnitude
- for constant input signal (DC)? (default: True).
- flip_filter: Flip the filter? (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- separable: Return a separable filter? (default: select automatically).
-
- Returns:
- Float32 tensor of the shape
- `[filter_height, filter_width]` (non-separable) or
- `[filter_taps]` (separable).
- """
- # Validate.
- if f is None:
- f = 1
- f = torch.as_tensor(f, dtype=torch.float32)
- assert f.ndim in [0, 1, 2]
- assert f.numel() > 0
- if f.ndim == 0:
- f = f[np.newaxis]
-
- # Separable?
- if separable is None:
- separable = (f.ndim == 1 and f.numel() >= 8)
- if f.ndim == 1 and not separable:
- f = f.ger(f)
- assert f.ndim == (1 if separable else 2)
-
- # Apply normalize, flip, gain, and device.
- if normalize:
- f /= f.sum()
- if flip_filter:
- f = f.flip(list(range(f.ndim)))
- f = f * (gain ** (f.ndim / 2))
- f = f.to(device=device)
- return f
-
-#----------------------------------------------------------------------------
-
-def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
- r"""Pad, upsample, filter, and downsample a batch of 2D images.
-
- Performs the following sequence of operations for each channel:
-
- 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
-
- 2. Pad the image with the specified number of zeros on each side (`padding`).
- Negative padding corresponds to cropping the image.
-
- 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
- so that the footprint of all output pixels lies within the input image.
-
- 4. Downsample the image by keeping every Nth pixel (`down`).
-
- This sequence of operations bears close resemblance to scipy.signal.upfirdn().
- The fused op is considerably more efficient than performing the same calculation
- using standard PyTorch ops. It supports gradients of arbitrary order.
-
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- up: Integer upsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- down: Integer downsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- padding: Padding with respect to the upsampled image. Can be a single number
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
-
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- assert isinstance(x, torch.Tensor)
- assert impl in ['ref', 'cuda']
- if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init():
- return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
- return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
-
-#----------------------------------------------------------------------------
-
-@misc.profiled_function
-def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
- """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
- """
- # Validate arguments.
- assert isinstance(x, torch.Tensor) and x.ndim == 4
- if f is None:
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
- assert f.dtype == torch.float32 and not f.requires_grad
- batch_size, num_channels, in_height, in_width = x.shape
- upx, upy = _parse_scaling(up)
- downx, downy = _parse_scaling(down)
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
-
- # Upsample by inserting zeros.
- x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
- x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
- x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
-
- # Pad or crop.
- x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
- x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
-
- # Setup filter.
- f = f * (gain ** (f.ndim / 2))
- f = f.to(x.dtype)
- if not flip_filter:
- f = f.flip(list(range(f.ndim)))
-
- # Convolve with the filter.
- f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
- if f.ndim == 4:
- x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
- else:
- x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
- x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
-
- # Downsample by throwing away pixels.
- x = x[:, :, ::downy, ::downx]
- return x
-
-#----------------------------------------------------------------------------
-
-_upfirdn2d_cuda_cache = dict()
-
-def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
- """Fast CUDA implementation of `upfirdn2d()` using custom ops.
- """
- # Parse arguments.
- upx, upy = _parse_scaling(up)
- downx, downy = _parse_scaling(down)
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
-
- # Lookup from cache.
- key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
- if key in _upfirdn2d_cuda_cache:
- return _upfirdn2d_cuda_cache[key]
-
- # Forward op.
- class Upfirdn2dCuda(torch.autograd.Function):
- @staticmethod
- @custom_fwd(cast_inputs=torch.float32)
- def forward(ctx, x, f): # pylint: disable=arguments-differ
- assert isinstance(x, torch.Tensor) and x.ndim == 4
- if f is None:
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
- y = x
- if f.ndim == 2:
- y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
- else:
- y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
- y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
- ctx.save_for_backward(f)
- ctx.x_shape = x.shape
- return y
-
- @staticmethod
- @custom_bwd
- def backward(ctx, dy): # pylint: disable=arguments-differ
- f, = ctx.saved_tensors
- _, _, ih, iw = ctx.x_shape
- _, _, oh, ow = dy.shape
- fw, fh = _get_filter_size(f)
- p = [
- fw - padx0 - 1,
- iw * upx - ow * downx + padx0 - upx + 1,
- fh - pady0 - 1,
- ih * upy - oh * downy + pady0 - upy + 1,
- ]
- dx = None
- df = None
-
- if ctx.needs_input_grad[0]:
- dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
-
- assert not ctx.needs_input_grad[1]
- return dx, df
-
- # Add to cache.
- _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
- return Upfirdn2dCuda
-
-#----------------------------------------------------------------------------
-
-def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
- r"""Filter a batch of 2D images using the given 2D FIR filter.
-
- By default, the result is padded so that its shape matches the input.
- User-specified padding is applied on top of that, with negative values
- indicating cropping. Pixels outside the image are assumed to be zero.
-
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- padding: Padding with respect to the output. Can be a single number or a
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
-
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
- fw, fh = _get_filter_size(f)
- p = [
- padx0 + fw // 2,
- padx1 + (fw - 1) // 2,
- pady0 + fh // 2,
- pady1 + (fh - 1) // 2,
- ]
- return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
-
-#----------------------------------------------------------------------------
-
-def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
- r"""Upsample a batch of 2D images using the given 2D FIR filter.
-
- By default, the result is padded so that its shape is a multiple of the input.
- User-specified padding is applied on top of that, with negative values
- indicating cropping. Pixels outside the image are assumed to be zero.
-
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- up: Integer upsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- padding: Padding with respect to the output. Can be a single number or a
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
-
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- upx, upy = _parse_scaling(up)
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
- fw, fh = _get_filter_size(f)
- p = [
- padx0 + (fw + upx - 1) // 2,
- padx1 + (fw - upx) // 2,
- pady0 + (fh + upy - 1) // 2,
- pady1 + (fh - upy) // 2,
- ]
- return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
-
-#----------------------------------------------------------------------------
-
-def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
- r"""Downsample a batch of 2D images using the given 2D FIR filter.
-
- By default, the result is padded so that its shape is a fraction of the input.
- User-specified padding is applied on top of that, with negative values
- indicating cropping. Pixels outside the image are assumed to be zero.
-
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- down: Integer downsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- padding: Padding with respect to the input. Can be a single number or a
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
-
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- downx, downy = _parse_scaling(down)
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
- fw, fh = _get_filter_size(f)
- p = [
- padx0 + (fw - downx + 1) // 2,
- padx1 + (fw - downx) // 2,
- pady0 + (fh - downy + 1) // 2,
- pady1 + (fh - downy) // 2,
- ]
- return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
-
-#----------------------------------------------------------------------------