Spaces:
Runtime error
Runtime error
update demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitmodules +3 -0
- app.py +21 -62
- configs/anonymizers/FB_cse.py +0 -28
- configs/anonymizers/FB_cse_mask.py +0 -29
- configs/anonymizers/FB_cse_mask_face.py +0 -29
- configs/anonymizers/face.py +0 -18
- configs/anonymizers/market1501/blackout.py +0 -8
- configs/anonymizers/market1501/person.py +0 -6
- configs/anonymizers/market1501/pixelation16.py +0 -8
- configs/anonymizers/market1501/pixelation8.py +0 -8
- configs/datasets/coco_cse.py +0 -69
- configs/datasets/fdf128.py +0 -24
- configs/datasets/fdf256.py +0 -69
- configs/datasets/fdh.py +0 -89
- configs/datasets/utils.py +0 -12
- configs/defaults.py +0 -45
- configs/discriminators/sg2_discriminator.py +0 -42
- configs/fdf/stylegan.py +0 -14
- configs/fdf/stylegan_fdf128.py +0 -13
- configs/fdh/styleganL.py +0 -16
- configs/fdh/styleganL_nocse.py +0 -14
- configs/generators/stylegan_unet.py +0 -22
- deep_privacy2 +1 -0
- dp2/__init__.py +0 -0
- dp2/anonymizer/__init__.py +0 -1
- dp2/anonymizer/anonymizer.py +0 -159
- dp2/data/__init__.py +0 -0
- dp2/data/build.py +0 -148
- dp2/data/datasets/__init__.py +0 -0
- dp2/data/datasets/coco_cse.py +0 -148
- dp2/data/datasets/fdf.py +0 -129
- dp2/data/datasets/fdh.py +0 -104
- dp2/data/transforms/__init__.py +0 -2
- dp2/data/transforms/functional.py +0 -61
- dp2/data/transforms/stylegan2_transform.py +0 -394
- dp2/data/transforms/transforms.py +0 -247
- dp2/data/utils.py +0 -102
- dp2/detection/__init__.py +0 -3
- dp2/detection/base.py +0 -45
- dp2/detection/box_utils.py +0 -104
- dp2/detection/box_utils_fdf.py +0 -203
- dp2/detection/cse_mask_face_detector.py +0 -116
- dp2/detection/face_detector.py +0 -62
- dp2/detection/models/__init__.py +0 -0
- dp2/detection/models/cse.py +0 -135
- dp2/detection/models/keypoint_maskrcnn.py +0 -111
- dp2/detection/models/mask_rcnn.py +0 -78
- dp2/detection/person_detector.py +0 -135
- dp2/detection/structures.py +0 -464
.gitattributes
CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
erling.jpg filter=lfs diff=lfs merge=lfs -text
|
36 |
*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
erling.jpg filter=lfs diff=lfs merge=lfs -text
|
36 |
*.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth filter=lfs diff=lfs merge=lfs -text
|
38 |
+
torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e filter=lfs diff=lfs merge=lfs -text
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "deep_privacy2"]
|
2 |
+
path = deep_privacy2
|
3 |
+
url = https://github.com/hukkelas/deep_privacy2
|
app.py
CHANGED
@@ -1,78 +1,37 @@
|
|
|
|
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
os.system("pip install --upgrade pip")
|
3 |
os.system("pip install ftfy regex tqdm")
|
4 |
-
os.system("pip install git+https://github.com/openai/CLIP.git")
|
5 |
os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
|
6 |
-
os.system("pip install git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
|
7 |
-
|
8 |
-
|
9 |
-
import torch
|
10 |
-
from PIL import Image
|
11 |
from dp2 import utils
|
12 |
-
from
|
13 |
-
import tops
|
14 |
-
import gradio.inputs
|
15 |
-
|
16 |
-
|
17 |
-
cfg_body = utils.load_config("configs/anonymizers/FB_cse.py")
|
18 |
-
anonymizer_body = instantiate(cfg_body.anonymizer, load_cache=False)
|
19 |
-
anonymizer_body.initialize_tracker(fps=1)
|
20 |
-
cfg_face = utils.load_config("configs/anonymizers/face.py")
|
21 |
-
anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
|
22 |
-
anonymizer_face.initialize_tracker(fps=1)
|
23 |
-
|
24 |
-
|
25 |
-
class ExampleDemo:
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
input_image = gradio.Image(type="pil", label="Upload your image or try the example below!")
|
32 |
-
output_image = gradio.Image(type="numpy", label="Output")
|
33 |
-
with gradio.Row():
|
34 |
-
update_btn = gradio.Button("Update Anonymization").style(full_width=True)
|
35 |
-
visualize_det = gradio.Checkbox(value=False, label="Show Detections")
|
36 |
-
visualize_det.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
|
37 |
-
gradio.Examples(
|
38 |
-
["erling.jpg", "g7-summit-leaders-distraction.jpg"], inputs=[input_image]
|
39 |
-
)
|
40 |
-
update_btn.click(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
|
41 |
-
input_image.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
|
42 |
-
self.track = False
|
43 |
|
44 |
-
def anonymize(self, img: Image, visualize_detection: bool):
|
45 |
-
img, cache_id = pil2torch(img)
|
46 |
-
img = tops.to_cuda(img)
|
47 |
-
if visualize_detection:
|
48 |
-
img = self.anonymizer.visualize_detection(img, cache_id=cache_id)
|
49 |
-
else:
|
50 |
-
img = self.anonymizer(
|
51 |
-
img, truncation_value=0 if self.multi_modal_truncation else 1, multi_modal_truncation=self.multi_modal_truncation, amp=True,
|
52 |
-
cache_id=cache_id, track=self.track)
|
53 |
-
img = utils.im2numpy(img)
|
54 |
-
return img
|
55 |
|
|
|
56 |
|
57 |
-
|
58 |
-
img = img.convert("RGB")
|
59 |
-
img = np.array(img)
|
60 |
-
img = np.rollaxis(img, 2)
|
61 |
-
return torch.from_numpy(img), None
|
62 |
|
63 |
|
64 |
with gradio.Blocks() as demo:
|
65 |
gradio.Markdown("# <center> DeepPrivacy2 - Realistic Image Anonymization </center>")
|
66 |
gradio.Markdown("### <center> Håkon Hukkelås, Rudolf Mester, Frank Lindseth </center>")
|
67 |
-
gradio.Markdown("<center> DeepPrivacy2 is a toolbox for realistic anonymization of humans, including a face and a full-body anonymizer. </center>")
|
68 |
gradio.Markdown("<center> See more information at: <a href='https://github.com/hukkelas/deep_privacy2'> https://github.com/hukkelas/deep_privacy2 </a> </center>")
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
with gradio.Tab("Full-Body Anonymization"):
|
73 |
-
ExampleDemo(anonymizer_body, multi_modal_truncation=True)
|
74 |
-
with gradio.Tab("Face Anonymization"):
|
75 |
-
ExampleDemo(anonymizer_face, multi_modal_truncation=False)
|
76 |
-
|
77 |
-
|
78 |
-
demo.launch()
|
|
|
1 |
+
import gradio
|
2 |
+
import sys
|
3 |
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from tops.config import instantiate
|
6 |
+
import gradio.inputs
|
7 |
os.system("pip install --upgrade pip")
|
8 |
os.system("pip install ftfy regex tqdm")
|
9 |
+
os.system("pip install --no-deps git+https://github.com/openai/CLIP.git")
|
10 |
os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
|
11 |
+
os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
|
12 |
+
sys.path.insert(0, Path(os.getcwd(), "deep_privacy2"))
|
13 |
+
os.environ["TORCH_HOME"] = "torch_home"
|
|
|
|
|
14 |
from dp2 import utils
|
15 |
+
from gradio_demos.modules import ExampleDemo, WebcamDemo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
cfg_face = utils.load_config("deep_privacy2/configs/anonymizers/face.py")
|
18 |
+
for key in ["person_G_cfg", "cse_person_G_cfg", "face_G_cfg", "car_G_cfg"]:
|
19 |
+
if key in cfg_face.anonymizer:
|
20 |
+
cfg_face.anonymizer[key] = Path("deep_privacy2", cfg_face.anonymizer[key])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
|
24 |
|
25 |
+
anonymizer_face.initialize_tracker(fps=1)
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
with gradio.Blocks() as demo:
|
29 |
gradio.Markdown("# <center> DeepPrivacy2 - Realistic Image Anonymization </center>")
|
30 |
gradio.Markdown("### <center> Håkon Hukkelås, Rudolf Mester, Frank Lindseth </center>")
|
|
|
31 |
gradio.Markdown("<center> See more information at: <a href='https://github.com/hukkelas/deep_privacy2'> https://github.com/hukkelas/deep_privacy2 </a> </center>")
|
32 |
+
with gradio.Tab("Face Anonymization"):
|
33 |
+
ExampleDemo(anonymizer_face)
|
34 |
+
with gradio.Tab("Live Webcam"):
|
35 |
+
WebcamDemo(anonymizer_face)
|
36 |
|
37 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/anonymizers/FB_cse.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
from dp2.anonymizer import Anonymizer
|
2 |
-
from dp2.detection.person_detector import CSEPersonDetector
|
3 |
-
from ..defaults import common
|
4 |
-
from tops.config import LazyCall as L
|
5 |
-
from dp2.generator.dummy_generators import MaskOutGenerator
|
6 |
-
|
7 |
-
|
8 |
-
maskout_G = L(MaskOutGenerator)(noise="constant")
|
9 |
-
|
10 |
-
detector = L(CSEPersonDetector)(
|
11 |
-
mask_rcnn_cfg=dict(),
|
12 |
-
cse_cfg=dict(),
|
13 |
-
cse_post_process_cfg=dict(
|
14 |
-
target_imsize=(288, 160),
|
15 |
-
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
16 |
-
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
17 |
-
iou_combine_threshold=0.4,
|
18 |
-
dilation_percentage=0.02,
|
19 |
-
normalize_embedding=False
|
20 |
-
),
|
21 |
-
score_threshold=0.3,
|
22 |
-
cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
|
23 |
-
)
|
24 |
-
|
25 |
-
anonymizer = L(Anonymizer)(
|
26 |
-
detector="${detector}",
|
27 |
-
cse_person_G_cfg="configs/fdh/styleganL.py",
|
28 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/anonymizers/FB_cse_mask.py
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
from dp2.anonymizer import Anonymizer
|
2 |
-
from dp2.detection.person_detector import CSEPersonDetector
|
3 |
-
from ..defaults import common
|
4 |
-
from tops.config import LazyCall as L
|
5 |
-
from dp2.generator.dummy_generators import MaskOutGenerator
|
6 |
-
|
7 |
-
|
8 |
-
maskout_G = L(MaskOutGenerator)(noise="constant")
|
9 |
-
|
10 |
-
detector = L(CSEPersonDetector)(
|
11 |
-
mask_rcnn_cfg=dict(),
|
12 |
-
cse_cfg=dict(),
|
13 |
-
cse_post_process_cfg=dict(
|
14 |
-
target_imsize=(288, 160),
|
15 |
-
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
16 |
-
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
17 |
-
iou_combine_threshold=0.4,
|
18 |
-
dilation_percentage=0.02,
|
19 |
-
normalize_embedding=False
|
20 |
-
),
|
21 |
-
score_threshold=0.3,
|
22 |
-
cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
|
23 |
-
)
|
24 |
-
|
25 |
-
anonymizer = L(Anonymizer)(
|
26 |
-
detector="${detector}",
|
27 |
-
person_G_cfg="configs/fdh/styleganL_nocse.py",
|
28 |
-
cse_person_G_cfg="configs/fdh/styleganL.py",
|
29 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/anonymizers/FB_cse_mask_face.py
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
from dp2.anonymizer import Anonymizer
|
2 |
-
from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector
|
3 |
-
from ..defaults import common
|
4 |
-
from tops.config import LazyCall as L
|
5 |
-
|
6 |
-
detector = L(CSeMaskFaceDetector)(
|
7 |
-
mask_rcnn_cfg=dict(),
|
8 |
-
face_detector_cfg=dict(),
|
9 |
-
face_post_process_cfg=dict(target_imsize=(256, 256)),
|
10 |
-
cse_cfg=dict(),
|
11 |
-
cse_post_process_cfg=dict(
|
12 |
-
target_imsize=(288, 160),
|
13 |
-
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
14 |
-
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
15 |
-
iou_combine_threshold=0.4,
|
16 |
-
dilation_percentage=0.02,
|
17 |
-
normalize_embedding=False
|
18 |
-
),
|
19 |
-
score_threshold=0.3,
|
20 |
-
cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache")
|
21 |
-
)
|
22 |
-
|
23 |
-
anonymizer = L(Anonymizer)(
|
24 |
-
detector="${detector}",
|
25 |
-
face_G_cfg="configs/fdf/stylegan.py",
|
26 |
-
person_G_cfg="configs/fdh/styleganL_nocse.py",
|
27 |
-
cse_person_G_cfg="configs/fdh/styleganL.py",
|
28 |
-
car_G_cfg="configs/generators/dummy/pixelation8.py"
|
29 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/anonymizers/face.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
from dp2.anonymizer import Anonymizer
|
2 |
-
from dp2.detection.face_detector import FaceDetector
|
3 |
-
from ..defaults import common
|
4 |
-
from tops.config import LazyCall as L
|
5 |
-
|
6 |
-
|
7 |
-
detector = L(FaceDetector)(
|
8 |
-
face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
|
9 |
-
face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
|
10 |
-
score_threshold=0.3,
|
11 |
-
cache_directory=common.output_dir.joinpath("face_detection_cache")
|
12 |
-
)
|
13 |
-
|
14 |
-
|
15 |
-
anonymizer = L(Anonymizer)(
|
16 |
-
detector="${detector}",
|
17 |
-
face_G_cfg="configs/fdf/stylegan.py",
|
18 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/anonymizers/market1501/blackout.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
-
|
3 |
-
detector.score_threshold = .1
|
4 |
-
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
-
detector.cse_cfg.score_thres = 0.3
|
6 |
-
anonymizer.generators.face_G_cfg = None
|
7 |
-
anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py"
|
8 |
-
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/anonymizers/market1501/person.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
-
|
3 |
-
detector.score_threshold = .1
|
4 |
-
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
-
detector.cse_cfg.score_thres = 0.3
|
6 |
-
anonymizer.generators.face_G_cfg = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/anonymizers/market1501/pixelation16.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
-
|
3 |
-
detector.score_threshold = .1
|
4 |
-
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
-
detector.cse_cfg.score_thres = 0.3
|
6 |
-
anonymizer.generators.face_G_cfg = None
|
7 |
-
anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py"
|
8 |
-
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/anonymizers/market1501/pixelation8.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
-
|
3 |
-
detector.score_threshold = .1
|
4 |
-
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
-
detector.cse_cfg.score_thres = 0.3
|
6 |
-
anonymizer.generators.face_G_cfg = None
|
7 |
-
anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py"
|
8 |
-
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/datasets/coco_cse.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pathlib import Path
|
3 |
-
from tops.config import LazyCall as L
|
4 |
-
import torch
|
5 |
-
import functools
|
6 |
-
from dp2.data.datasets import CocoCSE
|
7 |
-
from dp2.data.build import get_dataloader
|
8 |
-
from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
|
9 |
-
from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
|
10 |
-
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
11 |
-
from .utils import final_eval_fn
|
12 |
-
|
13 |
-
|
14 |
-
dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
15 |
-
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
16 |
-
data_dir = Path(dataset_base_dir, "coco_cse")
|
17 |
-
data = dict(
|
18 |
-
imsize=(288, 160),
|
19 |
-
im_channels=3,
|
20 |
-
semantic_nc=26,
|
21 |
-
cse_nc=16,
|
22 |
-
train=dict(
|
23 |
-
dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
|
24 |
-
loader=L(get_dataloader)(
|
25 |
-
shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
|
26 |
-
batch_size="${train.batch_size}",
|
27 |
-
dataset="${..dataset}",
|
28 |
-
infinite=True,
|
29 |
-
gpu_transform=L(torch.nn.Sequential)(*[
|
30 |
-
L(ToFloat)(),
|
31 |
-
L(StyleGANAugmentPipe)(
|
32 |
-
rotate=0.5, rotate_max=.05,
|
33 |
-
xint=.5, xint_max=0.05,
|
34 |
-
scale=.5, scale_std=.05,
|
35 |
-
aniso=0.5, aniso_std=.05,
|
36 |
-
xfrac=.5, xfrac_std=.05,
|
37 |
-
brightness=.5, brightness_std=.05,
|
38 |
-
contrast=.5, contrast_std=.1,
|
39 |
-
hue=.5, hue_max=.05,
|
40 |
-
saturation=.5, saturation_std=.5,
|
41 |
-
imgfilter=.5, imgfilter_std=.1),
|
42 |
-
L(RandomHorizontalFlip)(p=0.5),
|
43 |
-
L(CreateEmbedding)(),
|
44 |
-
L(Resize)(size="${data.imsize}"),
|
45 |
-
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
46 |
-
L(CreateCondition)(),
|
47 |
-
])
|
48 |
-
)
|
49 |
-
),
|
50 |
-
val=dict(
|
51 |
-
dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
|
52 |
-
loader=L(get_dataloader)(
|
53 |
-
shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
|
54 |
-
batch_size="${train.batch_size}",
|
55 |
-
dataset="${..dataset}",
|
56 |
-
infinite=False,
|
57 |
-
gpu_transform=L(torch.nn.Sequential)(*[
|
58 |
-
L(ToFloat)(),
|
59 |
-
L(CreateEmbedding)(),
|
60 |
-
L(Resize)(size="${data.imsize}"),
|
61 |
-
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
62 |
-
L(CreateCondition)(),
|
63 |
-
])
|
64 |
-
)
|
65 |
-
),
|
66 |
-
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
67 |
-
train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
|
68 |
-
evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
|
69 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/datasets/fdf128.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
from functools import partial
|
3 |
-
from dp2.data.datasets.fdf import FDFDataset
|
4 |
-
from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn
|
5 |
-
|
6 |
-
data_dir = Path(dataset_base_dir, "fdf")
|
7 |
-
data.train.dataset.dirpath = data_dir.joinpath("train")
|
8 |
-
data.val.dataset.dirpath = data_dir.joinpath("val")
|
9 |
-
data.imsize = (128, 128)
|
10 |
-
|
11 |
-
|
12 |
-
data.train_evaluation_fn = partial(
|
13 |
-
final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train"))
|
14 |
-
data.evaluation_fn = partial(
|
15 |
-
final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final"))
|
16 |
-
|
17 |
-
data.train.dataset.update(
|
18 |
-
_target_ = FDFDataset,
|
19 |
-
imsize="${data.imsize}"
|
20 |
-
)
|
21 |
-
data.val.dataset.update(
|
22 |
-
_target_ = FDFDataset,
|
23 |
-
imsize="${data.imsize}"
|
24 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/datasets/fdf256.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pathlib import Path
|
3 |
-
from tops.config import LazyCall as L
|
4 |
-
import torch
|
5 |
-
import functools
|
6 |
-
from dp2.data.datasets.fdf import FDF256Dataset
|
7 |
-
from dp2.data.build import get_dataloader
|
8 |
-
from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
|
9 |
-
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
10 |
-
from dp2.metrics.fid_clip import compute_fid_clip
|
11 |
-
from dp2.metrics.ppl import calculate_ppl
|
12 |
-
from .utils import final_eval_fn
|
13 |
-
|
14 |
-
|
15 |
-
def final_eval_fn(*args, **kwargs):
|
16 |
-
result = compute_metrics_iteratively(*args, **kwargs)
|
17 |
-
result2 = compute_fid_clip(*args, **kwargs)
|
18 |
-
assert all(key not in result for key in result2)
|
19 |
-
result.update(result2)
|
20 |
-
result3 = calculate_ppl(*args, **kwargs,)
|
21 |
-
assert all(key not in result for key in result3)
|
22 |
-
result.update(result3)
|
23 |
-
return result
|
24 |
-
|
25 |
-
|
26 |
-
dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
27 |
-
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
28 |
-
data_dir = Path(dataset_base_dir, "fdf256")
|
29 |
-
data = dict(
|
30 |
-
imsize=(256, 256),
|
31 |
-
im_channels=3,
|
32 |
-
semantic_nc=None,
|
33 |
-
cse_nc=None,
|
34 |
-
n_keypoints=None,
|
35 |
-
train=dict(
|
36 |
-
dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
|
37 |
-
loader=L(get_dataloader)(
|
38 |
-
shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
|
39 |
-
batch_size="${train.batch_size}",
|
40 |
-
dataset="${..dataset}",
|
41 |
-
infinite=True,
|
42 |
-
gpu_transform=L(torch.nn.Sequential)(*[
|
43 |
-
L(ToFloat)(),
|
44 |
-
L(RandomHorizontalFlip)(p=0.5),
|
45 |
-
L(Resize)(size="${data.imsize}"),
|
46 |
-
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
47 |
-
L(CreateCondition)(),
|
48 |
-
])
|
49 |
-
)
|
50 |
-
),
|
51 |
-
val=dict(
|
52 |
-
dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
|
53 |
-
loader=L(get_dataloader)(
|
54 |
-
shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
|
55 |
-
batch_size="${train.batch_size}",
|
56 |
-
dataset="${..dataset}",
|
57 |
-
infinite=False,
|
58 |
-
gpu_transform=L(torch.nn.Sequential)(*[
|
59 |
-
L(ToFloat)(),
|
60 |
-
L(Resize)(size="${data.imsize}"),
|
61 |
-
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
62 |
-
L(CreateCondition)(),
|
63 |
-
])
|
64 |
-
)
|
65 |
-
),
|
66 |
-
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
67 |
-
train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "fdf_val_train")),
|
68 |
-
evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
|
69 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/datasets/fdh.py
DELETED
@@ -1,89 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pathlib import Path
|
3 |
-
from tops.config import LazyCall as L
|
4 |
-
import torch
|
5 |
-
import functools
|
6 |
-
from dp2.data.datasets.fdh import get_dataloader_fdh_wds
|
7 |
-
from dp2.data.utils import get_coco_flipmap
|
8 |
-
from dp2.data.transforms.transforms import (
|
9 |
-
Normalize,
|
10 |
-
ToFloat,
|
11 |
-
CreateCondition,
|
12 |
-
RandomHorizontalFlip,
|
13 |
-
CreateEmbedding,
|
14 |
-
)
|
15 |
-
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
16 |
-
from dp2.metrics.fid_clip import compute_fid_clip
|
17 |
-
from .utils import final_eval_fn
|
18 |
-
|
19 |
-
|
20 |
-
def train_eval_fn(*args, **kwargs):
|
21 |
-
result = compute_metrics_iteratively(*args, **kwargs)
|
22 |
-
result2 = compute_fid_clip(*args, **kwargs)
|
23 |
-
assert all(key not in result for key in result2)
|
24 |
-
result.update(result2)
|
25 |
-
return result
|
26 |
-
|
27 |
-
|
28 |
-
dataset_base_dir = (
|
29 |
-
os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
30 |
-
)
|
31 |
-
metrics_cache = (
|
32 |
-
os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
33 |
-
)
|
34 |
-
data_dir = Path(dataset_base_dir, "fdh")
|
35 |
-
data = dict(
|
36 |
-
imsize=(288, 160),
|
37 |
-
im_channels=3,
|
38 |
-
cse_nc=16,
|
39 |
-
n_keypoints=17,
|
40 |
-
train=dict(
|
41 |
-
loader=L(get_dataloader_fdh_wds)(
|
42 |
-
path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
|
43 |
-
batch_size="${train.batch_size}",
|
44 |
-
num_workers=6,
|
45 |
-
transform=L(torch.nn.Sequential)(
|
46 |
-
L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
|
47 |
-
),
|
48 |
-
gpu_transform=L(torch.nn.Sequential)(
|
49 |
-
L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
|
50 |
-
L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
|
51 |
-
L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
|
52 |
-
L(CreateCondition)(),
|
53 |
-
),
|
54 |
-
infinite=True,
|
55 |
-
shuffle=True,
|
56 |
-
partial_batches=False,
|
57 |
-
load_embedding=True,
|
58 |
-
)
|
59 |
-
),
|
60 |
-
val=dict(
|
61 |
-
loader=L(get_dataloader_fdh_wds)(
|
62 |
-
path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
|
63 |
-
batch_size="${train.batch_size}",
|
64 |
-
num_workers=6,
|
65 |
-
transform=None,
|
66 |
-
gpu_transform=L(torch.nn.Sequential)(
|
67 |
-
L(ToFloat)(keys=["img", "mask", "E_mask", "maskrcnn_mask"], norm=False),
|
68 |
-
L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
|
69 |
-
L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
|
70 |
-
L(CreateCondition)(),
|
71 |
-
),
|
72 |
-
infinite=False,
|
73 |
-
shuffle=False,
|
74 |
-
partial_batches=True,
|
75 |
-
load_embedding=True,
|
76 |
-
)
|
77 |
-
),
|
78 |
-
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
79 |
-
train_evaluation_fn=functools.partial(
|
80 |
-
train_eval_fn,
|
81 |
-
cache_directory=Path(metrics_cache, "fdh_v7_train"),
|
82 |
-
data_len=int(30e3),
|
83 |
-
),
|
84 |
-
evaluation_fn=functools.partial(
|
85 |
-
final_eval_fn,
|
86 |
-
cache_directory=Path(metrics_cache, "fdh_v6_val"),
|
87 |
-
data_len=int(30e3),
|
88 |
-
),
|
89 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/datasets/utils.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
from dp2.metrics.ppl import calculate_ppl
|
2 |
-
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
3 |
-
from dp2.metrics.fid_clip import compute_fid_clip
|
4 |
-
|
5 |
-
|
6 |
-
def final_eval_fn(*args, **kwargs):
|
7 |
-
result = compute_metrics_iteratively(*args, **kwargs)
|
8 |
-
result2 = calculate_ppl(*args, **kwargs,)
|
9 |
-
result2 = compute_fid_clip(*args, **kwargs)
|
10 |
-
assert all(key not in result for key in result2)
|
11 |
-
result.update(result2)
|
12 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/defaults.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
import pathlib
|
2 |
-
import os
|
3 |
-
import torch
|
4 |
-
from tops.config import LazyCall as L
|
5 |
-
|
6 |
-
if "PRETRAINED_CHECKPOINTS_PATH" in os.environ:
|
7 |
-
PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"])
|
8 |
-
else:
|
9 |
-
PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints")
|
10 |
-
if "BASE_OUTPUT_DIR" in os.environ:
|
11 |
-
BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"])
|
12 |
-
else:
|
13 |
-
BASE_OUTPUT_DIR = pathlib.Path("outputs")
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
common = dict(
|
18 |
-
logger_backend=["wandb", "stdout", "json", "image_dumper"],
|
19 |
-
wandb_project="fba_test",
|
20 |
-
output_dir=BASE_OUTPUT_DIR,
|
21 |
-
experiment_name=None, # Optional experiment name to show on wandb
|
22 |
-
)
|
23 |
-
|
24 |
-
train = dict(
|
25 |
-
batch_size=32,
|
26 |
-
seed=0,
|
27 |
-
ims_per_log=1024,
|
28 |
-
ims_per_val=int(200e3),
|
29 |
-
max_images_to_train=int(12e6),
|
30 |
-
amp=dict(
|
31 |
-
enabled=True,
|
32 |
-
scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
|
33 |
-
scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
|
34 |
-
),
|
35 |
-
fp16_ddp_accumulate=False, # All gather gradients in fp16?
|
36 |
-
broadcast_buffers=False,
|
37 |
-
bias_act_plugin_enabled=True,
|
38 |
-
grid_sample_gradfix_enabled=True,
|
39 |
-
conv2d_gradfix_enabled=False,
|
40 |
-
channels_last=False,
|
41 |
-
)
|
42 |
-
|
43 |
-
# exponential moving average
|
44 |
-
EMA = dict(rampup=0.05)
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/discriminators/sg2_discriminator.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
from tops.config import LazyCall as L
|
2 |
-
from dp2.discriminator import SG2Discriminator
|
3 |
-
import torch
|
4 |
-
from dp2.loss import StyleGAN2Loss
|
5 |
-
|
6 |
-
|
7 |
-
discriminator = L(SG2Discriminator)(
|
8 |
-
imsize="${data.imsize}",
|
9 |
-
im_channels="${data.im_channels}",
|
10 |
-
min_fmap_resolution=4,
|
11 |
-
max_cnum_mul=8,
|
12 |
-
cnum=80,
|
13 |
-
input_condition=True,
|
14 |
-
conv_clamp=256,
|
15 |
-
input_cse=False,
|
16 |
-
cse_nc="${data.cse_nc}"
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
loss_fnc = L(StyleGAN2Loss)(
|
21 |
-
lazy_regularization=True,
|
22 |
-
lazy_reg_interval=16,
|
23 |
-
r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False),
|
24 |
-
EP_lambd=0.001,
|
25 |
-
pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01)
|
26 |
-
)
|
27 |
-
|
28 |
-
def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs):
|
29 |
-
if lazy_regularization:
|
30 |
-
# From Analyzing and improving the image quality of stylegan, CVPR 2020
|
31 |
-
c = lazy_reg_interval / (lazy_reg_interval + 1)
|
32 |
-
betas = [beta ** c for beta in betas]
|
33 |
-
lr *= c
|
34 |
-
print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}")
|
35 |
-
return type(lr=lr, betas=betas, **kwargs)
|
36 |
-
|
37 |
-
|
38 |
-
D_optim = L(build_D_optim)(
|
39 |
-
type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99),
|
40 |
-
lazy_regularization="${loss_fnc.lazy_regularization}",
|
41 |
-
lazy_reg_interval="${loss_fnc.lazy_reg_interval}")
|
42 |
-
G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/fdf/stylegan.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
from ..generators.stylegan_unet import generator
|
2 |
-
from ..datasets.fdf256 import data
|
3 |
-
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
4 |
-
from ..defaults import train, common, EMA
|
5 |
-
|
6 |
-
train.max_images_to_train = int(35e6)
|
7 |
-
G_optim.lr = 0.002
|
8 |
-
D_optim.lr = 0.002
|
9 |
-
generator.input_cse = False
|
10 |
-
loss_fnc.r1_opts.lambd = 1
|
11 |
-
train.ims_per_val = int(2e6)
|
12 |
-
|
13 |
-
common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e"
|
14 |
-
common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/fdf/stylegan_fdf128.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
2 |
-
from ..datasets.fdf128 import data
|
3 |
-
from ..generators.stylegan_unet import generator
|
4 |
-
from ..defaults import train, common, EMA
|
5 |
-
from tops.config import LazyCall as L
|
6 |
-
|
7 |
-
train.max_images_to_train = int(25e6)
|
8 |
-
G_optim.lr = 0.002
|
9 |
-
D_optim.lr = 0.002
|
10 |
-
generator.cnum = 128
|
11 |
-
generator.max_cnum_mul = 4
|
12 |
-
generator.input_cse = False
|
13 |
-
loss_fnc.r1_opts.lambd = .1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/fdh/styleganL.py
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
from tops.config import LazyCall as L
|
2 |
-
from ..generators.stylegan_unet import generator
|
3 |
-
from ..datasets.fdh import data
|
4 |
-
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
5 |
-
from ..defaults import train, common, EMA
|
6 |
-
|
7 |
-
train.max_images_to_train = int(50e6)
|
8 |
-
train.batch_size = 64
|
9 |
-
G_optim.lr = 0.002
|
10 |
-
D_optim.lr = 0.002
|
11 |
-
data.train.loader.num_workers = 4
|
12 |
-
train.ims_per_val = int(1e6)
|
13 |
-
loss_fnc.r1_opts.lambd = .1
|
14 |
-
|
15 |
-
common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c"
|
16 |
-
common.model_md5sum = "3411478b5ec600a4219cccf4499732bd"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/fdh/styleganL_nocse.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
from tops.config import LazyCall as L
|
2 |
-
from ..generators.stylegan_unet import generator
|
3 |
-
from ..datasets.fdh import data
|
4 |
-
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
5 |
-
from ..defaults import train, common, EMA
|
6 |
-
|
7 |
-
train.max_images_to_train = int(50e6)
|
8 |
-
G_optim.lr = 0.002
|
9 |
-
D_optim.lr = 0.002
|
10 |
-
generator.input_cse = False
|
11 |
-
data.load_embeddings = False
|
12 |
-
common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt"
|
13 |
-
common.model_md5sum = "fda0d809741bc67487abada793975c37"
|
14 |
-
generator.fix_errors = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/generators/stylegan_unet.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
from dp2.generator.stylegan_unet import StyleGANUnet
|
2 |
-
from tops.config import LazyCall as L
|
3 |
-
|
4 |
-
generator = L(StyleGANUnet)(
|
5 |
-
imsize="${data.imsize}",
|
6 |
-
im_channels="${data.im_channels}",
|
7 |
-
min_fmap_resolution=8,
|
8 |
-
cnum=64,
|
9 |
-
max_cnum_mul=8,
|
10 |
-
n_middle_blocks=0,
|
11 |
-
z_channels=512,
|
12 |
-
mask_output=True,
|
13 |
-
conv_clamp=256,
|
14 |
-
input_cse=True,
|
15 |
-
scale_grad=True,
|
16 |
-
cse_nc="${data.cse_nc}",
|
17 |
-
w_dim=512,
|
18 |
-
n_keypoints="${data.n_keypoints}",
|
19 |
-
input_keypoints=False,
|
20 |
-
input_keypoint_indices=[],
|
21 |
-
fix_errors=True
|
22 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deep_privacy2
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 37dcbeb23a1f51121d53bcd80d32d086d6822b7b
|
dp2/__init__.py
DELETED
File without changes
|
dp2/anonymizer/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .anonymizer import Anonymizer
|
|
|
|
dp2/anonymizer/anonymizer.py
DELETED
@@ -1,159 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
from typing import Union, Optional
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
import tops
|
6 |
-
import torchvision.transforms.functional as F
|
7 |
-
from motpy import Detection, MultiObjectTracker
|
8 |
-
from dp2.utils import load_config
|
9 |
-
from dp2.infer import build_trained_generator
|
10 |
-
from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
|
11 |
-
|
12 |
-
|
13 |
-
def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
|
14 |
-
cfg = load_config(cfg_path)
|
15 |
-
G = build_trained_generator(cfg)
|
16 |
-
tops.logger.log(f"Loaded generator from: {cfg_path}")
|
17 |
-
return G
|
18 |
-
|
19 |
-
|
20 |
-
def resize_batch(img, mask, maskrcnn_mask, condition, imsize, **kwargs):
|
21 |
-
img = F.resize(img, imsize, antialias=True)
|
22 |
-
mask = (F.resize(mask, imsize, antialias=True) > 0.99).float()
|
23 |
-
maskrcnn_mask = (F.resize(maskrcnn_mask, imsize, antialias=True) > 0.5).float()
|
24 |
-
|
25 |
-
condition = img * mask
|
26 |
-
return dict(img=img, mask=mask, maskrcnn_mask=maskrcnn_mask, condition=condition)
|
27 |
-
|
28 |
-
|
29 |
-
class Anonymizer:
|
30 |
-
|
31 |
-
def __init__(
|
32 |
-
self,
|
33 |
-
detector,
|
34 |
-
load_cache: bool,
|
35 |
-
person_G_cfg: Optional[Union[str, Path]] = None,
|
36 |
-
cse_person_G_cfg: Optional[Union[str, Path]] = None,
|
37 |
-
face_G_cfg: Optional[Union[str, Path]] = None,
|
38 |
-
car_G_cfg: Optional[Union[str, Path]] = None,
|
39 |
-
) -> None:
|
40 |
-
self.detector = detector
|
41 |
-
self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
|
42 |
-
self.load_cache = load_cache
|
43 |
-
if cse_person_G_cfg is not None:
|
44 |
-
self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
|
45 |
-
if person_G_cfg is not None:
|
46 |
-
self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
|
47 |
-
if face_G_cfg is not None:
|
48 |
-
self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
|
49 |
-
if car_G_cfg is not None:
|
50 |
-
self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
|
51 |
-
|
52 |
-
def initialize_tracker(self, fps: float):
|
53 |
-
self.tracker = MultiObjectTracker(dt=1/fps)
|
54 |
-
self.track_to_z_idx = dict()
|
55 |
-
self.cur_z_idx = 0
|
56 |
-
|
57 |
-
@torch.no_grad()
|
58 |
-
def anonymize_detections(self,
|
59 |
-
im, detection, truncation_value: float,
|
60 |
-
multi_modal_truncation: bool, amp: bool, z_idx,
|
61 |
-
all_styles=None,
|
62 |
-
update_identity=None,
|
63 |
-
):
|
64 |
-
G = self.generators[type(detection)]
|
65 |
-
if G is None:
|
66 |
-
return im
|
67 |
-
C, H, W = im.shape
|
68 |
-
orig_im = im.clone()
|
69 |
-
if update_identity is None:
|
70 |
-
update_identity = [True for i in range(len(detection))]
|
71 |
-
for idx in range(len(detection)):
|
72 |
-
if not update_identity[idx]:
|
73 |
-
continue
|
74 |
-
batch = detection.get_crop(idx, im)
|
75 |
-
x0, y0, x1, y1 = batch.pop("boxes")[0]
|
76 |
-
batch = {k: tops.to_cuda(v) for k, v in batch.items()}
|
77 |
-
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
|
78 |
-
batch["img"] = batch["img"].float()
|
79 |
-
batch["condition"] = batch["mask"] * batch["img"]
|
80 |
-
orig_shape = None
|
81 |
-
if G.imsize and batch["img"].shape[-1] != G.imsize[-1] and batch["img"].shape[-2] != G.imsize[-2]:
|
82 |
-
orig_shape = batch["img"].shape[-2:]
|
83 |
-
batch = resize_batch(**batch, imsize=G.imsize)
|
84 |
-
with torch.cuda.amp.autocast(amp):
|
85 |
-
if all_styles is not None:
|
86 |
-
anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
|
87 |
-
elif multi_modal_truncation and hasattr(G, "multi_modal_truncate") and hasattr(G.style_net, "w_centers"):
|
88 |
-
w_indices = None
|
89 |
-
if z_idx is not None:
|
90 |
-
w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
|
91 |
-
anonymized_im = G.multi_modal_truncate(
|
92 |
-
**batch, truncation_value=truncation_value,
|
93 |
-
w_indices=w_indices)["img"]
|
94 |
-
else:
|
95 |
-
z = None
|
96 |
-
if z_idx is not None:
|
97 |
-
state = np.random.RandomState(seed=z_idx[idx])
|
98 |
-
z = state.normal(size=(1, G.z_channels))
|
99 |
-
z = tops.to_cuda(torch.from_numpy(z))
|
100 |
-
anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
|
101 |
-
if orig_shape is not None:
|
102 |
-
anonymized_im = F.resize(anonymized_im, orig_shape, antialias=True)
|
103 |
-
anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255).round().byte()
|
104 |
-
|
105 |
-
# Resize and denormalize image
|
106 |
-
gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), antialias=True)
|
107 |
-
mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
|
108 |
-
# Remove padding
|
109 |
-
pad = [max(-x0,0), max(-y0,0)]
|
110 |
-
pad = [*pad, max(x1-W,0), max(y1-H,0)]
|
111 |
-
remove_pad = lambda x: x[...,pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
|
112 |
-
gim = remove_pad(gim)
|
113 |
-
mask = remove_pad(mask)
|
114 |
-
x0, y0 = max(x0, 0), max(y0, 0)
|
115 |
-
x1, y1 = min(x1, W), min(y1, H)
|
116 |
-
mask = mask.logical_not()[None].repeat(3, 1, 1)
|
117 |
-
im[:, y0:y1, x0:x1][mask] = gim[mask]
|
118 |
-
|
119 |
-
return im
|
120 |
-
|
121 |
-
def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
|
122 |
-
all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
|
123 |
-
for det in all_detections:
|
124 |
-
im = det.visualize(im)
|
125 |
-
return im
|
126 |
-
|
127 |
-
@torch.no_grad()
|
128 |
-
def forward(self, im: torch.Tensor, cache_id: str = None, track=True, **synthesis_kwargs) -> torch.Tensor:
|
129 |
-
assert im.dtype == torch.uint8
|
130 |
-
im = tops.to_cuda(im)
|
131 |
-
all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
|
132 |
-
if hasattr(self, "tracker") and track:
|
133 |
-
[_.pre_process() for _ in all_detections]
|
134 |
-
import numpy as np
|
135 |
-
boxes = np.concatenate([_.boxes for _ in all_detections])
|
136 |
-
boxes = [Detection(box) for box in boxes]
|
137 |
-
self.tracker.step(boxes)
|
138 |
-
track_ids = self.tracker.detections_matched_ids
|
139 |
-
z_idx = []
|
140 |
-
for track_id in track_ids:
|
141 |
-
if track_id not in self.track_to_z_idx:
|
142 |
-
self.track_to_z_idx[track_id] = self.cur_z_idx
|
143 |
-
self.cur_z_idx += 1
|
144 |
-
z_idx.append(self.track_to_z_idx[track_id])
|
145 |
-
z_idx = np.array(z_idx)
|
146 |
-
idx_offset = 0
|
147 |
-
|
148 |
-
for detection in all_detections:
|
149 |
-
zs = None
|
150 |
-
if hasattr(self, "tracker") and track:
|
151 |
-
zs = z_idx[idx_offset:idx_offset+len(detection)]
|
152 |
-
idx_offset += len(detection)
|
153 |
-
im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
|
154 |
-
|
155 |
-
return im.cpu()
|
156 |
-
|
157 |
-
def __call__(self, *args, **kwargs):
|
158 |
-
return self.forward(*args, **kwargs)
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/data/__init__.py
DELETED
File without changes
|
dp2/data/build.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
import io
|
2 |
-
import torch
|
3 |
-
import tops
|
4 |
-
from .utils import collate_fn, jpg_decoder, get_num_workers, png_decoder
|
5 |
-
|
6 |
-
def get_dataloader(
|
7 |
-
dataset, gpu_transform: torch.nn.Module,
|
8 |
-
num_workers,
|
9 |
-
batch_size,
|
10 |
-
infinite: bool,
|
11 |
-
drop_last: bool,
|
12 |
-
prefetch_factor: int,
|
13 |
-
shuffle,
|
14 |
-
channels_last=False
|
15 |
-
):
|
16 |
-
sampler = None
|
17 |
-
dl_kwargs = dict(
|
18 |
-
pin_memory=True,
|
19 |
-
)
|
20 |
-
if infinite:
|
21 |
-
sampler = tops.InfiniteSampler(
|
22 |
-
dataset, rank=tops.rank(),
|
23 |
-
num_replicas=tops.world_size(),
|
24 |
-
shuffle=shuffle
|
25 |
-
)
|
26 |
-
elif tops.world_size() > 1:
|
27 |
-
sampler = torch.utils.data.DistributedSampler(
|
28 |
-
dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
|
29 |
-
dl_kwargs["drop_last"] = drop_last
|
30 |
-
else:
|
31 |
-
dl_kwargs["shuffle"] = shuffle
|
32 |
-
dl_kwargs["drop_last"] = drop_last
|
33 |
-
dataloader = torch.utils.data.DataLoader(
|
34 |
-
dataset, sampler=sampler, collate_fn=collate_fn,
|
35 |
-
batch_size=batch_size,
|
36 |
-
num_workers=num_workers, prefetch_factor=prefetch_factor,
|
37 |
-
**dl_kwargs
|
38 |
-
)
|
39 |
-
dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
|
40 |
-
return dataloader
|
41 |
-
|
42 |
-
|
43 |
-
def get_dataloader_places2_wds(
|
44 |
-
path,
|
45 |
-
batch_size: int,
|
46 |
-
num_workers: int,
|
47 |
-
transform: torch.nn.Module,
|
48 |
-
gpu_transform: torch.nn.Module,
|
49 |
-
infinite: bool,
|
50 |
-
shuffle: bool,
|
51 |
-
partial_batches: bool,
|
52 |
-
sample_shuffle=10_000,
|
53 |
-
tar_shuffle=100,
|
54 |
-
channels_last=False,
|
55 |
-
):
|
56 |
-
import webdataset as wds
|
57 |
-
import os
|
58 |
-
os.environ["RANK"] = str(tops.rank())
|
59 |
-
os.environ["WORLD_SIZE"] = str(tops.world_size())
|
60 |
-
|
61 |
-
if infinite:
|
62 |
-
pipeline = [wds.ResampledShards(str(path))]
|
63 |
-
else:
|
64 |
-
pipeline = [wds.SimpleShardList(str(path))]
|
65 |
-
if shuffle:
|
66 |
-
pipeline.append(wds.shuffle(tar_shuffle))
|
67 |
-
pipeline.extend([
|
68 |
-
wds.split_by_node,
|
69 |
-
wds.split_by_worker,
|
70 |
-
])
|
71 |
-
if shuffle:
|
72 |
-
pipeline.append(wds.shuffle(sample_shuffle))
|
73 |
-
|
74 |
-
pipeline.extend([
|
75 |
-
wds.tarfile_to_samples(),
|
76 |
-
wds.decode("torchrgb8"),
|
77 |
-
wds.rename_keys(["img", "jpg"], ["__key__", "__key__"]),
|
78 |
-
])
|
79 |
-
if transform is not None:
|
80 |
-
pipeline.append(wds.map(transform))
|
81 |
-
pipeline.extend([
|
82 |
-
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
|
83 |
-
])
|
84 |
-
pipeline = wds.DataPipeline(*pipeline)
|
85 |
-
if infinite:
|
86 |
-
pipeline = pipeline.repeat(nepochs=1000000)
|
87 |
-
loader = wds.WebLoader(
|
88 |
-
pipeline, batch_size=None, shuffle=False,
|
89 |
-
num_workers=get_num_workers(num_workers),
|
90 |
-
persistent_workers=True,
|
91 |
-
)
|
92 |
-
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
|
93 |
-
return loader
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
def get_dataloader_celebAHQ_wds(
|
99 |
-
path,
|
100 |
-
batch_size: int,
|
101 |
-
num_workers: int,
|
102 |
-
transform: torch.nn.Module,
|
103 |
-
gpu_transform: torch.nn.Module,
|
104 |
-
infinite: bool,
|
105 |
-
shuffle: bool,
|
106 |
-
partial_batches: bool,
|
107 |
-
sample_shuffle=10_000,
|
108 |
-
tar_shuffle=100,
|
109 |
-
channels_last=False,
|
110 |
-
):
|
111 |
-
import webdataset as wds
|
112 |
-
import os
|
113 |
-
os.environ["RANK"] = str(tops.rank())
|
114 |
-
os.environ["WORLD_SIZE"] = str(tops.world_size())
|
115 |
-
|
116 |
-
if infinite:
|
117 |
-
pipeline = [wds.ResampledShards(str(path))]
|
118 |
-
else:
|
119 |
-
pipeline = [wds.SimpleShardList(str(path))]
|
120 |
-
if shuffle:
|
121 |
-
pipeline.append(wds.shuffle(tar_shuffle))
|
122 |
-
pipeline.extend([
|
123 |
-
wds.split_by_node,
|
124 |
-
wds.split_by_worker,
|
125 |
-
])
|
126 |
-
if shuffle:
|
127 |
-
pipeline.append(wds.shuffle(sample_shuffle))
|
128 |
-
|
129 |
-
pipeline.extend([
|
130 |
-
wds.tarfile_to_samples(),
|
131 |
-
wds.decode(wds.handle_extension(".png", png_decoder)),
|
132 |
-
wds.rename_keys(["img", "png"], ["__key__", "__key__"]),
|
133 |
-
])
|
134 |
-
if transform is not None:
|
135 |
-
pipeline.append(wds.map(transform))
|
136 |
-
pipeline.extend([
|
137 |
-
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
|
138 |
-
])
|
139 |
-
pipeline = wds.DataPipeline(*pipeline)
|
140 |
-
if infinite:
|
141 |
-
pipeline = pipeline.repeat(nepochs=1000000)
|
142 |
-
loader = wds.WebLoader(
|
143 |
-
pipeline, batch_size=None, shuffle=False,
|
144 |
-
num_workers=get_num_workers(num_workers),
|
145 |
-
persistent_workers=True,
|
146 |
-
)
|
147 |
-
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last)
|
148 |
-
return loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/data/datasets/__init__.py
DELETED
File without changes
|
dp2/data/datasets/coco_cse.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
import pickle
|
2 |
-
import torchvision
|
3 |
-
import torch
|
4 |
-
import pathlib
|
5 |
-
import numpy as np
|
6 |
-
from typing import Callable, Optional, Union
|
7 |
-
from torch.hub import get_dir as get_hub_dir
|
8 |
-
|
9 |
-
|
10 |
-
def cache_embed_stats(embed_map: torch.Tensor):
|
11 |
-
mean = embed_map.mean(dim=0, keepdim=True)
|
12 |
-
rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
|
13 |
-
|
14 |
-
cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
|
15 |
-
path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
|
16 |
-
path.parent.mkdir(exist_ok=True, parents=True)
|
17 |
-
torch.save(cache, path)
|
18 |
-
|
19 |
-
|
20 |
-
class CocoCSE(torch.utils.data.Dataset):
|
21 |
-
|
22 |
-
def __init__(self,
|
23 |
-
dirpath: Union[str, pathlib.Path],
|
24 |
-
transform: Optional[Callable],
|
25 |
-
normalize_E: bool,):
|
26 |
-
dirpath = pathlib.Path(dirpath)
|
27 |
-
self.dirpath = dirpath
|
28 |
-
|
29 |
-
self.transform = transform
|
30 |
-
assert self.dirpath.is_dir(),\
|
31 |
-
f"Did not find dataset at: {dirpath}"
|
32 |
-
self.image_paths, self.embedding_paths = self._load_impaths()
|
33 |
-
self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
|
34 |
-
mean = self.embed_map.mean(dim=0, keepdim=True)
|
35 |
-
rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
|
36 |
-
self.embed_map = (self.embed_map - mean) * rstd
|
37 |
-
cache_embed_stats(self.embed_map)
|
38 |
-
|
39 |
-
def _load_impaths(self):
|
40 |
-
image_dir = self.dirpath.joinpath("images")
|
41 |
-
image_paths = list(image_dir.glob("*.png"))
|
42 |
-
image_paths.sort()
|
43 |
-
embedding_paths = [
|
44 |
-
self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
|
45 |
-
]
|
46 |
-
return image_paths, embedding_paths
|
47 |
-
|
48 |
-
def __len__(self):
|
49 |
-
return len(self.image_paths)
|
50 |
-
|
51 |
-
def __getitem__(self, idx):
|
52 |
-
im = torchvision.io.read_image(str(self.image_paths[idx]))
|
53 |
-
vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
|
54 |
-
vertices = torch.from_numpy(vertices.squeeze()).long()
|
55 |
-
mask = torch.from_numpy(mask.squeeze()).float()
|
56 |
-
border = torch.from_numpy(border.squeeze()).float()
|
57 |
-
E_mask = 1 - mask - border
|
58 |
-
batch = {
|
59 |
-
"img": im,
|
60 |
-
"vertices": vertices[None],
|
61 |
-
"mask": mask[None],
|
62 |
-
"embed_map": self.embed_map,
|
63 |
-
"border": border[None],
|
64 |
-
"E_mask": E_mask[None]
|
65 |
-
}
|
66 |
-
if self.transform is None:
|
67 |
-
return batch
|
68 |
-
return self.transform(batch)
|
69 |
-
|
70 |
-
|
71 |
-
class CocoCSEWithFace(CocoCSE):
|
72 |
-
|
73 |
-
def __init__(self,
|
74 |
-
dirpath: Union[str, pathlib.Path],
|
75 |
-
transform: Optional[Callable],
|
76 |
-
**kwargs):
|
77 |
-
super().__init__(dirpath, transform, **kwargs)
|
78 |
-
with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp:
|
79 |
-
self.face_boxes = pickle.load(fp)
|
80 |
-
|
81 |
-
def __getitem__(self, idx):
|
82 |
-
item = super().__getitem__(idx)
|
83 |
-
item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name]
|
84 |
-
return item
|
85 |
-
|
86 |
-
|
87 |
-
class CocoCSESemantic(torch.utils.data.Dataset):
|
88 |
-
|
89 |
-
def __init__(self,
|
90 |
-
dirpath: Union[str, pathlib.Path],
|
91 |
-
transform: Optional[Callable],
|
92 |
-
**kwargs):
|
93 |
-
dirpath = pathlib.Path(dirpath)
|
94 |
-
self.dirpath = dirpath
|
95 |
-
|
96 |
-
self.transform = transform
|
97 |
-
assert self.dirpath.is_dir(),\
|
98 |
-
f"Did not find dataset at: {dirpath}"
|
99 |
-
self.image_paths, self.embedding_paths = self._load_impaths()
|
100 |
-
self.vertx2cat = torch.from_numpy(np.load(self.dirpath.parent.joinpath("vertx2cat.npy")))
|
101 |
-
self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
|
102 |
-
|
103 |
-
def _load_impaths(self):
|
104 |
-
image_dir = self.dirpath.joinpath("images")
|
105 |
-
image_paths = list(image_dir.glob("*.png"))
|
106 |
-
image_paths.sort()
|
107 |
-
embedding_paths = [
|
108 |
-
self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
|
109 |
-
]
|
110 |
-
return image_paths, embedding_paths
|
111 |
-
|
112 |
-
def __len__(self):
|
113 |
-
return len(self.image_paths)
|
114 |
-
|
115 |
-
def __getitem__(self, idx):
|
116 |
-
im = torchvision.io.read_image(str(self.image_paths[idx]))
|
117 |
-
vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
|
118 |
-
vertices = torch.from_numpy(vertices.squeeze()).long()
|
119 |
-
mask = torch.from_numpy(mask.squeeze()).float()
|
120 |
-
border = torch.from_numpy(border.squeeze()).float()
|
121 |
-
E_mask = 1 - mask - border
|
122 |
-
batch = {
|
123 |
-
"img": im,
|
124 |
-
"vertices": vertices[None],
|
125 |
-
"mask": mask[None],
|
126 |
-
"border": border[None],
|
127 |
-
"vertx2cat": self.vertx2cat,
|
128 |
-
"embed_map": self.embed_map,
|
129 |
-
}
|
130 |
-
if self.transform is None:
|
131 |
-
return batch
|
132 |
-
return self.transform(batch)
|
133 |
-
|
134 |
-
|
135 |
-
class CocoCSESemanticWithFace(CocoCSESemantic):
|
136 |
-
|
137 |
-
def __init__(self,
|
138 |
-
dirpath: Union[str, pathlib.Path],
|
139 |
-
transform: Optional[Callable],
|
140 |
-
**kwargs):
|
141 |
-
super().__init__(dirpath, transform, **kwargs)
|
142 |
-
with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp:
|
143 |
-
self.face_boxes = pickle.load(fp)
|
144 |
-
|
145 |
-
def __getitem__(self, idx):
|
146 |
-
item = super().__getitem__(idx)
|
147 |
-
item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name]
|
148 |
-
return item
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/data/datasets/fdf.py
DELETED
@@ -1,129 +0,0 @@
|
|
1 |
-
import pathlib
|
2 |
-
from typing import Tuple
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
import pathlib
|
6 |
-
try:
|
7 |
-
import pyspng
|
8 |
-
PYSPNG_IMPORTED = True
|
9 |
-
except ImportError:
|
10 |
-
PYSPNG_IMPORTED = False
|
11 |
-
print("Could not load pyspng. Defaulting to pillow image backend.")
|
12 |
-
from PIL import Image
|
13 |
-
from tops import logger
|
14 |
-
|
15 |
-
|
16 |
-
class FDFDataset:
|
17 |
-
|
18 |
-
def __init__(self,
|
19 |
-
dirpath,
|
20 |
-
imsize: Tuple[int],
|
21 |
-
load_keypoints: bool,
|
22 |
-
transform):
|
23 |
-
dirpath = pathlib.Path(dirpath)
|
24 |
-
self.dirpath = dirpath
|
25 |
-
self.transform = transform
|
26 |
-
self.imsize = imsize[0]
|
27 |
-
self.load_keypoints = load_keypoints
|
28 |
-
assert self.dirpath.is_dir(),\
|
29 |
-
f"Did not find dataset at: {dirpath}"
|
30 |
-
image_dir = self.dirpath.joinpath("images", str(self.imsize))
|
31 |
-
self.image_paths = list(image_dir.glob("*.png"))
|
32 |
-
assert len(self.image_paths) > 0,\
|
33 |
-
f"Did not find images in: {image_dir}"
|
34 |
-
self.image_paths.sort(key=lambda x: int(x.stem))
|
35 |
-
self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
|
36 |
-
|
37 |
-
self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch"))
|
38 |
-
assert len(self.image_paths) == len(self.bounding_boxes)
|
39 |
-
assert len(self.image_paths) == len(self.landmarks)
|
40 |
-
logger.log(
|
41 |
-
f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}")
|
42 |
-
|
43 |
-
def get_mask(self, idx):
|
44 |
-
mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool)
|
45 |
-
bounding_box = self.bounding_boxes[idx]
|
46 |
-
x0, y0, x1, y1 = bounding_box
|
47 |
-
mask[:, y0:y1, x0:x1] = 0
|
48 |
-
return mask
|
49 |
-
|
50 |
-
def __len__(self):
|
51 |
-
return len(self.image_paths)
|
52 |
-
|
53 |
-
def __getitem__(self, index):
|
54 |
-
impath = self.image_paths[index]
|
55 |
-
if PYSPNG_IMPORTED:
|
56 |
-
with open(impath, "rb") as fp:
|
57 |
-
im = pyspng.load(fp.read())
|
58 |
-
else:
|
59 |
-
with Image.open(impath) as fp:
|
60 |
-
im = np.array(fp)
|
61 |
-
im = torch.from_numpy(np.rollaxis(im, -1, 0))
|
62 |
-
masks = self.get_mask(index)
|
63 |
-
landmark = self.landmarks[index]
|
64 |
-
batch = {
|
65 |
-
"img": im,
|
66 |
-
"mask": masks,
|
67 |
-
}
|
68 |
-
if self.load_keypoints:
|
69 |
-
batch["keypoints"] = landmark
|
70 |
-
if self.transform is None:
|
71 |
-
return batch
|
72 |
-
return self.transform(batch)
|
73 |
-
|
74 |
-
|
75 |
-
class FDF256Dataset:
|
76 |
-
|
77 |
-
def __init__(self,
|
78 |
-
dirpath,
|
79 |
-
load_keypoints: bool,
|
80 |
-
transform):
|
81 |
-
dirpath = pathlib.Path(dirpath)
|
82 |
-
self.dirpath = dirpath
|
83 |
-
self.transform = transform
|
84 |
-
self.load_keypoints = load_keypoints
|
85 |
-
assert self.dirpath.is_dir(),\
|
86 |
-
f"Did not find dataset at: {dirpath}"
|
87 |
-
image_dir = self.dirpath.joinpath("images")
|
88 |
-
self.image_paths = list(image_dir.glob("*.png"))
|
89 |
-
assert len(self.image_paths) > 0,\
|
90 |
-
f"Did not find images in: {image_dir}"
|
91 |
-
self.image_paths.sort(key=lambda x: int(x.stem))
|
92 |
-
self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
|
93 |
-
self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy")))
|
94 |
-
assert len(self.image_paths) == len(self.bounding_boxes)
|
95 |
-
assert len(self.image_paths) == len(self.landmarks)
|
96 |
-
logger.log(
|
97 |
-
f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}")
|
98 |
-
|
99 |
-
def get_mask(self, idx):
|
100 |
-
mask = torch.ones((1, 256, 256), dtype=torch.bool)
|
101 |
-
bounding_box = self.bounding_boxes[idx]
|
102 |
-
x0, y0, x1, y1 = bounding_box
|
103 |
-
mask[:, y0:y1, x0:x1] = 0
|
104 |
-
return mask
|
105 |
-
|
106 |
-
def __len__(self):
|
107 |
-
return len(self.image_paths)
|
108 |
-
|
109 |
-
def __getitem__(self, index):
|
110 |
-
impath = self.image_paths[index]
|
111 |
-
if PYSPNG_IMPORTED:
|
112 |
-
with open(impath, "rb") as fp:
|
113 |
-
im = pyspng.load(fp.read())
|
114 |
-
else:
|
115 |
-
with Image.open(impath) as fp:
|
116 |
-
im = np.array(fp)
|
117 |
-
im = torch.from_numpy(np.rollaxis(im, -1, 0))
|
118 |
-
masks = self.get_mask(index)
|
119 |
-
landmark = self.landmarks[index]
|
120 |
-
batch = {
|
121 |
-
"img": im,
|
122 |
-
"mask": masks,
|
123 |
-
}
|
124 |
-
if self.load_keypoints:
|
125 |
-
batch["keypoints"] = landmark
|
126 |
-
if self.transform is None:
|
127 |
-
return batch
|
128 |
-
return self.transform(batch)
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/data/datasets/fdh.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import tops
|
3 |
-
import numpy as np
|
4 |
-
import io
|
5 |
-
import webdataset as wds
|
6 |
-
import os
|
7 |
-
from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
|
8 |
-
|
9 |
-
|
10 |
-
def kp_decoder(x):
|
11 |
-
# Keypoints are between [0, 1] for webdataset
|
12 |
-
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
|
13 |
-
keypoints[:, 0] /= 160
|
14 |
-
keypoints[:, 1] /= 288
|
15 |
-
check_outside = lambda x: (x < 0).logical_or(x > 1)
|
16 |
-
is_outside = check_outside(keypoints[:, 0]).logical_or(
|
17 |
-
check_outside(keypoints[:, 1])
|
18 |
-
)
|
19 |
-
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
|
20 |
-
return keypoints
|
21 |
-
|
22 |
-
|
23 |
-
def vertices_decoder(x):
|
24 |
-
vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
|
25 |
-
return vertices.squeeze()[None]
|
26 |
-
|
27 |
-
|
28 |
-
def get_dataloader_fdh_wds(
|
29 |
-
path,
|
30 |
-
batch_size: int,
|
31 |
-
num_workers: int,
|
32 |
-
transform: torch.nn.Module,
|
33 |
-
gpu_transform: torch.nn.Module,
|
34 |
-
infinite: bool,
|
35 |
-
shuffle: bool,
|
36 |
-
partial_batches: bool,
|
37 |
-
load_embedding: bool,
|
38 |
-
sample_shuffle=10_000,
|
39 |
-
tar_shuffle=100,
|
40 |
-
read_condition=False,
|
41 |
-
channels_last=False,
|
42 |
-
):
|
43 |
-
# Need to set this for split_by_node to work.
|
44 |
-
os.environ["RANK"] = str(tops.rank())
|
45 |
-
os.environ["WORLD_SIZE"] = str(tops.world_size())
|
46 |
-
if infinite:
|
47 |
-
pipeline = [wds.ResampledShards(str(path))]
|
48 |
-
else:
|
49 |
-
pipeline = [wds.SimpleShardList(str(path))]
|
50 |
-
if shuffle:
|
51 |
-
pipeline.append(wds.shuffle(tar_shuffle))
|
52 |
-
pipeline.extend([
|
53 |
-
wds.split_by_node,
|
54 |
-
wds.split_by_worker,
|
55 |
-
])
|
56 |
-
if shuffle:
|
57 |
-
pipeline.append(wds.shuffle(sample_shuffle))
|
58 |
-
|
59 |
-
decoder = [
|
60 |
-
wds.handle_extension("image.png", png_decoder),
|
61 |
-
wds.handle_extension("mask.png", mask_decoder),
|
62 |
-
wds.handle_extension("maskrcnn_mask.png", mask_decoder),
|
63 |
-
wds.handle_extension("keypoints.npy", kp_decoder),
|
64 |
-
]
|
65 |
-
|
66 |
-
rename_keys = [
|
67 |
-
["img", "image.png"], ["mask", "mask.png"],
|
68 |
-
["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"]
|
69 |
-
]
|
70 |
-
if load_embedding:
|
71 |
-
decoder.extend([
|
72 |
-
wds.handle_extension("vertices.npy", vertices_decoder),
|
73 |
-
wds.handle_extension("E_mask.png", mask_decoder)
|
74 |
-
])
|
75 |
-
rename_keys.extend([
|
76 |
-
["vertices", "vertices.npy"],
|
77 |
-
["E_mask", "e_mask.png"]
|
78 |
-
])
|
79 |
-
|
80 |
-
if read_condition:
|
81 |
-
decoder.append(
|
82 |
-
wds.handle_extension("condition.png", png_decoder)
|
83 |
-
)
|
84 |
-
rename_keys.append(["condition", "condition.png"])
|
85 |
-
|
86 |
-
pipeline.extend([
|
87 |
-
wds.tarfile_to_samples(),
|
88 |
-
wds.decode(*decoder),
|
89 |
-
wds.rename_keys(*rename_keys),
|
90 |
-
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
|
91 |
-
])
|
92 |
-
if transform is not None:
|
93 |
-
pipeline.append(wds.map(transform))
|
94 |
-
pipeline = wds.DataPipeline(*pipeline)
|
95 |
-
if infinite:
|
96 |
-
pipeline = pipeline.repeat(nepochs=1000000)
|
97 |
-
|
98 |
-
loader = wds.WebLoader(
|
99 |
-
pipeline, batch_size=None, shuffle=False,
|
100 |
-
num_workers=get_num_workers(num_workers),
|
101 |
-
persistent_workers=True,
|
102 |
-
)
|
103 |
-
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
|
104 |
-
return loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/data/transforms/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize
|
2 |
-
from .stylegan2_transform import StyleGANAugmentPipe
|
|
|
|
|
|
dp2/data/transforms/functional.py
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
import torchvision.transforms.functional as F
|
2 |
-
import torch
|
3 |
-
import pickle
|
4 |
-
from tops import download_file, assert_shape
|
5 |
-
from typing import Dict
|
6 |
-
from functools import lru_cache
|
7 |
-
|
8 |
-
global symmetry_transform
|
9 |
-
|
10 |
-
@lru_cache(maxsize=1)
|
11 |
-
def get_symmetry_transform(symmetry_url):
|
12 |
-
file_name = download_file(symmetry_url)
|
13 |
-
with open(file_name, "rb") as fp:
|
14 |
-
symmetry = pickle.load(fp)
|
15 |
-
return torch.from_numpy(symmetry["vertex_transforms"]).long()
|
16 |
-
|
17 |
-
|
18 |
-
hflip_handled_cases = set([
|
19 |
-
"keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition",
|
20 |
-
"embedding", "vertx2cat", "maskrcnn_mask", "__key__",
|
21 |
-
"img_hr", "condition_hr", "mask_hr"])
|
22 |
-
|
23 |
-
def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]:
|
24 |
-
container["img"] = F.hflip(container["img"])
|
25 |
-
if "condition" in container:
|
26 |
-
container["condition"] = F.hflip(container["condition"])
|
27 |
-
if "embedding" in container:
|
28 |
-
container["embedding"] = F.hflip(container["embedding"])
|
29 |
-
assert all([key in hflip_handled_cases for key in container]), container.keys()
|
30 |
-
if "keypoints" in container:
|
31 |
-
assert flip_map is not None
|
32 |
-
if container["keypoints"].ndim == 3:
|
33 |
-
keypoints = container["keypoints"][:, flip_map, :]
|
34 |
-
keypoints[:, :, 0] = 1 - keypoints[:, :, 0]
|
35 |
-
else:
|
36 |
-
assert_shape(container["keypoints"], (None, 3))
|
37 |
-
keypoints = container["keypoints"][flip_map, :]
|
38 |
-
keypoints[:, 0] = 1 - keypoints[:, 0]
|
39 |
-
container["keypoints"] = keypoints
|
40 |
-
if "mask" in container:
|
41 |
-
container["mask"] = F.hflip(container["mask"])
|
42 |
-
if "border" in container:
|
43 |
-
container["border"] = F.hflip(container["border"])
|
44 |
-
if "semantic_mask" in container:
|
45 |
-
container["semantic_mask"] = F.hflip(container["semantic_mask"])
|
46 |
-
if "vertices" in container:
|
47 |
-
symmetry_transform = get_symmetry_transform("https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl")
|
48 |
-
container["vertices"] = F.hflip(container["vertices"])
|
49 |
-
symmetry_transform_ = symmetry_transform.to(container["vertices"].device)
|
50 |
-
container["vertices"] = symmetry_transform_[container["vertices"].long()]
|
51 |
-
if "E_mask" in container:
|
52 |
-
container["E_mask"] = F.hflip(container["E_mask"])
|
53 |
-
if "maskrcnn_mask" in container:
|
54 |
-
container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"])
|
55 |
-
if "img_hr" in container:
|
56 |
-
container["img_hr"] = F.hflip(container["img_hr"])
|
57 |
-
if "condition_hr" in container:
|
58 |
-
container["condition_hr"] = F.hflip(container["condition_hr"])
|
59 |
-
if "mask_hr" in container:
|
60 |
-
container["mask_hr"] = F.hflip(container["mask_hr"])
|
61 |
-
return container
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/data/transforms/stylegan2_transform.py
DELETED
@@ -1,394 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import scipy.signal
|
3 |
-
import torch
|
4 |
-
try:
|
5 |
-
from sg3_torch_utils import misc
|
6 |
-
from sg3_torch_utils.ops import upfirdn2d
|
7 |
-
from sg3_torch_utils.ops import grid_sample_gradfix
|
8 |
-
from sg3_torch_utils.ops import conv2d_gradfix
|
9 |
-
except:
|
10 |
-
pass
|
11 |
-
#----------------------------------------------------------------------------
|
12 |
-
# Coefficients of various wavelet decomposition low-pass filters.
|
13 |
-
|
14 |
-
wavelets = {
|
15 |
-
'haar': [0.7071067811865476, 0.7071067811865476],
|
16 |
-
'db1': [0.7071067811865476, 0.7071067811865476],
|
17 |
-
'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
18 |
-
'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
19 |
-
'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
|
20 |
-
'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
|
21 |
-
'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
|
22 |
-
'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
|
23 |
-
'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
|
24 |
-
'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
25 |
-
'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
26 |
-
'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
|
27 |
-
'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
|
28 |
-
'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
|
29 |
-
'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
|
30 |
-
'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
|
31 |
-
}
|
32 |
-
|
33 |
-
#----------------------------------------------------------------------------
|
34 |
-
# Helpers for constructing transformation matrices.
|
35 |
-
|
36 |
-
|
37 |
-
def matrix(*rows, device=None):
|
38 |
-
assert all(len(row) == len(rows[0]) for row in rows)
|
39 |
-
elems = [x for row in rows for x in row]
|
40 |
-
ref = [x for x in elems if isinstance(x, torch.Tensor)]
|
41 |
-
if len(ref) == 0:
|
42 |
-
return misc.constant(np.asarray(rows), device=device)
|
43 |
-
assert device is None or device == ref[0].device
|
44 |
-
elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
|
45 |
-
return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
|
46 |
-
|
47 |
-
|
48 |
-
def translate2d(tx, ty, **kwargs):
|
49 |
-
return matrix(
|
50 |
-
[1, 0, tx],
|
51 |
-
[0, 1, ty],
|
52 |
-
[0, 0, 1],
|
53 |
-
**kwargs)
|
54 |
-
|
55 |
-
|
56 |
-
def translate3d(tx, ty, tz, **kwargs):
|
57 |
-
return matrix(
|
58 |
-
[1, 0, 0, tx],
|
59 |
-
[0, 1, 0, ty],
|
60 |
-
[0, 0, 1, tz],
|
61 |
-
[0, 0, 0, 1],
|
62 |
-
**kwargs)
|
63 |
-
|
64 |
-
|
65 |
-
def scale2d(sx, sy, **kwargs):
|
66 |
-
return matrix(
|
67 |
-
[sx, 0, 0],
|
68 |
-
[0, sy, 0],
|
69 |
-
[0, 0, 1],
|
70 |
-
**kwargs)
|
71 |
-
|
72 |
-
|
73 |
-
def scale3d(sx, sy, sz, **kwargs):
|
74 |
-
return matrix(
|
75 |
-
[sx, 0, 0, 0],
|
76 |
-
[0, sy, 0, 0],
|
77 |
-
[0, 0, sz, 0],
|
78 |
-
[0, 0, 0, 1],
|
79 |
-
**kwargs)
|
80 |
-
|
81 |
-
|
82 |
-
def rotate2d(theta, **kwargs):
|
83 |
-
return matrix(
|
84 |
-
[torch.cos(theta), torch.sin(-theta), 0],
|
85 |
-
[torch.sin(theta), torch.cos(theta), 0],
|
86 |
-
[0, 0, 1],
|
87 |
-
**kwargs)
|
88 |
-
|
89 |
-
|
90 |
-
def rotate3d(v, theta, **kwargs):
|
91 |
-
vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
|
92 |
-
s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
|
93 |
-
return matrix(
|
94 |
-
[vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
|
95 |
-
[vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
|
96 |
-
[vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
|
97 |
-
[0, 0, 0, 1],
|
98 |
-
**kwargs)
|
99 |
-
|
100 |
-
|
101 |
-
def translate2d_inv(tx, ty, **kwargs):
|
102 |
-
return translate2d(-tx, -ty, **kwargs)
|
103 |
-
|
104 |
-
|
105 |
-
def scale2d_inv(sx, sy, **kwargs):
|
106 |
-
return scale2d(1 / sx, 1 / sy, **kwargs)
|
107 |
-
|
108 |
-
|
109 |
-
def rotate2d_inv(theta, **kwargs):
|
110 |
-
return rotate2d(-theta, **kwargs)
|
111 |
-
|
112 |
-
|
113 |
-
class StyleGANAugmentPipe(torch.nn.Module):
|
114 |
-
def __init__(self,
|
115 |
-
rotate90=0, xint=0, xint_max=0.125,
|
116 |
-
scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
|
117 |
-
brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5,
|
118 |
-
hue_max=1, saturation_std=1,
|
119 |
-
imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
|
120 |
-
):
|
121 |
-
super().__init__()
|
122 |
-
self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
|
123 |
-
|
124 |
-
# Pixel blitting.
|
125 |
-
self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
|
126 |
-
self.xint = float(xint) # Probability multiplier for integer translation.
|
127 |
-
self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
|
128 |
-
|
129 |
-
# General geometric transformations.
|
130 |
-
self.scale = float(scale) # Probability multiplier for isotropic scaling.
|
131 |
-
self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
|
132 |
-
self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
|
133 |
-
self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
|
134 |
-
self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
|
135 |
-
self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
|
136 |
-
self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
|
137 |
-
self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
|
138 |
-
|
139 |
-
# Color transformations.
|
140 |
-
self.brightness = float(brightness) # Probability multiplier for brightness.
|
141 |
-
self.contrast = float(contrast) # Probability multiplier for contrast.
|
142 |
-
self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
|
143 |
-
self.hue = float(hue) # Probability multiplier for hue rotation.
|
144 |
-
self.saturation = float(saturation) # Probability multiplier for saturation.
|
145 |
-
self.brightness_std = float(brightness_std) # Standard deviation of brightness.
|
146 |
-
self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
|
147 |
-
self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
|
148 |
-
self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
|
149 |
-
|
150 |
-
# Image-space filtering.
|
151 |
-
self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
|
152 |
-
self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
|
153 |
-
self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
|
154 |
-
|
155 |
-
# Setup orthogonal lowpass filter for geometric augmentations.
|
156 |
-
self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
|
157 |
-
|
158 |
-
# Construct filter bank for image-space filtering.
|
159 |
-
Hz_lo = np.asarray(wavelets['sym2']) # H(z)
|
160 |
-
Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
|
161 |
-
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
|
162 |
-
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
|
163 |
-
Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
|
164 |
-
for i in range(1, Hz_fbank.shape[0]):
|
165 |
-
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
|
166 |
-
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
|
167 |
-
Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
|
168 |
-
self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
|
169 |
-
|
170 |
-
def forward(self, batch, debug_percentile=None):
|
171 |
-
images = batch["img"]
|
172 |
-
batch["vertices"] = batch["vertices"].float()
|
173 |
-
assert isinstance(images, torch.Tensor) and images.ndim == 4
|
174 |
-
batch_size, num_channels, height, width = images.shape
|
175 |
-
device = images.device
|
176 |
-
self.Hz_fbank = self.Hz_fbank.to(device)
|
177 |
-
self.Hz_geom = self.Hz_geom.to(device)
|
178 |
-
if debug_percentile is not None:
|
179 |
-
debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
|
180 |
-
|
181 |
-
# -------------------------------------
|
182 |
-
# Select parameters for pixel blitting.
|
183 |
-
# -------------------------------------
|
184 |
-
|
185 |
-
# Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
|
186 |
-
I_3 = torch.eye(3, device=device)
|
187 |
-
G_inv = I_3
|
188 |
-
|
189 |
-
# Apply integer translation with probability (xint * strength).
|
190 |
-
if self.xint > 0:
|
191 |
-
t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
|
192 |
-
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
|
193 |
-
if debug_percentile is not None:
|
194 |
-
t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
|
195 |
-
G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
|
196 |
-
|
197 |
-
# --------------------------------------------------------
|
198 |
-
# Select parameters for general geometric transformations.
|
199 |
-
# --------------------------------------------------------
|
200 |
-
|
201 |
-
# Apply isotropic scaling with probability (scale * strength).
|
202 |
-
if self.scale > 0:
|
203 |
-
s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
|
204 |
-
s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
|
205 |
-
if debug_percentile is not None:
|
206 |
-
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
|
207 |
-
G_inv = G_inv @ scale2d_inv(s, s)
|
208 |
-
|
209 |
-
# Apply pre-rotation with probability p_rot.
|
210 |
-
p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
|
211 |
-
if self.rotate > 0:
|
212 |
-
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
213 |
-
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
214 |
-
if debug_percentile is not None:
|
215 |
-
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
|
216 |
-
G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
|
217 |
-
|
218 |
-
# Apply anisotropic scaling with probability (aniso * strength).
|
219 |
-
if self.aniso > 0:
|
220 |
-
s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
|
221 |
-
s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
|
222 |
-
if debug_percentile is not None:
|
223 |
-
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
|
224 |
-
G_inv = G_inv @ scale2d_inv(s, 1 / s)
|
225 |
-
|
226 |
-
# Apply post-rotation with probability p_rot.
|
227 |
-
if self.rotate > 0:
|
228 |
-
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
229 |
-
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
230 |
-
if debug_percentile is not None:
|
231 |
-
theta = torch.zeros_like(theta)
|
232 |
-
G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
|
233 |
-
|
234 |
-
# Apply fractional translation with probability (xfrac * strength).
|
235 |
-
if self.xfrac > 0:
|
236 |
-
t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
|
237 |
-
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
|
238 |
-
if debug_percentile is not None:
|
239 |
-
t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
|
240 |
-
G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
|
241 |
-
|
242 |
-
# ----------------------------------
|
243 |
-
# Execute geometric transformations.
|
244 |
-
# ----------------------------------
|
245 |
-
|
246 |
-
# Execute if the transform is not identity.
|
247 |
-
if G_inv is not I_3:
|
248 |
-
# Calculate padding.
|
249 |
-
cx = (width - 1) / 2
|
250 |
-
cy = (height - 1) / 2
|
251 |
-
cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
|
252 |
-
cp = G_inv @ cp.t() # [batch, xyz, idx]
|
253 |
-
Hz_pad = self.Hz_geom.shape[0] // 4
|
254 |
-
margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
|
255 |
-
margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
|
256 |
-
margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
|
257 |
-
margin = margin.max(misc.constant([0, 0] * 2, device=device))
|
258 |
-
margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
|
259 |
-
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
|
260 |
-
|
261 |
-
# Pad image and adjust origin.
|
262 |
-
images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
|
263 |
-
batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0)
|
264 |
-
batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
|
265 |
-
batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
|
266 |
-
G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
|
267 |
-
|
268 |
-
# Upsample.
|
269 |
-
images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
|
270 |
-
batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest")
|
271 |
-
batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest")
|
272 |
-
batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest")
|
273 |
-
G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
|
274 |
-
G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
|
275 |
-
|
276 |
-
# Execute transformation.
|
277 |
-
shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
|
278 |
-
G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
|
279 |
-
grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
|
280 |
-
images = grid_sample_gradfix.grid_sample(images, grid)
|
281 |
-
|
282 |
-
batch["mask"] = torch.nn.functional.grid_sample(
|
283 |
-
input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
284 |
-
batch["E_mask"] = torch.nn.functional.grid_sample(
|
285 |
-
input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
286 |
-
batch["vertices"] = torch.nn.functional.grid_sample(
|
287 |
-
input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
288 |
-
|
289 |
-
|
290 |
-
# Downsample and crop.
|
291 |
-
images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
|
292 |
-
batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
293 |
-
batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
294 |
-
batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
295 |
-
# --------------------------------------------
|
296 |
-
# Select parameters for color transformations.
|
297 |
-
# --------------------------------------------
|
298 |
-
|
299 |
-
# Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
|
300 |
-
I_4 = torch.eye(4, device=device)
|
301 |
-
C = I_4
|
302 |
-
|
303 |
-
# Apply brightness with probability (brightness * strength).
|
304 |
-
if self.brightness > 0:
|
305 |
-
b = torch.randn([batch_size], device=device) * self.brightness_std
|
306 |
-
b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
|
307 |
-
if debug_percentile is not None:
|
308 |
-
b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
|
309 |
-
C = translate3d(b, b, b) @ C
|
310 |
-
|
311 |
-
# Apply contrast with probability (contrast * strength).
|
312 |
-
if self.contrast > 0:
|
313 |
-
c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
|
314 |
-
c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
|
315 |
-
if debug_percentile is not None:
|
316 |
-
c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
|
317 |
-
C = scale3d(c, c, c) @ C
|
318 |
-
|
319 |
-
# Apply luma flip with probability (lumaflip * strength).
|
320 |
-
v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
|
321 |
-
|
322 |
-
# Apply hue rotation with probability (hue * strength).
|
323 |
-
if self.hue > 0 and num_channels > 1:
|
324 |
-
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
|
325 |
-
theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
|
326 |
-
if debug_percentile is not None:
|
327 |
-
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
|
328 |
-
C = rotate3d(v, theta) @ C # Rotate around v.
|
329 |
-
|
330 |
-
# Apply saturation with probability (saturation * strength).
|
331 |
-
if self.saturation > 0 and num_channels > 1:
|
332 |
-
s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
|
333 |
-
s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
|
334 |
-
if debug_percentile is not None:
|
335 |
-
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
|
336 |
-
C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
|
337 |
-
|
338 |
-
# ------------------------------
|
339 |
-
# Execute color transformations.
|
340 |
-
# ------------------------------
|
341 |
-
|
342 |
-
# Execute if the transform is not identity.
|
343 |
-
if C is not I_4:
|
344 |
-
images = images.reshape([batch_size, num_channels, height * width])
|
345 |
-
if num_channels == 3:
|
346 |
-
images = C[:, :3, :3] @ images + C[:, :3, 3:]
|
347 |
-
elif num_channels == 1:
|
348 |
-
C = C[:, :3, :].mean(dim=1, keepdims=True)
|
349 |
-
images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
|
350 |
-
else:
|
351 |
-
raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
|
352 |
-
images = images.reshape([batch_size, num_channels, height, width])
|
353 |
-
|
354 |
-
# ----------------------
|
355 |
-
# Image-space filtering.
|
356 |
-
# ----------------------
|
357 |
-
|
358 |
-
if self.imgfilter > 0:
|
359 |
-
num_bands = self.Hz_fbank.shape[0]
|
360 |
-
assert len(self.imgfilter_bands) == num_bands
|
361 |
-
expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
|
362 |
-
|
363 |
-
# Apply amplification for each band with probability (imgfilter * strength * band_strength).
|
364 |
-
g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
|
365 |
-
for i, band_strength in enumerate(self.imgfilter_bands):
|
366 |
-
t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
|
367 |
-
t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
|
368 |
-
if debug_percentile is not None:
|
369 |
-
t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
|
370 |
-
t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
|
371 |
-
t[:, i] = t_i # Replace i'th element.
|
372 |
-
t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
|
373 |
-
g = g * t # Accumulate into global gain.
|
374 |
-
|
375 |
-
# Construct combined amplification filter.
|
376 |
-
Hz_prime = g @ self.Hz_fbank # [batch, tap]
|
377 |
-
Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
|
378 |
-
Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
|
379 |
-
|
380 |
-
# Apply filter.
|
381 |
-
p = self.Hz_fbank.shape[1] // 2
|
382 |
-
images = images.reshape([1, batch_size * num_channels, height, width])
|
383 |
-
images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
|
384 |
-
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
|
385 |
-
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
|
386 |
-
images = images.reshape([batch_size, num_channels, height, width])
|
387 |
-
|
388 |
-
# ------------------------
|
389 |
-
# Image-space corruptions.
|
390 |
-
# ------------------------
|
391 |
-
batch["img"] = images
|
392 |
-
batch["vertices"] = batch["vertices"].long()
|
393 |
-
batch["border"] = 1 - batch["E_mask"] - batch["mask"]
|
394 |
-
return batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/data/transforms/transforms.py
DELETED
@@ -1,247 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
from typing import Dict, List
|
3 |
-
import torchvision
|
4 |
-
import torch
|
5 |
-
import tops
|
6 |
-
import torchvision.transforms.functional as F
|
7 |
-
from .functional import hflip
|
8 |
-
|
9 |
-
|
10 |
-
class RandomHorizontalFlip(torch.nn.Module):
|
11 |
-
|
12 |
-
def __init__(self, p: float, flip_map=None,**kwargs):
|
13 |
-
super().__init__()
|
14 |
-
self.flip_ratio = p
|
15 |
-
self.flip_map = flip_map
|
16 |
-
if self.flip_ratio is None:
|
17 |
-
self.flip_ratio = 0.5
|
18 |
-
assert 0 <= self.flip_ratio <= 1
|
19 |
-
|
20 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
21 |
-
if torch.rand(1) > self.flip_ratio:
|
22 |
-
return container
|
23 |
-
return hflip(container, self.flip_map)
|
24 |
-
|
25 |
-
|
26 |
-
class CenterCrop(torch.nn.Module):
|
27 |
-
"""
|
28 |
-
Performs the transform on the image.
|
29 |
-
NOTE: Does not transform the mask to improve runtime.
|
30 |
-
"""
|
31 |
-
|
32 |
-
def __init__(self, size: List[int]):
|
33 |
-
super().__init__()
|
34 |
-
self.size = tuple(size)
|
35 |
-
|
36 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
37 |
-
min_size = min(container["img"].shape[1], container["img"].shape[2])
|
38 |
-
if min_size < self.size[0]:
|
39 |
-
container["img"] = F.center_crop(container["img"], min_size)
|
40 |
-
container["img"] = F.resize(container["img"], self.size)
|
41 |
-
return container
|
42 |
-
container["img"] = F.center_crop(container["img"], self.size)
|
43 |
-
return container
|
44 |
-
|
45 |
-
|
46 |
-
class Resize(torch.nn.Module):
|
47 |
-
"""
|
48 |
-
Performs the transform on the image.
|
49 |
-
NOTE: Does not transform the mask to improve runtime.
|
50 |
-
"""
|
51 |
-
|
52 |
-
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
|
53 |
-
super().__init__()
|
54 |
-
self.size = tuple(size)
|
55 |
-
self.interpolation = interpolation
|
56 |
-
|
57 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
58 |
-
container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
|
59 |
-
if "semantic_mask" in container:
|
60 |
-
container["semantic_mask"] = F.resize(
|
61 |
-
container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
|
62 |
-
if "embedding" in container:
|
63 |
-
container["embedding"] = F.resize(
|
64 |
-
container["embedding"], self.size, self.interpolation)
|
65 |
-
if "mask" in container:
|
66 |
-
container["mask"] = F.resize(
|
67 |
-
container["mask"], self.size, F.InterpolationMode.NEAREST)
|
68 |
-
if "E_mask" in container:
|
69 |
-
container["E_mask"] = F.resize(
|
70 |
-
container["E_mask"], self.size, F.InterpolationMode.NEAREST)
|
71 |
-
if "maskrcnn_mask" in container:
|
72 |
-
container["maskrcnn_mask"] = F.resize(
|
73 |
-
container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
|
74 |
-
if "vertices" in container:
|
75 |
-
container["vertices"] = F.resize(
|
76 |
-
container["vertices"], self.size, F.InterpolationMode.NEAREST)
|
77 |
-
return container
|
78 |
-
|
79 |
-
def __repr__(self):
|
80 |
-
repr = super().__repr__()
|
81 |
-
vars_ = dict(size=self.size, interpolation=self.interpolation)
|
82 |
-
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
|
83 |
-
|
84 |
-
|
85 |
-
class InsertHRImage(torch.nn.Module):
|
86 |
-
"""
|
87 |
-
Resizes mask by maxpool and assumes condition is already created
|
88 |
-
"""
|
89 |
-
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
|
90 |
-
super().__init__()
|
91 |
-
self.size = tuple(size)
|
92 |
-
self.interpolation = interpolation
|
93 |
-
|
94 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
95 |
-
assert container["img"].dtype == torch.float32
|
96 |
-
container["img_hr"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
|
97 |
-
container["condition_hr"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
|
98 |
-
mask = container["mask"] > 0
|
99 |
-
container["mask_hr"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
|
100 |
-
container["condition_hr"] = container["condition_hr"] * (1 - container["mask_hr"]) + container["img_hr"] * container["mask_hr"]
|
101 |
-
return container
|
102 |
-
|
103 |
-
def __repr__(self):
|
104 |
-
repr = super().__repr__()
|
105 |
-
vars_ = dict(size=self.size, interpolation=self.interpolation)
|
106 |
-
return repr + " "
|
107 |
-
|
108 |
-
|
109 |
-
class CopyHRImage(torch.nn.Module):
|
110 |
-
def __init__(self) -> None:
|
111 |
-
super().__init__()
|
112 |
-
|
113 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
114 |
-
container["img_hr"] = container["img"]
|
115 |
-
container["condition_hr"] = container["condition"]
|
116 |
-
container["mask_hr"] = container["mask"]
|
117 |
-
return container
|
118 |
-
|
119 |
-
|
120 |
-
class Resize2(torch.nn.Module):
|
121 |
-
"""
|
122 |
-
Resizes mask by maxpool and assumes condition is already created
|
123 |
-
"""
|
124 |
-
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, downsample_condition: bool = True, mask_condition= True):
|
125 |
-
super().__init__()
|
126 |
-
self.size = tuple(size)
|
127 |
-
self.interpolation = interpolation
|
128 |
-
self.downsample_condition = downsample_condition
|
129 |
-
self.mask_condition = mask_condition
|
130 |
-
|
131 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
132 |
-
# assert container["img"].dtype == torch.float32
|
133 |
-
container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
|
134 |
-
mask = container["mask"] > 0
|
135 |
-
container["mask"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
|
136 |
-
|
137 |
-
if self.downsample_condition:
|
138 |
-
container["condition"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
|
139 |
-
if self.mask_condition:
|
140 |
-
container["condition"] = container["condition"] * (1 - container["mask"]) + container["img"] * container["mask"]
|
141 |
-
return container
|
142 |
-
|
143 |
-
def __repr__(self):
|
144 |
-
repr = super().__repr__()
|
145 |
-
vars_ = dict(size=self.size, interpolation=self.interpolation)
|
146 |
-
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
class Normalize(torch.nn.Module):
|
151 |
-
"""
|
152 |
-
Performs the transform on the image.
|
153 |
-
NOTE: Does not transform the mask to improve runtime.
|
154 |
-
"""
|
155 |
-
|
156 |
-
def __init__(self, mean, std, inplace, keys=["img"]):
|
157 |
-
super().__init__()
|
158 |
-
self.mean = mean
|
159 |
-
self.std = std
|
160 |
-
self.inplace = inplace
|
161 |
-
self.keys = keys
|
162 |
-
|
163 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
164 |
-
for key in self.keys:
|
165 |
-
container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
|
166 |
-
return container
|
167 |
-
|
168 |
-
def __repr__(self):
|
169 |
-
repr = super().__repr__()
|
170 |
-
vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
|
171 |
-
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
|
172 |
-
|
173 |
-
|
174 |
-
class ToFloat(torch.nn.Module):
|
175 |
-
|
176 |
-
def __init__(self, keys=["img"], norm=True) -> None:
|
177 |
-
super().__init__()
|
178 |
-
self.keys = keys
|
179 |
-
self.gain = 255 if norm else 1
|
180 |
-
|
181 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
182 |
-
for key in self.keys:
|
183 |
-
container[key] = container[key].float() / self.gain
|
184 |
-
return container
|
185 |
-
|
186 |
-
|
187 |
-
class RandomCrop(torchvision.transforms.RandomCrop):
|
188 |
-
"""
|
189 |
-
Performs the transform on the image.
|
190 |
-
NOTE: Does not transform the mask to improve runtime.
|
191 |
-
"""
|
192 |
-
|
193 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
194 |
-
container["img"] = super().forward(container["img"])
|
195 |
-
return container
|
196 |
-
|
197 |
-
|
198 |
-
class CreateCondition(torch.nn.Module):
|
199 |
-
|
200 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
201 |
-
if container["img"].dtype == torch.uint8:
|
202 |
-
container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
|
203 |
-
return container
|
204 |
-
container["condition"] = container["img"] * container["mask"]
|
205 |
-
return container
|
206 |
-
|
207 |
-
|
208 |
-
class CreateEmbedding(torch.nn.Module):
|
209 |
-
|
210 |
-
def __init__(self, embed_path: Path, cuda=True) -> None:
|
211 |
-
super().__init__()
|
212 |
-
self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
|
213 |
-
if cuda:
|
214 |
-
self.embed_map = tops.to_cuda(self.embed_map)
|
215 |
-
|
216 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
217 |
-
vertices = container["vertices"]
|
218 |
-
if vertices.ndim == 3:
|
219 |
-
embedding = self.embed_map[vertices.long()].squeeze(dim=0)
|
220 |
-
embedding = embedding.permute(2, 0, 1) * container["E_mask"]
|
221 |
-
pass
|
222 |
-
else:
|
223 |
-
assert vertices.ndim == 4
|
224 |
-
embedding = self.embed_map[vertices.long()].squeeze(dim=1)
|
225 |
-
embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
|
226 |
-
container["embedding"] = embedding
|
227 |
-
container["embed_map"] = self.embed_map.clone()
|
228 |
-
return container
|
229 |
-
|
230 |
-
|
231 |
-
class UpdateMask(torch.nn.Module):
|
232 |
-
|
233 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
234 |
-
container["mask"] = (container["img"] == container["condition"]).any(dim=1, keepdims=True).float()
|
235 |
-
return container
|
236 |
-
|
237 |
-
|
238 |
-
class LoadClassEmbedding(torch.nn.Module):
|
239 |
-
|
240 |
-
def __init__(self, embedding_path: Path) -> None:
|
241 |
-
super().__init__()
|
242 |
-
self.embedding = torch.load(embedding_path, map_location="cpu")
|
243 |
-
|
244 |
-
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
245 |
-
key = "_".join(container["__key__"].split("train/")[-1].split("/")[:-1])
|
246 |
-
container["class_embedding"] = self.embedding[key].view(-1)
|
247 |
-
return container
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/data/utils.py
DELETED
@@ -1,102 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from PIL import Image
|
3 |
-
import numpy as np
|
4 |
-
import multiprocessing
|
5 |
-
import io
|
6 |
-
from tops import logger
|
7 |
-
from torch.utils.data._utils.collate import default_collate
|
8 |
-
|
9 |
-
try:
|
10 |
-
import pyspng
|
11 |
-
|
12 |
-
PYSPNG_IMPORTED = True
|
13 |
-
except ImportError:
|
14 |
-
PYSPNG_IMPORTED = False
|
15 |
-
print("Could not load pyspng. Defaulting to pillow image backend.")
|
16 |
-
from PIL import Image
|
17 |
-
|
18 |
-
|
19 |
-
def get_coco_keypoints():
|
20 |
-
return [
|
21 |
-
"nose",
|
22 |
-
"left_eye",
|
23 |
-
"right_eye",
|
24 |
-
"left_ear",
|
25 |
-
"right_ear",
|
26 |
-
"left_shoulder",
|
27 |
-
"right_shoulder",
|
28 |
-
"left_elbow",
|
29 |
-
"right_elbow",
|
30 |
-
"left_wrist",
|
31 |
-
"right_wrist",
|
32 |
-
"left_hip",
|
33 |
-
"right_hip",
|
34 |
-
"left_knee",
|
35 |
-
"right_knee",
|
36 |
-
"left_ankle",
|
37 |
-
"right_ankle",
|
38 |
-
]
|
39 |
-
|
40 |
-
|
41 |
-
def get_coco_flipmap():
|
42 |
-
keypoints = get_coco_keypoints()
|
43 |
-
keypoint_flip_map = {
|
44 |
-
"left_eye": "right_eye",
|
45 |
-
"left_ear": "right_ear",
|
46 |
-
"left_shoulder": "right_shoulder",
|
47 |
-
"left_elbow": "right_elbow",
|
48 |
-
"left_wrist": "right_wrist",
|
49 |
-
"left_hip": "right_hip",
|
50 |
-
"left_knee": "right_knee",
|
51 |
-
"left_ankle": "right_ankle",
|
52 |
-
}
|
53 |
-
for key, value in list(keypoint_flip_map.items()):
|
54 |
-
keypoint_flip_map[value] = key
|
55 |
-
keypoint_flip_map["nose"] = "nose"
|
56 |
-
keypoint_flip_map_idx = []
|
57 |
-
for source in keypoints:
|
58 |
-
keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
|
59 |
-
return keypoint_flip_map_idx
|
60 |
-
|
61 |
-
|
62 |
-
def mask_decoder(x):
|
63 |
-
mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
|
64 |
-
mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255.
|
65 |
-
return mask
|
66 |
-
|
67 |
-
|
68 |
-
def png_decoder(x):
|
69 |
-
if PYSPNG_IMPORTED:
|
70 |
-
return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
|
71 |
-
with Image.open(io.BytesIO(x)) as im:
|
72 |
-
im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
|
73 |
-
return im
|
74 |
-
|
75 |
-
|
76 |
-
def jpg_decoder(x):
|
77 |
-
with Image.open(io.BytesIO(x)) as im:
|
78 |
-
im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
|
79 |
-
return im
|
80 |
-
|
81 |
-
|
82 |
-
def get_num_workers(num_workers: int):
|
83 |
-
n_cpus = multiprocessing.cpu_count()
|
84 |
-
if num_workers > n_cpus:
|
85 |
-
logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
|
86 |
-
return n_cpus
|
87 |
-
return num_workers
|
88 |
-
|
89 |
-
|
90 |
-
def collate_fn(batch):
|
91 |
-
elem = batch[0]
|
92 |
-
ignore_keys = set(["embed_map", "vertx2cat"])
|
93 |
-
batch_ = {
|
94 |
-
key: default_collate([d[key] for d in batch])
|
95 |
-
for key in elem
|
96 |
-
if key not in ignore_keys
|
97 |
-
}
|
98 |
-
if "embed_map" in elem:
|
99 |
-
batch_["embed_map"] = elem["embed_map"]
|
100 |
-
if "vertx2cat" in elem:
|
101 |
-
batch_["vertx2cat"] = elem["vertx2cat"]
|
102 |
-
return batch_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/detection/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .cse_mask_face_detector import CSeMaskFaceDetector
|
2 |
-
from .person_detector import CSEPersonDetector
|
3 |
-
from .structures import PersonDetection, VehicleDetection, FaceDetection
|
|
|
|
|
|
|
|
dp2/detection/base.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
import pickle
|
2 |
-
import torch
|
3 |
-
import lzma
|
4 |
-
from pathlib import Path
|
5 |
-
from tops import logger
|
6 |
-
|
7 |
-
|
8 |
-
class BaseDetector:
|
9 |
-
|
10 |
-
|
11 |
-
def __init__(self, cache_directory: str) -> None:
|
12 |
-
if cache_directory is not None:
|
13 |
-
self.cache_directory = Path(cache_directory, str(self.__class__.__name__))
|
14 |
-
self.cache_directory.mkdir(exist_ok=True, parents=True)
|
15 |
-
|
16 |
-
def save_to_cache(self, detection, cache_path: Path, after_preprocess=True):
|
17 |
-
logger.log(f"Caching detection to: {cache_path}")
|
18 |
-
with lzma.open(cache_path, "wb") as fp:
|
19 |
-
torch.save(
|
20 |
-
[det.state_dict(after_preprocess=after_preprocess) for det in detection], fp,
|
21 |
-
pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
22 |
-
|
23 |
-
def load_from_cache(self, cache_path: Path):
|
24 |
-
logger.log(f"Loading detection from cache path: {cache_path}")
|
25 |
-
with lzma.open(cache_path, "rb") as fp:
|
26 |
-
state_dict = torch.load(fp)
|
27 |
-
return [
|
28 |
-
state["cls"].from_state_dict(state_dict=state) for state in state_dict
|
29 |
-
]
|
30 |
-
|
31 |
-
def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool):
|
32 |
-
if cache_id is None:
|
33 |
-
return self.forward(im)
|
34 |
-
cache_path = self.cache_directory.joinpath(cache_id + ".torch")
|
35 |
-
if cache_path.is_file() and load_cache:
|
36 |
-
try:
|
37 |
-
return self.load_from_cache(cache_path)
|
38 |
-
except Exception as e:
|
39 |
-
logger.warn(f"The cache file was corrupted: {cache_path}")
|
40 |
-
exit()
|
41 |
-
detections = self.forward(im)
|
42 |
-
self.save_to_cache(detections, cache_path)
|
43 |
-
return detections
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/detection/box_utils.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
|
3 |
-
|
4 |
-
def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio):
|
5 |
-
x0, y0, x1, y1 = [int(_) for _ in bbox]
|
6 |
-
h, w = y1 - y0, x1 - x0
|
7 |
-
cur_ratio = h / w
|
8 |
-
|
9 |
-
if cur_ratio == target_aspect_ratio:
|
10 |
-
return [x0, y0, x1, y1]
|
11 |
-
if cur_ratio < target_aspect_ratio:
|
12 |
-
target_height = int(w*target_aspect_ratio)
|
13 |
-
y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
|
14 |
-
else:
|
15 |
-
target_width = int(h/target_aspect_ratio)
|
16 |
-
x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
|
17 |
-
return x0, y0, x1, y1
|
18 |
-
|
19 |
-
|
20 |
-
def expand_axis(start, end, target_width, limit):
|
21 |
-
# Can return a bbox outside of limit
|
22 |
-
cur_width = end - start
|
23 |
-
start = start - (target_width-cur_width)//2
|
24 |
-
end = end + (target_width-cur_width)//2
|
25 |
-
if end - start != target_width:
|
26 |
-
end += 1
|
27 |
-
assert end - start == target_width
|
28 |
-
if start < 0 and end > limit:
|
29 |
-
return start, end
|
30 |
-
if start < 0 and end < limit:
|
31 |
-
to_shift = min(0 - start, limit - end)
|
32 |
-
start += to_shift
|
33 |
-
end += to_shift
|
34 |
-
if end > limit and start > 0:
|
35 |
-
to_shift = min(end - limit, start)
|
36 |
-
end -= to_shift
|
37 |
-
start -= to_shift
|
38 |
-
assert end - start == target_width
|
39 |
-
return start, end
|
40 |
-
|
41 |
-
|
42 |
-
def expand_box(bbox, imshape, mask, percentage_background: float):
|
43 |
-
assert isinstance(bbox[0], int)
|
44 |
-
assert 0 < percentage_background < 1
|
45 |
-
# Percentage in S
|
46 |
-
mask_pixels = mask.long().sum().cpu()
|
47 |
-
total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
48 |
-
percentage_mask = mask_pixels / total_pixels
|
49 |
-
if (1 - percentage_mask) > percentage_background:
|
50 |
-
return bbox
|
51 |
-
target_pixels = mask_pixels / (1 - percentage_background)
|
52 |
-
x0, y0, x1, y1 = bbox
|
53 |
-
H = y1 - y0
|
54 |
-
W = x1 - x0
|
55 |
-
p = np.sqrt(target_pixels/(H*W))
|
56 |
-
target_width = int(np.ceil(p * W))
|
57 |
-
target_height = int(np.ceil(p * H))
|
58 |
-
x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
|
59 |
-
y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
|
60 |
-
return [x0, y0, x1, y1]
|
61 |
-
|
62 |
-
|
63 |
-
def expand_axises_by_percentage(bbox_XYXY, imshape, percentage):
|
64 |
-
x0, y0, x1, y1 = bbox_XYXY
|
65 |
-
H = y1 - y0
|
66 |
-
W = x1 - x0
|
67 |
-
expansion = int(((H*W)**0.5) * percentage)
|
68 |
-
new_width = W + expansion
|
69 |
-
new_height = H + expansion
|
70 |
-
x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1])
|
71 |
-
y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0])
|
72 |
-
return [x0, y0, x1, y1]
|
73 |
-
|
74 |
-
|
75 |
-
def get_expanded_bbox(
|
76 |
-
bbox_XYXY,
|
77 |
-
imshape,
|
78 |
-
mask,
|
79 |
-
percentage_background: float,
|
80 |
-
axis_minimum_expansion: float,
|
81 |
-
target_aspect_ratio: float):
|
82 |
-
bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist()
|
83 |
-
# Expand each axis of the bounding box by a minimum percentage
|
84 |
-
bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion)
|
85 |
-
# Find the minimum bbox with the aspect ratio. Can be outside of imshape
|
86 |
-
bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio)
|
87 |
-
# Expands square box such that X% of the bbox is background
|
88 |
-
bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background)
|
89 |
-
assert isinstance(bbox_XYXY[0], (int, np.int64))
|
90 |
-
return bbox_XYXY
|
91 |
-
|
92 |
-
|
93 |
-
def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape):
|
94 |
-
def area_inside_ratio(bbox, imshape):
|
95 |
-
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
96 |
-
area_inside = (min(bbox[2], imshape[1]) - max(0,bbox[0])) * (min(imshape[0],bbox[3]) - max(0,bbox[1]))
|
97 |
-
return area_inside / area
|
98 |
-
ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0])
|
99 |
-
area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
100 |
-
if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside:
|
101 |
-
return False
|
102 |
-
if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area:
|
103 |
-
return False
|
104 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/detection/box_utils_fdf.py
DELETED
@@ -1,203 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
The FDF dataset expands bound boxes differently from what is used for CSE.
|
3 |
-
"""
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
|
8 |
-
def quadratic_bounding_box(x0, y0, width, height, imshape):
|
9 |
-
# We assume that we can create a image that is quadratic without
|
10 |
-
# minimizing any of the sides
|
11 |
-
assert width <= min(imshape[:2])
|
12 |
-
assert height <= min(imshape[:2])
|
13 |
-
min_side = min(height, width)
|
14 |
-
if height != width:
|
15 |
-
side_diff = abs(height - width)
|
16 |
-
# Want to extend the shortest side
|
17 |
-
if min_side == height:
|
18 |
-
# Vertical side
|
19 |
-
height += side_diff
|
20 |
-
if height > imshape[0]:
|
21 |
-
# Take full frame, and shrink width
|
22 |
-
y0 = 0
|
23 |
-
height = imshape[0]
|
24 |
-
|
25 |
-
side_diff = abs(height - width)
|
26 |
-
width -= side_diff
|
27 |
-
x0 += side_diff // 2
|
28 |
-
else:
|
29 |
-
y0 -= side_diff // 2
|
30 |
-
y0 = max(0, y0)
|
31 |
-
else:
|
32 |
-
# Horizontal side
|
33 |
-
width += side_diff
|
34 |
-
if width > imshape[1]:
|
35 |
-
# Take full frame width, and shrink height
|
36 |
-
x0 = 0
|
37 |
-
width = imshape[1]
|
38 |
-
|
39 |
-
side_diff = abs(height - width)
|
40 |
-
height -= side_diff
|
41 |
-
y0 += side_diff // 2
|
42 |
-
else:
|
43 |
-
x0 -= side_diff // 2
|
44 |
-
x0 = max(0, x0)
|
45 |
-
# Check that bbox goes outside image
|
46 |
-
x1 = x0 + width
|
47 |
-
y1 = y0 + height
|
48 |
-
if imshape[1] < x1:
|
49 |
-
diff = x1 - imshape[1]
|
50 |
-
x0 -= diff
|
51 |
-
if imshape[0] < y1:
|
52 |
-
diff = y1 - imshape[0]
|
53 |
-
y0 -= diff
|
54 |
-
assert x0 >= 0, "Bounding box outside image."
|
55 |
-
assert y0 >= 0, "Bounding box outside image."
|
56 |
-
assert x0 + width <= imshape[1], "Bounding box outside image."
|
57 |
-
assert y0 + height <= imshape[0], "Bounding box outside image."
|
58 |
-
return x0, y0, width, height
|
59 |
-
|
60 |
-
|
61 |
-
def expand_bounding_box(bbox, percentage, imshape):
|
62 |
-
orig_bbox = bbox.copy()
|
63 |
-
x0, y0, x1, y1 = bbox
|
64 |
-
width = x1 - x0
|
65 |
-
height = y1 - y0
|
66 |
-
x0, y0, width, height = quadratic_bounding_box(
|
67 |
-
x0, y0, width, height, imshape)
|
68 |
-
expanding_factor = int(max(height, width) * percentage)
|
69 |
-
|
70 |
-
possible_max_expansion = [(imshape[0] - width) // 2,
|
71 |
-
(imshape[1] - height) // 2,
|
72 |
-
expanding_factor]
|
73 |
-
|
74 |
-
expanding_factor = min(possible_max_expansion)
|
75 |
-
# Expand height
|
76 |
-
|
77 |
-
if expanding_factor > 0:
|
78 |
-
|
79 |
-
y0 = y0 - expanding_factor
|
80 |
-
y0 = max(0, y0)
|
81 |
-
|
82 |
-
height += expanding_factor * 2
|
83 |
-
if height > imshape[0]:
|
84 |
-
y0 -= (imshape[0] - height)
|
85 |
-
height = imshape[0]
|
86 |
-
|
87 |
-
if height + y0 > imshape[0]:
|
88 |
-
y0 -= (height + y0 - imshape[0])
|
89 |
-
|
90 |
-
# Expand width
|
91 |
-
x0 = x0 - expanding_factor
|
92 |
-
x0 = max(0, x0)
|
93 |
-
|
94 |
-
width += expanding_factor * 2
|
95 |
-
if width > imshape[1]:
|
96 |
-
x0 -= (imshape[1] - width)
|
97 |
-
width = imshape[1]
|
98 |
-
|
99 |
-
if width + x0 > imshape[1]:
|
100 |
-
x0 -= (width + x0 - imshape[1])
|
101 |
-
y1 = y0 + height
|
102 |
-
x1 = x0 + width
|
103 |
-
assert y0 >= 0, "Y0 is minus"
|
104 |
-
assert height <= imshape[0], "Height is larger than image."
|
105 |
-
assert x0 + width <= imshape[1]
|
106 |
-
assert y0 + height <= imshape[0]
|
107 |
-
assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!"
|
108 |
-
assert x0 >= 0, "Y0 is minus"
|
109 |
-
assert width <= imshape[1], "Height is larger than image."
|
110 |
-
# Check that original bbox is within new
|
111 |
-
x0_o, y0_o, x1_o, y1_o = orig_bbox
|
112 |
-
assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}"
|
113 |
-
assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}"
|
114 |
-
assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}"
|
115 |
-
assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}"
|
116 |
-
|
117 |
-
x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]]
|
118 |
-
x1 = x0 + width
|
119 |
-
y1 = y0 + height
|
120 |
-
return np.array([x0, y0, x1, y1])
|
121 |
-
|
122 |
-
|
123 |
-
def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
|
124 |
-
keypoint = keypoint[:, :3] # only nose + eyes are relevant
|
125 |
-
kp_X = keypoint[0, :]
|
126 |
-
kp_Y = keypoint[1, :]
|
127 |
-
within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1)
|
128 |
-
within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1)
|
129 |
-
return within_X and within_Y
|
130 |
-
|
131 |
-
|
132 |
-
def expand_bbox_simple(bbox, percentage):
|
133 |
-
x0, y0, x1, y1 = bbox.astype(float)
|
134 |
-
width = x1 - x0
|
135 |
-
height = y1 - y0
|
136 |
-
x_c = int(x0) + width // 2
|
137 |
-
y_c = int(y0) + height // 2
|
138 |
-
avg_size = max(width, height)
|
139 |
-
new_width = avg_size * (1 + percentage)
|
140 |
-
x0 = x_c - new_width // 2
|
141 |
-
y0 = y_c - new_width // 2
|
142 |
-
x1 = x_c + new_width // 2
|
143 |
-
y1 = y_c + new_width // 2
|
144 |
-
return np.array([x0, y0, x1, y1]).astype(int)
|
145 |
-
|
146 |
-
|
147 |
-
def pad_image(im, bbox, pad_value):
|
148 |
-
x0, y0, x1, y1 = bbox
|
149 |
-
if x0 < 0:
|
150 |
-
pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]),
|
151 |
-
dtype=np.uint8) + pad_value
|
152 |
-
im = np.concatenate((pad_im, im), axis=1)
|
153 |
-
x1 += abs(x0)
|
154 |
-
x0 = 0
|
155 |
-
if y0 < 0:
|
156 |
-
pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]),
|
157 |
-
dtype=np.uint8) + pad_value
|
158 |
-
im = np.concatenate((pad_im, im), axis=0)
|
159 |
-
y1 += abs(y0)
|
160 |
-
y0 = 0
|
161 |
-
if x1 >= im.shape[1]:
|
162 |
-
pad_im = np.zeros(
|
163 |
-
(im.shape[0], x1 - im.shape[1] + 1, im.shape[2]),
|
164 |
-
dtype=np.uint8) + pad_value
|
165 |
-
im = np.concatenate((im, pad_im), axis=1)
|
166 |
-
if y1 >= im.shape[0]:
|
167 |
-
pad_im = np.zeros(
|
168 |
-
(y1 - im.shape[0] + 1, im.shape[1], im.shape[2]),
|
169 |
-
dtype=np.uint8) + pad_value
|
170 |
-
im = np.concatenate((im, pad_im), axis=0)
|
171 |
-
return im[y0:y1, x0:x1]
|
172 |
-
|
173 |
-
|
174 |
-
def clip_box(bbox, im):
|
175 |
-
bbox[0] = max(0, bbox[0])
|
176 |
-
bbox[1] = max(0, bbox[1])
|
177 |
-
bbox[2] = min(im.shape[1] - 1, bbox[2])
|
178 |
-
bbox[3] = min(im.shape[0] - 1, bbox[3])
|
179 |
-
return bbox
|
180 |
-
|
181 |
-
|
182 |
-
def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True):
|
183 |
-
outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0]
|
184 |
-
if simple_expand or (outside_im and pad_im):
|
185 |
-
return pad_image(im, bbox, pad_value)
|
186 |
-
bbox = clip_box(bbox, im)
|
187 |
-
x0, y0, x1, y1 = bbox
|
188 |
-
return im[y0:y1, x0:x1]
|
189 |
-
|
190 |
-
|
191 |
-
def expand_bbox(
|
192 |
-
bbox_ltrb, imshape, simple_expand, default_to_simple=False,
|
193 |
-
expansion_factor=0.35):
|
194 |
-
assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox.shape}"
|
195 |
-
bbox = bbox_ltrb.astype(float)
|
196 |
-
# FDF256 uses simple expand with ratio 0.4
|
197 |
-
if simple_expand:
|
198 |
-
return expand_bbox_simple(bbox, 0.4)
|
199 |
-
try:
|
200 |
-
return expand_bounding_box(bbox, expansion_factor, imshape)
|
201 |
-
except AssertionError:
|
202 |
-
return expand_bbox_simple(bbox, expansion_factor * 2)
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/detection/cse_mask_face_detector.py
DELETED
@@ -1,116 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import lzma
|
3 |
-
import tops
|
4 |
-
from pathlib import Path
|
5 |
-
from dp2.detection.base import BaseDetector
|
6 |
-
from .utils import combine_cse_maskrcnn_dets
|
7 |
-
from face_detection import build_detector as build_face_detector
|
8 |
-
from .models.cse import CSEDetector
|
9 |
-
from .models.mask_rcnn import MaskRCNNDetector
|
10 |
-
from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection
|
11 |
-
from tops import logger
|
12 |
-
|
13 |
-
|
14 |
-
def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
|
15 |
-
assert len(box1.shape) == 2
|
16 |
-
assert len(box2.shape) == 2
|
17 |
-
box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
|
18 |
-
# This can be batched
|
19 |
-
for i, box in enumerate(box1):
|
20 |
-
is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
|
21 |
-
is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
|
22 |
-
is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
|
23 |
-
box1_inside[i] = is_outside.logical_not().any()
|
24 |
-
return box1_inside
|
25 |
-
|
26 |
-
|
27 |
-
class CSeMaskFaceDetector(BaseDetector):
|
28 |
-
|
29 |
-
def __init__(
|
30 |
-
self,
|
31 |
-
mask_rcnn_cfg,
|
32 |
-
face_detector_cfg: dict,
|
33 |
-
cse_cfg: dict,
|
34 |
-
face_post_process_cfg: dict,
|
35 |
-
cse_post_process_cfg,
|
36 |
-
score_threshold: float,
|
37 |
-
**kwargs
|
38 |
-
) -> None:
|
39 |
-
super().__init__(**kwargs)
|
40 |
-
self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
|
41 |
-
if "confidence_threshold" not in face_detector_cfg:
|
42 |
-
face_detector_cfg["confidence_threshold"] = score_threshold
|
43 |
-
if "score_thres" not in cse_cfg:
|
44 |
-
cse_cfg["score_thres"] = score_threshold
|
45 |
-
self.cse_detector = CSEDetector(**cse_cfg)
|
46 |
-
self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True)
|
47 |
-
self.cse_post_process_cfg = cse_post_process_cfg
|
48 |
-
self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
|
49 |
-
self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold")
|
50 |
-
self.face_post_process_cfg = face_post_process_cfg
|
51 |
-
|
52 |
-
def __call__(self, *args, **kwargs):
|
53 |
-
return self.forward(*args, **kwargs)
|
54 |
-
|
55 |
-
def _detect_faces(self, im: torch.Tensor):
|
56 |
-
H, W = im.shape[1:]
|
57 |
-
im = im.float() - self.face_mean
|
58 |
-
im = self.face_detector.resize(im[None], 1.0)
|
59 |
-
boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
|
60 |
-
boxes_XYXY[:, [0, 2]] *= W
|
61 |
-
boxes_XYXY[:, [1, 3]] *= H
|
62 |
-
return boxes_XYXY.round().long()
|
63 |
-
|
64 |
-
def load_from_cache(self, cache_path: Path):
|
65 |
-
logger.log(f"Loading detection from cache path: {cache_path}",)
|
66 |
-
with lzma.open(cache_path, "rb") as fp:
|
67 |
-
state_dict = torch.load(fp, map_location="cpu")
|
68 |
-
kwargs = dict(
|
69 |
-
post_process_cfg=self.cse_post_process_cfg,
|
70 |
-
embed_map=self.cse_detector.embed_map,
|
71 |
-
**self.face_post_process_cfg
|
72 |
-
)
|
73 |
-
return [
|
74 |
-
state["cls"].from_state_dict(**kwargs, state_dict=state)
|
75 |
-
for state in state_dict
|
76 |
-
]
|
77 |
-
|
78 |
-
@torch.no_grad()
|
79 |
-
def forward(self, im: torch.Tensor):
|
80 |
-
maskrcnn_dets = self.mask_rcnn(im)
|
81 |
-
cse_dets = self.cse_detector(im)
|
82 |
-
embed_map = self.cse_detector.embed_map
|
83 |
-
print("Calling face detector.")
|
84 |
-
face_boxes = self._detect_faces(im).cpu()
|
85 |
-
maskrcnn_person = {
|
86 |
-
k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items()
|
87 |
-
}
|
88 |
-
maskrcnn_other = {
|
89 |
-
k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items()
|
90 |
-
}
|
91 |
-
maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"])
|
92 |
-
combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets(
|
93 |
-
maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold)
|
94 |
-
|
95 |
-
persons_with_cse = CSEPersonDetection(
|
96 |
-
combined_segmentation, cse_dets, **self.cse_post_process_cfg,
|
97 |
-
embed_map=embed_map,orig_imshape_CHW=im.shape
|
98 |
-
)
|
99 |
-
persons_with_cse.pre_process()
|
100 |
-
not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]]
|
101 |
-
persons_without_cse = PersonDetection(
|
102 |
-
maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg,
|
103 |
-
orig_imshape_CHW=im.shape
|
104 |
-
)
|
105 |
-
persons_without_cse.pre_process()
|
106 |
-
|
107 |
-
face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or(
|
108 |
-
box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes)
|
109 |
-
)
|
110 |
-
face_boxes = face_boxes[face_boxes_covered.logical_not()]
|
111 |
-
face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
|
112 |
-
|
113 |
-
# Order matters. The anonymizer will anonymize FIFO.
|
114 |
-
# Later detections will overwrite.
|
115 |
-
all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse]
|
116 |
-
return all_detections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/detection/face_detector.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import lzma
|
3 |
-
import tops
|
4 |
-
from pathlib import Path
|
5 |
-
from dp2.detection.base import BaseDetector
|
6 |
-
from face_detection import build_detector as build_face_detector
|
7 |
-
from .structures import FaceDetection
|
8 |
-
from tops import logger
|
9 |
-
|
10 |
-
|
11 |
-
def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
|
12 |
-
assert len(box1.shape) == 2
|
13 |
-
assert len(box2.shape) == 2
|
14 |
-
box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
|
15 |
-
# This can be batched
|
16 |
-
for i, box in enumerate(box1):
|
17 |
-
is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
|
18 |
-
is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
|
19 |
-
is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
|
20 |
-
box1_inside[i] = is_outside.logical_not().any()
|
21 |
-
return box1_inside
|
22 |
-
|
23 |
-
|
24 |
-
class FaceDetector(BaseDetector):
|
25 |
-
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
face_detector_cfg: dict,
|
29 |
-
score_threshold: float,
|
30 |
-
face_post_process_cfg: dict,
|
31 |
-
**kwargs
|
32 |
-
) -> None:
|
33 |
-
super().__init__(**kwargs)
|
34 |
-
self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
|
35 |
-
self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
|
36 |
-
self.face_post_process_cfg = face_post_process_cfg
|
37 |
-
|
38 |
-
def __call__(self, *args, **kwargs):
|
39 |
-
return self.forward(*args, **kwargs)
|
40 |
-
|
41 |
-
def _detect_faces(self, im: torch.Tensor):
|
42 |
-
H, W = im.shape[1:]
|
43 |
-
im = im.float() - self.face_mean
|
44 |
-
im = self.face_detector.resize(im[None], 1.0)
|
45 |
-
boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
|
46 |
-
boxes_XYXY[:, [0, 2]] *= W
|
47 |
-
boxes_XYXY[:, [1, 3]] *= H
|
48 |
-
return boxes_XYXY.round().long().cpu()
|
49 |
-
|
50 |
-
@torch.no_grad()
|
51 |
-
def forward(self, im: torch.Tensor):
|
52 |
-
face_boxes = self._detect_faces(im)
|
53 |
-
face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
|
54 |
-
return [face_boxes]
|
55 |
-
|
56 |
-
def load_from_cache(self, cache_path: Path):
|
57 |
-
logger.log(f"Loading detection from cache path: {cache_path}")
|
58 |
-
with lzma.open(cache_path, "rb") as fp:
|
59 |
-
state_dict = torch.load(fp)
|
60 |
-
return [
|
61 |
-
state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict
|
62 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/detection/models/__init__.py
DELETED
File without changes
|
dp2/detection/models/cse.py
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from typing import List
|
3 |
-
import tops
|
4 |
-
from torchvision.transforms.functional import InterpolationMode, resize
|
5 |
-
from densepose.data.utils import get_class_to_mesh_name_mapping
|
6 |
-
from densepose import add_densepose_config
|
7 |
-
from densepose.structures import DensePoseEmbeddingPredictorOutput
|
8 |
-
from densepose.vis.extractor import DensePoseOutputsExtractor
|
9 |
-
from densepose.modeling import build_densepose_embedder
|
10 |
-
from detectron2.config import get_cfg
|
11 |
-
from detectron2.data.transforms import ResizeShortestEdge
|
12 |
-
from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer
|
13 |
-
from detectron2.modeling import build_model
|
14 |
-
|
15 |
-
|
16 |
-
model_urls = {
|
17 |
-
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl",
|
18 |
-
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl",
|
19 |
-
}
|
20 |
-
|
21 |
-
|
22 |
-
def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape):
|
23 |
-
assert len(S.shape) == 3
|
24 |
-
H, W = imshape
|
25 |
-
N = len(boxes_XYXY)
|
26 |
-
segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device)
|
27 |
-
boxes_XYXY = boxes_XYXY.long()
|
28 |
-
for i in range(N):
|
29 |
-
x0, y0, x1, y1 = boxes_XYXY[i]
|
30 |
-
assert x0 >= 0 and y0 >= 0
|
31 |
-
assert x1 <= imshape[1]
|
32 |
-
assert y1 <= imshape[0]
|
33 |
-
h = y1 - y0
|
34 |
-
w = x1 - x0
|
35 |
-
segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0
|
36 |
-
return segmentation
|
37 |
-
|
38 |
-
|
39 |
-
class CSEDetector:
|
40 |
-
|
41 |
-
def __init__(
|
42 |
-
self,
|
43 |
-
cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
|
44 |
-
cfg_2_download: List[str] = [
|
45 |
-
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
|
46 |
-
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml",
|
47 |
-
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"],
|
48 |
-
score_thres: float = 0.9,
|
49 |
-
nms_thresh: float = None,
|
50 |
-
) -> None:
|
51 |
-
with tops.logger.capture_log_stdout():
|
52 |
-
cfg = get_cfg()
|
53 |
-
self.device = tops.get_device()
|
54 |
-
add_densepose_config(cfg)
|
55 |
-
cfg_path = tops.download_file(cfg_url)
|
56 |
-
for p in cfg_2_download:
|
57 |
-
tops.download_file(p)
|
58 |
-
with tops.logger.capture_log_stdout():
|
59 |
-
cfg.merge_from_file(cfg_path)
|
60 |
-
assert cfg_url in model_urls, cfg_url
|
61 |
-
model_path = tops.download_file(model_urls[cfg_url])
|
62 |
-
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
|
63 |
-
if nms_thresh is not None:
|
64 |
-
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh
|
65 |
-
cfg.MODEL.WEIGHTS = str(model_path)
|
66 |
-
cfg.MODEL.DEVICE = str(self.device)
|
67 |
-
cfg.freeze()
|
68 |
-
with tops.logger.capture_log_stdout():
|
69 |
-
self.model = build_model(cfg)
|
70 |
-
self.model.eval()
|
71 |
-
DetectionCheckpointer(self.model).load(str(model_path))
|
72 |
-
self.input_format = cfg.INPUT.FORMAT
|
73 |
-
self.densepose_extractor = DensePoseOutputsExtractor()
|
74 |
-
self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
|
75 |
-
|
76 |
-
self.embedder = build_densepose_embedder(cfg)
|
77 |
-
self.mesh_vertex_embeddings = {
|
78 |
-
mesh_name: self.embedder(mesh_name).to(self.device)
|
79 |
-
for mesh_name in self.class_to_mesh_name.values()
|
80 |
-
if self.embedder.has_embeddings(mesh_name)
|
81 |
-
}
|
82 |
-
self.cfg = cfg
|
83 |
-
self.embed_map = self.mesh_vertex_embeddings["smpl_27554"]
|
84 |
-
tops.logger.log("CSEDetector built.")
|
85 |
-
|
86 |
-
def __call__(self, *args, **kwargs):
|
87 |
-
return self.forward(*args, **kwargs)
|
88 |
-
|
89 |
-
def resize_im(self, im):
|
90 |
-
H, W = im.shape[1:]
|
91 |
-
newH, newW = ResizeShortestEdge.get_output_shape(
|
92 |
-
H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
|
93 |
-
return resize(
|
94 |
-
im, (newH, newW), InterpolationMode.BILINEAR, antialias=True)
|
95 |
-
|
96 |
-
@torch.no_grad()
|
97 |
-
def forward(self, im):
|
98 |
-
assert im.dtype == torch.uint8
|
99 |
-
if self.input_format == "BGR":
|
100 |
-
im = im.flip(0)
|
101 |
-
H, W = im.shape[1:]
|
102 |
-
im = self.resize_im(im)
|
103 |
-
output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
|
104 |
-
scores = output.get("scores")
|
105 |
-
if len(scores) == 0:
|
106 |
-
return dict(
|
107 |
-
instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device),
|
108 |
-
instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device),
|
109 |
-
embed_map=self.mesh_vertex_embeddings["smpl_27554"],
|
110 |
-
bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device),
|
111 |
-
im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device),
|
112 |
-
scores=torch.empty((0), dtype=torch.float, device=im.device)
|
113 |
-
)
|
114 |
-
pred_densepose, boxes_xywh, classes = self.densepose_extractor(output)
|
115 |
-
assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose
|
116 |
-
S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes)
|
117 |
-
E = pred_densepose.embedding
|
118 |
-
mesh_name = self.class_to_mesh_name[classes[0]]
|
119 |
-
assert mesh_name == "smpl_27554"
|
120 |
-
x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)]
|
121 |
-
boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1)
|
122 |
-
boxes_XYXY = boxes_XYXY.round_().long()
|
123 |
-
|
124 |
-
non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not()
|
125 |
-
S = S[non_empty_boxes]
|
126 |
-
E = E[non_empty_boxes]
|
127 |
-
boxes_XYXY = boxes_XYXY[non_empty_boxes]
|
128 |
-
scores = scores[non_empty_boxes]
|
129 |
-
im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W])
|
130 |
-
return dict(
|
131 |
-
instance_segmentation=S, instance_embedding=E,
|
132 |
-
bbox_XYXY=boxes_XYXY,
|
133 |
-
im_segmentation=im_segmentation,
|
134 |
-
scores=scores.view(-1))
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/detection/models/keypoint_maskrcnn.py
DELETED
@@ -1,111 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
from detectron2.checkpoint import DetectionCheckpointer
|
4 |
-
from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads
|
5 |
-
from detectron2.data.transforms import ResizeShortestEdge
|
6 |
-
from detectron2.structures import Instances
|
7 |
-
from detectron2 import model_zoo
|
8 |
-
from detectron2.config import instantiate
|
9 |
-
from detectron2.config import LazyCall as L
|
10 |
-
from PIL import Image
|
11 |
-
import tops
|
12 |
-
import functools
|
13 |
-
from torchvision.transforms.functional import resize
|
14 |
-
|
15 |
-
|
16 |
-
def get_rn50_fpn_keypoint_rcnn(weight_path: str):
|
17 |
-
from detectron2.modeling.poolers import ROIPooler
|
18 |
-
from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
|
19 |
-
from detectron2.layers import ShapeSpec
|
20 |
-
model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model
|
21 |
-
model.roi_heads.update(
|
22 |
-
num_classes=1,
|
23 |
-
keypoint_in_features=["p2", "p3", "p4", "p5"],
|
24 |
-
keypoint_pooler=L(ROIPooler)(
|
25 |
-
output_size=14,
|
26 |
-
scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
|
27 |
-
sampling_ratio=0,
|
28 |
-
pooler_type="ROIAlignV2",
|
29 |
-
),
|
30 |
-
keypoint_head=L(KRCNNConvDeconvUpsampleHead)(
|
31 |
-
input_shape=ShapeSpec(channels=256, width=14, height=14),
|
32 |
-
num_keypoints=17,
|
33 |
-
conv_dims=[512] * 8,
|
34 |
-
loss_normalizer="visible",
|
35 |
-
),
|
36 |
-
)
|
37 |
-
|
38 |
-
# Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2.
|
39 |
-
# 1000 proposals per-image is found to hurt box AP.
|
40 |
-
# Therefore we increase it to 1500 per-image.
|
41 |
-
model.proposal_generator.post_nms_topk = (1500, 1000)
|
42 |
-
|
43 |
-
# Keypoint AP degrades (though box AP improves) when using plain L1 loss
|
44 |
-
model.roi_heads.box_predictor.smooth_l1_beta = 0.5
|
45 |
-
model = instantiate(model)
|
46 |
-
|
47 |
-
dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader
|
48 |
-
test_transform = instantiate(dataloader.test.mapper.augmentations)
|
49 |
-
DetectionCheckpointer(model).load(weight_path)
|
50 |
-
return model, test_transform
|
51 |
-
|
52 |
-
|
53 |
-
models = {
|
54 |
-
"rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth")
|
55 |
-
}
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
class KeypointMaskRCNN:
|
61 |
-
|
62 |
-
def __init__(self, model_name: str, score_threshold: float) -> None:
|
63 |
-
assert model_name in models, f"Did not find {model_name} in models"
|
64 |
-
model, test_transform = models[model_name]()
|
65 |
-
self.model = model.eval().to(tops.get_device())
|
66 |
-
if isinstance(self.model.roi_heads, CascadeROIHeads):
|
67 |
-
for head in self.model.roi_heads.box_predictors:
|
68 |
-
assert hasattr(head, "test_score_thresh")
|
69 |
-
head.test_score_thresh = score_threshold
|
70 |
-
else:
|
71 |
-
assert isinstance(self.model.roi_heads, StandardROIHeads)
|
72 |
-
assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh")
|
73 |
-
self.model.roi_heads.box_predictor.test_score_thresh = score_threshold
|
74 |
-
|
75 |
-
self.test_transform = test_transform
|
76 |
-
assert len(self.test_transform) == 1
|
77 |
-
self.test_transform = self.test_transform[0]
|
78 |
-
assert isinstance(self.test_transform, ResizeShortestEdge)
|
79 |
-
assert self.test_transform.interp == Image.BILINEAR
|
80 |
-
self.image_format = self.model.input_format
|
81 |
-
|
82 |
-
def resize_im(self, im):
|
83 |
-
H, W = im.shape[-2:]
|
84 |
-
if self.test_transform.is_range:
|
85 |
-
size = np.random.randint(self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1)
|
86 |
-
else:
|
87 |
-
size = np.random.choice(self.test_transform.short_edge_length)
|
88 |
-
newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size)
|
89 |
-
return resize(
|
90 |
-
im, (newH, newW), antialias=True)
|
91 |
-
|
92 |
-
def __call__(self, *args, **kwargs):
|
93 |
-
return self.forward(*args, **kwargs)
|
94 |
-
|
95 |
-
@torch.no_grad()
|
96 |
-
def forward(self, im: torch.Tensor) -> Instances:
|
97 |
-
assert im.ndim == 3
|
98 |
-
if self.image_format == "BGR":
|
99 |
-
im = im.flip(0)
|
100 |
-
H, W = im.shape[-2:]
|
101 |
-
im = self.resize_im(im)
|
102 |
-
im = im.float()
|
103 |
-
inputs = dict(image=im, height=H, width=W)
|
104 |
-
# instances contains
|
105 |
-
# dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps'])
|
106 |
-
instances = self.model([inputs])[0]["instances"]
|
107 |
-
return dict(
|
108 |
-
scores=instances.get("scores").cpu(),
|
109 |
-
segmentation=instances.get("pred_masks").cpu(),
|
110 |
-
keypoints=instances.get("pred_keypoints").cpu()
|
111 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/detection/models/mask_rcnn.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import tops
|
3 |
-
from detectron2.modeling import build_model
|
4 |
-
from detectron2.checkpoint import DetectionCheckpointer
|
5 |
-
from detectron2.structures import Boxes
|
6 |
-
from detectron2.data import MetadataCatalog
|
7 |
-
from detectron2 import model_zoo
|
8 |
-
from typing import Dict
|
9 |
-
from detectron2.data.transforms import ResizeShortestEdge
|
10 |
-
from torchvision.transforms.functional import resize
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
model_urls = {
|
15 |
-
"COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml": "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl",
|
16 |
-
|
17 |
-
}
|
18 |
-
class MaskRCNNDetector:
|
19 |
-
|
20 |
-
def __init__(
|
21 |
-
self,
|
22 |
-
cfg_name: str = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml",
|
23 |
-
score_thres: float = 0.9,
|
24 |
-
class_filter=["person"], #["car", "bicycle","truck", "bus", "backpack"]
|
25 |
-
fp16_inference: bool = False
|
26 |
-
) -> None:
|
27 |
-
cfg = model_zoo.get_config(cfg_name)
|
28 |
-
cfg.MODEL.DEVICE = str(tops.get_device())
|
29 |
-
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
|
30 |
-
cfg.freeze()
|
31 |
-
self.cfg = cfg
|
32 |
-
with tops.logger.capture_log_stdout():
|
33 |
-
self.model = build_model(cfg)
|
34 |
-
DetectionCheckpointer(self.model).load(model_urls[cfg_name])
|
35 |
-
self.model.eval()
|
36 |
-
self.input_format = cfg.INPUT.FORMAT
|
37 |
-
self.class_names = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
|
38 |
-
self.class_to_keep = set([self.class_names.index(cls_) for cls_ in class_filter])
|
39 |
-
self.person_class = self.class_names.index("person")
|
40 |
-
self.fp16_inference = fp16_inference
|
41 |
-
tops.logger.log("Mask R-CNN built.")
|
42 |
-
|
43 |
-
def __call__(self, *args, **kwargs):
|
44 |
-
return self.forward(*args, **kwargs)
|
45 |
-
|
46 |
-
def resize_im(self, im):
|
47 |
-
H, W = im.shape[1:]
|
48 |
-
newH, newW = ResizeShortestEdge.get_output_shape(
|
49 |
-
H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
|
50 |
-
return resize(
|
51 |
-
im, (newH, newW), antialias=True)
|
52 |
-
|
53 |
-
@torch.no_grad()
|
54 |
-
def forward(self, im: torch.Tensor):
|
55 |
-
if self.input_format == "BGR":
|
56 |
-
im = im.flip(0)
|
57 |
-
else:
|
58 |
-
assert self.input_format == "RGB"
|
59 |
-
H, W = im.shape[-2:]
|
60 |
-
im = self.resize_im(im)
|
61 |
-
with torch.cuda.amp.autocast(enabled=self.fp16_inference):
|
62 |
-
output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
|
63 |
-
scores = output.get("scores")
|
64 |
-
N = len(scores)
|
65 |
-
classes = output.get("pred_classes")
|
66 |
-
idx2keep = [i for i in range(N) if classes[i].tolist() in self.class_to_keep]
|
67 |
-
classes = classes[idx2keep]
|
68 |
-
assert isinstance(output.get("pred_boxes"), Boxes)
|
69 |
-
segmentation = output.get("pred_masks")[idx2keep]
|
70 |
-
assert segmentation.dtype == torch.bool
|
71 |
-
is_person = classes == self.person_class
|
72 |
-
return {
|
73 |
-
"scores": output.get("scores")[idx2keep],
|
74 |
-
"segmentation": segmentation,
|
75 |
-
"classes": output.get("pred_classes")[idx2keep],
|
76 |
-
"is_person": is_person
|
77 |
-
}
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/detection/person_detector.py
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import lzma
|
3 |
-
from dp2.detection.base import BaseDetector
|
4 |
-
from .utils import combine_cse_maskrcnn_dets
|
5 |
-
from .models.cse import CSEDetector
|
6 |
-
from .models.mask_rcnn import MaskRCNNDetector
|
7 |
-
from .models.keypoint_maskrcnn import KeypointMaskRCNN
|
8 |
-
from .structures import CSEPersonDetection, PersonDetection
|
9 |
-
from pathlib import Path
|
10 |
-
|
11 |
-
|
12 |
-
class CSEPersonDetector(BaseDetector):
|
13 |
-
def __init__(
|
14 |
-
self,
|
15 |
-
score_threshold: float,
|
16 |
-
mask_rcnn_cfg: dict,
|
17 |
-
cse_cfg: dict,
|
18 |
-
cse_post_process_cfg: dict,
|
19 |
-
**kwargs
|
20 |
-
) -> None:
|
21 |
-
super().__init__(**kwargs)
|
22 |
-
self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
|
23 |
-
self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold)
|
24 |
-
self.post_process_cfg = cse_post_process_cfg
|
25 |
-
self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold")
|
26 |
-
|
27 |
-
def __call__(self, *args, **kwargs):
|
28 |
-
return self.forward(*args, **kwargs)
|
29 |
-
|
30 |
-
def load_from_cache(self, cache_path: Path):
|
31 |
-
with lzma.open(cache_path, "rb") as fp:
|
32 |
-
state_dict = torch.load(fp)
|
33 |
-
kwargs = dict(
|
34 |
-
post_process_cfg=self.post_process_cfg,
|
35 |
-
embed_map=self.cse_detector.embed_map,
|
36 |
-
)
|
37 |
-
return [
|
38 |
-
state["cls"].from_state_dict(**kwargs, state_dict=state)
|
39 |
-
for state in state_dict
|
40 |
-
]
|
41 |
-
|
42 |
-
@torch.no_grad()
|
43 |
-
def forward(self, im: torch.Tensor, cse_dets=None):
|
44 |
-
mask_dets = self.mask_rcnn(im)
|
45 |
-
if cse_dets is None:
|
46 |
-
cse_dets = self.cse_detector(im)
|
47 |
-
segmentation = mask_dets["segmentation"]
|
48 |
-
segmentation, cse_dets, _ = combine_cse_maskrcnn_dets(
|
49 |
-
segmentation, cse_dets, self.iou_combine_threshold
|
50 |
-
)
|
51 |
-
det = CSEPersonDetection(
|
52 |
-
segmentation=segmentation,
|
53 |
-
cse_dets=cse_dets,
|
54 |
-
embed_map=self.cse_detector.embed_map,
|
55 |
-
orig_imshape_CHW=im.shape,
|
56 |
-
**self.post_process_cfg
|
57 |
-
)
|
58 |
-
return [det]
|
59 |
-
|
60 |
-
|
61 |
-
class MaskRCNNPersonDetector(BaseDetector):
|
62 |
-
def __init__(
|
63 |
-
self,
|
64 |
-
score_threshold: float,
|
65 |
-
mask_rcnn_cfg: dict,
|
66 |
-
cse_post_process_cfg: dict,
|
67 |
-
**kwargs
|
68 |
-
) -> None:
|
69 |
-
super().__init__(**kwargs)
|
70 |
-
self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
|
71 |
-
self.post_process_cfg = cse_post_process_cfg
|
72 |
-
|
73 |
-
def __call__(self, *args, **kwargs):
|
74 |
-
return self.forward(*args, **kwargs)
|
75 |
-
|
76 |
-
def load_from_cache(self, cache_path: Path):
|
77 |
-
with lzma.open(cache_path, "rb") as fp:
|
78 |
-
state_dict = torch.load(fp)
|
79 |
-
kwargs = dict(
|
80 |
-
post_process_cfg=self.post_process_cfg,
|
81 |
-
)
|
82 |
-
return [
|
83 |
-
state["cls"].from_state_dict(**kwargs, state_dict=state)
|
84 |
-
for state in state_dict
|
85 |
-
]
|
86 |
-
|
87 |
-
@torch.no_grad()
|
88 |
-
def forward(self, im: torch.Tensor):
|
89 |
-
mask_dets = self.mask_rcnn(im)
|
90 |
-
segmentation = mask_dets["segmentation"]
|
91 |
-
det = PersonDetection(
|
92 |
-
segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape
|
93 |
-
)
|
94 |
-
return [det]
|
95 |
-
|
96 |
-
|
97 |
-
class KeypointMaskRCNNPersonDetector(BaseDetector):
|
98 |
-
def __init__(
|
99 |
-
self,
|
100 |
-
score_threshold: float,
|
101 |
-
mask_rcnn_cfg: dict,
|
102 |
-
cse_post_process_cfg: dict,
|
103 |
-
**kwargs
|
104 |
-
) -> None:
|
105 |
-
super().__init__(**kwargs)
|
106 |
-
self.mask_rcnn = KeypointMaskRCNN(
|
107 |
-
**mask_rcnn_cfg, score_threshold=score_threshold
|
108 |
-
)
|
109 |
-
self.post_process_cfg = cse_post_process_cfg
|
110 |
-
|
111 |
-
def __call__(self, *args, **kwargs):
|
112 |
-
return self.forward(*args, **kwargs)
|
113 |
-
|
114 |
-
def load_from_cache(self, cache_path: Path):
|
115 |
-
with lzma.open(cache_path, "rb") as fp:
|
116 |
-
state_dict = torch.load(fp)
|
117 |
-
kwargs = dict(
|
118 |
-
post_process_cfg=self.post_process_cfg,
|
119 |
-
)
|
120 |
-
return [
|
121 |
-
state["cls"].from_state_dict(**kwargs, state_dict=state)
|
122 |
-
for state in state_dict
|
123 |
-
]
|
124 |
-
|
125 |
-
@torch.no_grad()
|
126 |
-
def forward(self, im: torch.Tensor):
|
127 |
-
mask_dets = self.mask_rcnn(im)
|
128 |
-
segmentation = mask_dets["segmentation"]
|
129 |
-
det = PersonDetection(
|
130 |
-
segmentation,
|
131 |
-
**self.post_process_cfg,
|
132 |
-
orig_imshape_CHW=im.shape,
|
133 |
-
keypoints=mask_dets["keypoints"]
|
134 |
-
)
|
135 |
-
return [det]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dp2/detection/structures.py
DELETED
@@ -1,464 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
-
from dp2 import utils
|
4 |
-
from dp2.utils import vis_utils, crop_box
|
5 |
-
from .utils import (
|
6 |
-
cut_pad_resize, masks_to_boxes,
|
7 |
-
get_kernel, transform_embedding, initialize_cse_boxes
|
8 |
-
)
|
9 |
-
from .box_utils import get_expanded_bbox, include_box
|
10 |
-
import torchvision
|
11 |
-
import tops
|
12 |
-
from .box_utils_fdf import expand_bbox as expand_bbox_fdf
|
13 |
-
|
14 |
-
|
15 |
-
class VehicleDetection:
|
16 |
-
|
17 |
-
def __init__(self, segmentation: torch.BoolTensor) -> None:
|
18 |
-
self.segmentation = segmentation
|
19 |
-
self.boxes = masks_to_boxes(segmentation)
|
20 |
-
assert self.boxes.shape[1] == 4, self.boxes.shape
|
21 |
-
self.n_detections = self.segmentation.shape[0]
|
22 |
-
area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0])
|
23 |
-
|
24 |
-
sorted_idx = torch.argsort(area, descending=True)
|
25 |
-
self.segmentation = self.segmentation[sorted_idx]
|
26 |
-
self.boxes = self.boxes[sorted_idx].cpu()
|
27 |
-
|
28 |
-
def pre_process(self):
|
29 |
-
pass
|
30 |
-
|
31 |
-
def get_crop(self, idx: int, im):
|
32 |
-
assert idx < len(self)
|
33 |
-
box = self.boxes[idx]
|
34 |
-
im = crop_box(self.im, box)
|
35 |
-
mask = crop_box(self.segmentation[idx])
|
36 |
-
mask = mask == 0
|
37 |
-
return dict(img=im, mask=mask.float(), boxes=box)
|
38 |
-
|
39 |
-
def visualize(self, im):
|
40 |
-
if len(self) == 0:
|
41 |
-
return im
|
42 |
-
im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not())
|
43 |
-
return im
|
44 |
-
|
45 |
-
def __len__(self):
|
46 |
-
return self.n_detections
|
47 |
-
|
48 |
-
@staticmethod
|
49 |
-
def from_state_dict(state_dict, **kwargs):
|
50 |
-
numel = np.prod(state_dict["shape"])
|
51 |
-
arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel)
|
52 |
-
segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"])
|
53 |
-
return VehicleDetection(segmentation)
|
54 |
-
|
55 |
-
def state_dict(self, **kwargs):
|
56 |
-
segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy()))
|
57 |
-
return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape)
|
58 |
-
|
59 |
-
|
60 |
-
class FaceDetection:
|
61 |
-
|
62 |
-
def __init__(self, boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool, **kwargs) -> None:
|
63 |
-
self.boxes = boxes_ltrb.cpu()
|
64 |
-
assert self.boxes.shape[1] == 4, self.boxes.shape
|
65 |
-
self.target_imsize = tuple(target_imsize)
|
66 |
-
# Sory by area to paste in largest faces last
|
67 |
-
area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
|
68 |
-
idx = area.argsort(descending=False)
|
69 |
-
self.boxes = self.boxes[idx]
|
70 |
-
self.fdf128_expand = fdf128_expand
|
71 |
-
|
72 |
-
def visualize(self, im):
|
73 |
-
if len(self) == 0:
|
74 |
-
return im
|
75 |
-
orig_device = im.device
|
76 |
-
for box in self.boxes:
|
77 |
-
simple_expand = False if self.fdf128_expand else True
|
78 |
-
e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand))
|
79 |
-
im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2)
|
80 |
-
im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
|
81 |
-
|
82 |
-
return im.to(device=orig_device)
|
83 |
-
|
84 |
-
def get_crop(self, idx: int, im):
|
85 |
-
assert idx < len(self)
|
86 |
-
box = self.boxes[idx].numpy()
|
87 |
-
simple_expand = False if self.fdf128_expand else True
|
88 |
-
expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], simple_expand=simple_expand)
|
89 |
-
im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True)
|
90 |
-
area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
|
91 |
-
|
92 |
-
# Find the square mask corresponding to box.
|
93 |
-
box_mask = box.copy().astype(float)
|
94 |
-
box_mask[[0, 2]] -= expanded_boxes[0]
|
95 |
-
box_mask[[1, 3]] -= expanded_boxes[1]
|
96 |
-
|
97 |
-
width = expanded_boxes[2] - expanded_boxes[0]
|
98 |
-
resize_factor = self.target_imsize[0] / width
|
99 |
-
box_mask = (box_mask * resize_factor).astype(int)
|
100 |
-
mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32)
|
101 |
-
crop_box(mask, box_mask).fill_(0)
|
102 |
-
return dict(
|
103 |
-
img=im[None], mask=mask[None],
|
104 |
-
boxes=torch.from_numpy(expanded_boxes).view(1, -1))
|
105 |
-
|
106 |
-
def __len__(self):
|
107 |
-
return len(self.boxes)
|
108 |
-
|
109 |
-
@staticmethod
|
110 |
-
def from_state_dict(state_dict, **kwargs):
|
111 |
-
return FaceDetection(state_dict["boxes"].cpu(), **kwargs)
|
112 |
-
|
113 |
-
def state_dict(self, **kwargs):
|
114 |
-
return dict(boxes=self.boxes, cls=self.__class__)
|
115 |
-
|
116 |
-
def pre_process(self):
|
117 |
-
pass
|
118 |
-
|
119 |
-
|
120 |
-
def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape):
|
121 |
-
"""
|
122 |
-
Dilation happens after padding, which could place dilation in the padded area.
|
123 |
-
Remove this.
|
124 |
-
"""
|
125 |
-
x0, y0, x1, y1 = exp_box
|
126 |
-
H, W = orig_imshape
|
127 |
-
# Padding in original image space
|
128 |
-
p_y0 = max(0, -y0)
|
129 |
-
p_y1 = max(y1 - H, 0)
|
130 |
-
p_x0 = max(0, -x0)
|
131 |
-
p_x1 = max(x1 - W, 0)
|
132 |
-
resize_ratio = mask.shape[-2] / (y1-y0)
|
133 |
-
p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]]
|
134 |
-
mask[..., :p_y0, :] = 0
|
135 |
-
mask[..., :p_x0] = 0
|
136 |
-
mask[..., mask.shape[-2] - p_y1:, :] = 0
|
137 |
-
mask[..., mask.shape[-1] - p_x1:] = 0
|
138 |
-
|
139 |
-
|
140 |
-
class CSEPersonDetection:
|
141 |
-
|
142 |
-
def __init__(self,
|
143 |
-
segmentation, cse_dets,
|
144 |
-
target_imsize,
|
145 |
-
exp_bbox_cfg, exp_bbox_filter,
|
146 |
-
dilation_percentage: float,
|
147 |
-
embed_map: torch.Tensor,
|
148 |
-
orig_imshape_CHW,
|
149 |
-
normalize_embedding: bool) -> None:
|
150 |
-
self.segmentation = segmentation
|
151 |
-
self.cse_dets = cse_dets
|
152 |
-
self.target_imsize = list(target_imsize)
|
153 |
-
self.pre_processed = False
|
154 |
-
self.exp_bbox_cfg = exp_bbox_cfg
|
155 |
-
self.exp_bbox_filter = exp_bbox_filter
|
156 |
-
self.dilation_percentage = dilation_percentage
|
157 |
-
self.embed_map = embed_map
|
158 |
-
self.normalize_embedding = normalize_embedding
|
159 |
-
if self.normalize_embedding:
|
160 |
-
embed_map_mean = self.embed_map.mean(dim=0, keepdim=True)
|
161 |
-
embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
|
162 |
-
self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd
|
163 |
-
self.orig_imshape_CHW = orig_imshape_CHW
|
164 |
-
|
165 |
-
@torch.no_grad()
|
166 |
-
def pre_process(self):
|
167 |
-
if self.pre_processed:
|
168 |
-
return
|
169 |
-
boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu()
|
170 |
-
expanded_boxes = []
|
171 |
-
included_boxes = []
|
172 |
-
for i in range(len(boxes)):
|
173 |
-
exp_box = get_expanded_bbox(
|
174 |
-
boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
|
175 |
-
target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
|
176 |
-
if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
|
177 |
-
continue
|
178 |
-
included_boxes.append(i)
|
179 |
-
expanded_boxes.append(exp_box)
|
180 |
-
expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
|
181 |
-
self.segmentation = self.segmentation[included_boxes]
|
182 |
-
self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()}
|
183 |
-
|
184 |
-
self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
|
185 |
-
area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
|
186 |
-
for i, box in enumerate(expanded_boxes):
|
187 |
-
self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
|
188 |
-
|
189 |
-
dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
|
190 |
-
self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
|
191 |
-
self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
|
192 |
-
[remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))]
|
193 |
-
self.boxes = expanded_boxes.cpu()
|
194 |
-
self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
|
195 |
-
|
196 |
-
self.pre_processed = True
|
197 |
-
self.n_detections = len(self.boxes)
|
198 |
-
self.mask = self.mask.logical_not()
|
199 |
-
|
200 |
-
E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool)
|
201 |
-
self.vertices = torch.zeros_like(E_mask, dtype=torch.long)
|
202 |
-
for i in range(self.n_detections):
|
203 |
-
E_, E_mask[i] = transform_embedding(
|
204 |
-
self.cse_dets["instance_embedding"][i],
|
205 |
-
self.cse_dets["instance_segmentation"][i],
|
206 |
-
self.boxes[i],
|
207 |
-
self.cse_dets["bbox_XYXY"][i].cpu(),
|
208 |
-
self.target_imsize
|
209 |
-
)
|
210 |
-
self.vertices[i] = utils.from_E_to_vertex(E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None]
|
211 |
-
self.E_mask = E_mask
|
212 |
-
|
213 |
-
sorted_idx = torch.argsort(area, descending=False)
|
214 |
-
self.mask = self.mask[sorted_idx]
|
215 |
-
self.boxes = self.boxes[sorted_idx.cpu()]
|
216 |
-
self.vertices = self.vertices[sorted_idx]
|
217 |
-
self.E_mask = self.E_mask[sorted_idx]
|
218 |
-
self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
|
219 |
-
|
220 |
-
def get_crop(self, idx: int, im):
|
221 |
-
self.pre_process()
|
222 |
-
assert idx < len(self)
|
223 |
-
box = self.boxes[idx]
|
224 |
-
mask = self.mask[idx]
|
225 |
-
im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
|
226 |
-
|
227 |
-
vertices_ = self.vertices[idx]
|
228 |
-
E_mask_ = self.E_mask[idx].float()
|
229 |
-
if self.normalize_embedding:
|
230 |
-
embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
|
231 |
-
else:
|
232 |
-
embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
|
233 |
-
|
234 |
-
return dict(
|
235 |
-
img=im,
|
236 |
-
mask=mask.float()[None],
|
237 |
-
boxes=box.reshape(1, -1),
|
238 |
-
E_mask=E_mask_[None],
|
239 |
-
vertices=vertices_[None],
|
240 |
-
embed_map=self.embed_map,
|
241 |
-
embedding=embedding[None],
|
242 |
-
maskrcnn_mask=self.maskrcnn_mask[idx].float()[None]
|
243 |
-
)
|
244 |
-
|
245 |
-
def __len__(self):
|
246 |
-
self.pre_process()
|
247 |
-
return self.n_detections
|
248 |
-
|
249 |
-
def state_dict(self, after_preprocess=False):
|
250 |
-
"""
|
251 |
-
The processed annotations occupy more space than the original detections.
|
252 |
-
"""
|
253 |
-
if not after_preprocess:
|
254 |
-
return {
|
255 |
-
"combined_segmentation": self.segmentation.bool(),
|
256 |
-
"cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(),
|
257 |
-
"cse_instance_embedding": self.cse_dets["instance_embedding"],
|
258 |
-
"cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(),
|
259 |
-
"cls": self.__class__,
|
260 |
-
"orig_imshape_CHW": self.orig_imshape_CHW
|
261 |
-
}
|
262 |
-
self.pre_process()
|
263 |
-
return dict(
|
264 |
-
E_mask=torch.from_numpy(np.packbits(self.E_mask.bool().cpu().numpy())),
|
265 |
-
mask=torch.from_numpy(np.packbits(self.mask.bool().cpu().numpy())),
|
266 |
-
maskrcnn_mask=torch.from_numpy(np.packbits(self.maskrcnn_mask.bool().cpu().numpy())),
|
267 |
-
vertices=self.vertices.to(torch.int16).cpu(),
|
268 |
-
cls=self.__class__,
|
269 |
-
boxes=self.boxes,
|
270 |
-
orig_imshape_CHW=self.orig_imshape_CHW,
|
271 |
-
)
|
272 |
-
|
273 |
-
@staticmethod
|
274 |
-
def from_state_dict(
|
275 |
-
state_dict, embed_map,
|
276 |
-
post_process_cfg, **kwargs):
|
277 |
-
after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict
|
278 |
-
if after_preprocess:
|
279 |
-
detection = CSEPersonDetection(
|
280 |
-
segmentation=None, cse_dets=None, embed_map=embed_map,
|
281 |
-
orig_imshape_CHW=state_dict["orig_imshape_CHW"],
|
282 |
-
**post_process_cfg)
|
283 |
-
detection.vertices = tops.to_cuda(state_dict["vertices"].long())
|
284 |
-
numel = np.prod(detection.vertices.shape)
|
285 |
-
detection.E_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["E_mask"].numpy(), count=numel))).view(*detection.vertices.shape)
|
286 |
-
detection.mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["mask"].numpy(), count=numel))).view(*detection.vertices.shape)
|
287 |
-
detection.maskrcnn_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["maskrcnn_mask"].numpy(), count=numel))).view(*detection.vertices.shape)
|
288 |
-
detection.n_detections = len(detection.mask)
|
289 |
-
detection.pre_processed = True
|
290 |
-
|
291 |
-
if isinstance(state_dict["boxes"], np.ndarray):
|
292 |
-
state_dict["boxes"] = torch.from_numpy(state_dict["boxes"])
|
293 |
-
detection.boxes = state_dict["boxes"]
|
294 |
-
return detection
|
295 |
-
|
296 |
-
cse_dets = dict(
|
297 |
-
instance_segmentation=state_dict["cse_instance_segmentation"],
|
298 |
-
instance_embedding=state_dict["cse_instance_embedding"],
|
299 |
-
embed_map=embed_map,
|
300 |
-
bbox_XYXY=state_dict["cse_bbox_XYXY"])
|
301 |
-
cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()}
|
302 |
-
|
303 |
-
segmentation = state_dict["combined_segmentation"]
|
304 |
-
return CSEPersonDetection(
|
305 |
-
segmentation, cse_dets, embed_map=embed_map,
|
306 |
-
orig_imshape_CHW=state_dict["orig_imshape_CHW"],
|
307 |
-
**post_process_cfg)
|
308 |
-
|
309 |
-
def visualize(self, im):
|
310 |
-
self.pre_process()
|
311 |
-
if len(self) == 0:
|
312 |
-
return im
|
313 |
-
im = vis_utils.draw_cropped_masks(
|
314 |
-
im.clone(), self.mask, self.boxes, visualize_instances=False)
|
315 |
-
E = self.embed_map[self.vertices.long()].squeeze(1).permute(0,3, 1, 2)
|
316 |
-
im = im.to(E.device)
|
317 |
-
im = vis_utils.draw_cse_all(
|
318 |
-
E, self.E_mask.squeeze(1).bool(), im,
|
319 |
-
self.boxes, self.embed_map)
|
320 |
-
im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
|
321 |
-
return im
|
322 |
-
|
323 |
-
|
324 |
-
def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes):
|
325 |
-
keypoints = keypoints.clone()
|
326 |
-
N = boxes.shape[0]
|
327 |
-
tops.assert_shape(keypoints, (N, None, 3))
|
328 |
-
tops.assert_shape(boxes, (N, 4))
|
329 |
-
x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T]
|
330 |
-
|
331 |
-
w = x1 - x0
|
332 |
-
h = y1 - y0
|
333 |
-
keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w
|
334 |
-
keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h
|
335 |
-
check_outside = lambda x: (x < 0).logical_or(x > 1)
|
336 |
-
is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1]))
|
337 |
-
keypoints[:, :, 2] = keypoints[:, :, 2] >= 0
|
338 |
-
keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not())
|
339 |
-
return keypoints
|
340 |
-
|
341 |
-
|
342 |
-
class PersonDetection:
|
343 |
-
|
344 |
-
def __init__(
|
345 |
-
self,
|
346 |
-
segmentation,
|
347 |
-
target_imsize,
|
348 |
-
exp_bbox_cfg, exp_bbox_filter,
|
349 |
-
dilation_percentage: float,
|
350 |
-
orig_imshape_CHW,
|
351 |
-
keypoints=None,
|
352 |
-
**kwargs) -> None:
|
353 |
-
self.segmentation = segmentation
|
354 |
-
self.target_imsize = list(target_imsize)
|
355 |
-
self.pre_processed = False
|
356 |
-
self.exp_bbox_cfg = exp_bbox_cfg
|
357 |
-
self.exp_bbox_filter = exp_bbox_filter
|
358 |
-
self.dilation_percentage = dilation_percentage
|
359 |
-
self.orig_imshape_CHW = orig_imshape_CHW
|
360 |
-
self.keypoints = keypoints
|
361 |
-
|
362 |
-
@torch.no_grad()
|
363 |
-
def pre_process(self):
|
364 |
-
if self.pre_processed:
|
365 |
-
return
|
366 |
-
boxes = masks_to_boxes(self.segmentation).cpu()
|
367 |
-
expanded_boxes = []
|
368 |
-
included_boxes = []
|
369 |
-
for i in range(len(boxes)):
|
370 |
-
exp_box = get_expanded_bbox(
|
371 |
-
boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
|
372 |
-
target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
|
373 |
-
if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
|
374 |
-
continue
|
375 |
-
included_boxes.append(i)
|
376 |
-
expanded_boxes.append(exp_box)
|
377 |
-
expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
|
378 |
-
self.segmentation = self.segmentation[included_boxes]
|
379 |
-
if self.keypoints is not None:
|
380 |
-
self.keypoints = self.keypoints[included_boxes]
|
381 |
-
area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
|
382 |
-
self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
|
383 |
-
for i, box in enumerate(expanded_boxes):
|
384 |
-
self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
|
385 |
-
if self.keypoints is not None:
|
386 |
-
self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes)
|
387 |
-
dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
|
388 |
-
self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
|
389 |
-
self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
|
390 |
-
|
391 |
-
[remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))]
|
392 |
-
self.boxes = expanded_boxes
|
393 |
-
self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
|
394 |
-
|
395 |
-
self.pre_processed = True
|
396 |
-
self.n_detections = len(self.boxes)
|
397 |
-
self.mask = self.mask.logical_not()
|
398 |
-
|
399 |
-
sorted_idx = torch.argsort(area, descending=False)
|
400 |
-
self.mask = self.mask[sorted_idx]
|
401 |
-
self.boxes = self.boxes[sorted_idx.cpu()]
|
402 |
-
self.segmentation = self.segmentation[sorted_idx]
|
403 |
-
self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
|
404 |
-
if self.keypoints is not None:
|
405 |
-
self.keypoints = self.keypoints[sorted_idx]
|
406 |
-
|
407 |
-
def get_crop(self, idx: int, im: torch.Tensor):
|
408 |
-
assert idx < len(self)
|
409 |
-
self.pre_process()
|
410 |
-
box = self.boxes[idx]
|
411 |
-
mask = self.mask[idx][None].float()
|
412 |
-
im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
|
413 |
-
batch = dict(
|
414 |
-
img=im, mask=mask, boxes=box.reshape(1, -1),
|
415 |
-
maskrcnn_mask=self.maskrcnn_mask[idx][None].float())
|
416 |
-
if self.keypoints is not None:
|
417 |
-
batch["keypoints"] = self.keypoints[idx:idx+1]
|
418 |
-
return batch
|
419 |
-
|
420 |
-
def __len__(self):
|
421 |
-
self.pre_process()
|
422 |
-
return self.n_detections
|
423 |
-
|
424 |
-
def state_dict(self, **kwargs):
|
425 |
-
return dict(
|
426 |
-
segmentation=self.segmentation.bool(),
|
427 |
-
cls=self.__class__,
|
428 |
-
orig_imshape_CHW=self.orig_imshape_CHW,
|
429 |
-
keypoints=self.keypoints
|
430 |
-
)
|
431 |
-
|
432 |
-
@staticmethod
|
433 |
-
def from_state_dict(
|
434 |
-
state_dict,
|
435 |
-
post_process_cfg, **kwargs):
|
436 |
-
return PersonDetection(
|
437 |
-
state_dict["segmentation"],
|
438 |
-
orig_imshape_CHW=state_dict["orig_imshape_CHW"],
|
439 |
-
**post_process_cfg,
|
440 |
-
keypoints=state_dict["keypoints"])
|
441 |
-
|
442 |
-
def visualize(self, im):
|
443 |
-
self.pre_process()
|
444 |
-
im = im.cpu()
|
445 |
-
if len(self) == 0:
|
446 |
-
return im
|
447 |
-
im = vis_utils.draw_cropped_masks(im.clone(), self.mask, self.boxes, visualize_instances=False)
|
448 |
-
im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes)
|
449 |
-
return im
|
450 |
-
|
451 |
-
|
452 |
-
def get_dilated_boxes(exp_bbox: torch.LongTensor, mask):
|
453 |
-
"""
|
454 |
-
mask: resized mask
|
455 |
-
"""
|
456 |
-
assert exp_bbox.shape[0] == mask.shape[0]
|
457 |
-
boxes = masks_to_boxes(mask.squeeze(1)).cpu()
|
458 |
-
H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0]
|
459 |
-
boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long()
|
460 |
-
boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long()
|
461 |
-
boxes[:, [0, 2]] += exp_bbox[:, 0:1]
|
462 |
-
boxes[:, [1, 3]] += exp_bbox[:, 1:2]
|
463 |
-
return boxes
|
464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|