haakohu commited on
Commit
31c6733
·
1 Parent(s): e24da0e

update demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitmodules +3 -0
  3. app.py +21 -62
  4. configs/anonymizers/FB_cse.py +0 -28
  5. configs/anonymizers/FB_cse_mask.py +0 -29
  6. configs/anonymizers/FB_cse_mask_face.py +0 -29
  7. configs/anonymizers/face.py +0 -18
  8. configs/anonymizers/market1501/blackout.py +0 -8
  9. configs/anonymizers/market1501/person.py +0 -6
  10. configs/anonymizers/market1501/pixelation16.py +0 -8
  11. configs/anonymizers/market1501/pixelation8.py +0 -8
  12. configs/datasets/coco_cse.py +0 -69
  13. configs/datasets/fdf128.py +0 -24
  14. configs/datasets/fdf256.py +0 -69
  15. configs/datasets/fdh.py +0 -89
  16. configs/datasets/utils.py +0 -12
  17. configs/defaults.py +0 -45
  18. configs/discriminators/sg2_discriminator.py +0 -42
  19. configs/fdf/stylegan.py +0 -14
  20. configs/fdf/stylegan_fdf128.py +0 -13
  21. configs/fdh/styleganL.py +0 -16
  22. configs/fdh/styleganL_nocse.py +0 -14
  23. configs/generators/stylegan_unet.py +0 -22
  24. deep_privacy2 +1 -0
  25. dp2/__init__.py +0 -0
  26. dp2/anonymizer/__init__.py +0 -1
  27. dp2/anonymizer/anonymizer.py +0 -159
  28. dp2/data/__init__.py +0 -0
  29. dp2/data/build.py +0 -148
  30. dp2/data/datasets/__init__.py +0 -0
  31. dp2/data/datasets/coco_cse.py +0 -148
  32. dp2/data/datasets/fdf.py +0 -129
  33. dp2/data/datasets/fdh.py +0 -104
  34. dp2/data/transforms/__init__.py +0 -2
  35. dp2/data/transforms/functional.py +0 -61
  36. dp2/data/transforms/stylegan2_transform.py +0 -394
  37. dp2/data/transforms/transforms.py +0 -247
  38. dp2/data/utils.py +0 -102
  39. dp2/detection/__init__.py +0 -3
  40. dp2/detection/base.py +0 -45
  41. dp2/detection/box_utils.py +0 -104
  42. dp2/detection/box_utils_fdf.py +0 -203
  43. dp2/detection/cse_mask_face_detector.py +0 -116
  44. dp2/detection/face_detector.py +0 -62
  45. dp2/detection/models/__init__.py +0 -0
  46. dp2/detection/models/cse.py +0 -135
  47. dp2/detection/models/keypoint_maskrcnn.py +0 -111
  48. dp2/detection/models/mask_rcnn.py +0 -78
  49. dp2/detection/person_detector.py +0 -135
  50. dp2/detection/structures.py +0 -464
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  erling.jpg filter=lfs diff=lfs merge=lfs -text
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  erling.jpg filter=lfs diff=lfs merge=lfs -text
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth filter=lfs diff=lfs merge=lfs -text
38
+ torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "deep_privacy2"]
2
+ path = deep_privacy2
3
+ url = https://github.com/hukkelas/deep_privacy2
app.py CHANGED
@@ -1,78 +1,37 @@
 
 
1
  import os
 
 
 
2
  os.system("pip install --upgrade pip")
3
  os.system("pip install ftfy regex tqdm")
4
- os.system("pip install git+https://github.com/openai/CLIP.git")
5
  os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
6
- os.system("pip install git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
7
- import gradio
8
- import numpy as np
9
- import torch
10
- from PIL import Image
11
  from dp2 import utils
12
- from tops.config import instantiate
13
- import tops
14
- import gradio.inputs
15
-
16
-
17
- cfg_body = utils.load_config("configs/anonymizers/FB_cse.py")
18
- anonymizer_body = instantiate(cfg_body.anonymizer, load_cache=False)
19
- anonymizer_body.initialize_tracker(fps=1)
20
- cfg_face = utils.load_config("configs/anonymizers/face.py")
21
- anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
22
- anonymizer_face.initialize_tracker(fps=1)
23
-
24
-
25
- class ExampleDemo:
26
 
27
- def __init__(self, anonymizer, multi_modal_truncation=False) -> None:
28
- self.multi_modal_truncation = multi_modal_truncation
29
- self.anonymizer = anonymizer
30
- with gradio.Row():
31
- input_image = gradio.Image(type="pil", label="Upload your image or try the example below!")
32
- output_image = gradio.Image(type="numpy", label="Output")
33
- with gradio.Row():
34
- update_btn = gradio.Button("Update Anonymization").style(full_width=True)
35
- visualize_det = gradio.Checkbox(value=False, label="Show Detections")
36
- visualize_det.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
37
- gradio.Examples(
38
- ["erling.jpg", "g7-summit-leaders-distraction.jpg"], inputs=[input_image]
39
- )
40
- update_btn.click(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
41
- input_image.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
42
- self.track = False
43
 
44
- def anonymize(self, img: Image, visualize_detection: bool):
45
- img, cache_id = pil2torch(img)
46
- img = tops.to_cuda(img)
47
- if visualize_detection:
48
- img = self.anonymizer.visualize_detection(img, cache_id=cache_id)
49
- else:
50
- img = self.anonymizer(
51
- img, truncation_value=0 if self.multi_modal_truncation else 1, multi_modal_truncation=self.multi_modal_truncation, amp=True,
52
- cache_id=cache_id, track=self.track)
53
- img = utils.im2numpy(img)
54
- return img
55
 
 
56
 
57
- def pil2torch(img: Image.Image):
58
- img = img.convert("RGB")
59
- img = np.array(img)
60
- img = np.rollaxis(img, 2)
61
- return torch.from_numpy(img), None
62
 
63
 
64
  with gradio.Blocks() as demo:
65
  gradio.Markdown("# <center> DeepPrivacy2 - Realistic Image Anonymization </center>")
66
  gradio.Markdown("### <center> Håkon Hukkelås, Rudolf Mester, Frank Lindseth </center>")
67
- gradio.Markdown("<center> DeepPrivacy2 is a toolbox for realistic anonymization of humans, including a face and a full-body anonymizer. </center>")
68
  gradio.Markdown("<center> See more information at: <a href='https://github.com/hukkelas/deep_privacy2'> https://github.com/hukkelas/deep_privacy2 </a> </center>")
 
 
 
 
69
 
70
-
71
-
72
- with gradio.Tab("Full-Body Anonymization"):
73
- ExampleDemo(anonymizer_body, multi_modal_truncation=True)
74
- with gradio.Tab("Face Anonymization"):
75
- ExampleDemo(anonymizer_face, multi_modal_truncation=False)
76
-
77
-
78
- demo.launch()
 
1
+ import gradio
2
+ import sys
3
  import os
4
+ from pathlib import Path
5
+ from tops.config import instantiate
6
+ import gradio.inputs
7
  os.system("pip install --upgrade pip")
8
  os.system("pip install ftfy regex tqdm")
9
+ os.system("pip install --no-deps git+https://github.com/openai/CLIP.git")
10
  os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
11
+ os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
12
+ sys.path.insert(0, Path(os.getcwd(), "deep_privacy2"))
13
+ os.environ["TORCH_HOME"] = "torch_home"
 
 
14
  from dp2 import utils
15
+ from gradio_demos.modules import ExampleDemo, WebcamDemo
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ cfg_face = utils.load_config("deep_privacy2/configs/anonymizers/face.py")
18
+ for key in ["person_G_cfg", "cse_person_G_cfg", "face_G_cfg", "car_G_cfg"]:
19
+ if key in cfg_face.anonymizer:
20
+ cfg_face.anonymizer[key] = Path("deep_privacy2", cfg_face.anonymizer[key])
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
24
 
25
+ anonymizer_face.initialize_tracker(fps=1)
 
 
 
 
26
 
27
 
28
  with gradio.Blocks() as demo:
29
  gradio.Markdown("# <center> DeepPrivacy2 - Realistic Image Anonymization </center>")
30
  gradio.Markdown("### <center> Håkon Hukkelås, Rudolf Mester, Frank Lindseth </center>")
 
31
  gradio.Markdown("<center> See more information at: <a href='https://github.com/hukkelas/deep_privacy2'> https://github.com/hukkelas/deep_privacy2 </a> </center>")
32
+ with gradio.Tab("Face Anonymization"):
33
+ ExampleDemo(anonymizer_face)
34
+ with gradio.Tab("Live Webcam"):
35
+ WebcamDemo(anonymizer_face)
36
 
37
+ demo.launch()
 
 
 
 
 
 
 
 
configs/anonymizers/FB_cse.py DELETED
@@ -1,28 +0,0 @@
1
- from dp2.anonymizer import Anonymizer
2
- from dp2.detection.person_detector import CSEPersonDetector
3
- from ..defaults import common
4
- from tops.config import LazyCall as L
5
- from dp2.generator.dummy_generators import MaskOutGenerator
6
-
7
-
8
- maskout_G = L(MaskOutGenerator)(noise="constant")
9
-
10
- detector = L(CSEPersonDetector)(
11
- mask_rcnn_cfg=dict(),
12
- cse_cfg=dict(),
13
- cse_post_process_cfg=dict(
14
- target_imsize=(288, 160),
15
- exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
16
- exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
17
- iou_combine_threshold=0.4,
18
- dilation_percentage=0.02,
19
- normalize_embedding=False
20
- ),
21
- score_threshold=0.3,
22
- cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
23
- )
24
-
25
- anonymizer = L(Anonymizer)(
26
- detector="${detector}",
27
- cse_person_G_cfg="configs/fdh/styleganL.py",
28
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/anonymizers/FB_cse_mask.py DELETED
@@ -1,29 +0,0 @@
1
- from dp2.anonymizer import Anonymizer
2
- from dp2.detection.person_detector import CSEPersonDetector
3
- from ..defaults import common
4
- from tops.config import LazyCall as L
5
- from dp2.generator.dummy_generators import MaskOutGenerator
6
-
7
-
8
- maskout_G = L(MaskOutGenerator)(noise="constant")
9
-
10
- detector = L(CSEPersonDetector)(
11
- mask_rcnn_cfg=dict(),
12
- cse_cfg=dict(),
13
- cse_post_process_cfg=dict(
14
- target_imsize=(288, 160),
15
- exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
16
- exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
17
- iou_combine_threshold=0.4,
18
- dilation_percentage=0.02,
19
- normalize_embedding=False
20
- ),
21
- score_threshold=0.3,
22
- cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
23
- )
24
-
25
- anonymizer = L(Anonymizer)(
26
- detector="${detector}",
27
- person_G_cfg="configs/fdh/styleganL_nocse.py",
28
- cse_person_G_cfg="configs/fdh/styleganL.py",
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/anonymizers/FB_cse_mask_face.py DELETED
@@ -1,29 +0,0 @@
1
- from dp2.anonymizer import Anonymizer
2
- from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector
3
- from ..defaults import common
4
- from tops.config import LazyCall as L
5
-
6
- detector = L(CSeMaskFaceDetector)(
7
- mask_rcnn_cfg=dict(),
8
- face_detector_cfg=dict(),
9
- face_post_process_cfg=dict(target_imsize=(256, 256)),
10
- cse_cfg=dict(),
11
- cse_post_process_cfg=dict(
12
- target_imsize=(288, 160),
13
- exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
14
- exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
15
- iou_combine_threshold=0.4,
16
- dilation_percentage=0.02,
17
- normalize_embedding=False
18
- ),
19
- score_threshold=0.3,
20
- cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache")
21
- )
22
-
23
- anonymizer = L(Anonymizer)(
24
- detector="${detector}",
25
- face_G_cfg="configs/fdf/stylegan.py",
26
- person_G_cfg="configs/fdh/styleganL_nocse.py",
27
- cse_person_G_cfg="configs/fdh/styleganL.py",
28
- car_G_cfg="configs/generators/dummy/pixelation8.py"
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/anonymizers/face.py DELETED
@@ -1,18 +0,0 @@
1
- from dp2.anonymizer import Anonymizer
2
- from dp2.detection.face_detector import FaceDetector
3
- from ..defaults import common
4
- from tops.config import LazyCall as L
5
-
6
-
7
- detector = L(FaceDetector)(
8
- face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
9
- face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
10
- score_threshold=0.3,
11
- cache_directory=common.output_dir.joinpath("face_detection_cache")
12
- )
13
-
14
-
15
- anonymizer = L(Anonymizer)(
16
- detector="${detector}",
17
- face_G_cfg="configs/fdf/stylegan.py",
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/anonymizers/market1501/blackout.py DELETED
@@ -1,8 +0,0 @@
1
- from ..FB_cse_mask_face import anonymizer, detector, common
2
-
3
- detector.score_threshold = .1
4
- detector.face_detector_cfg.confidence_threshold = .5
5
- detector.cse_cfg.score_thres = 0.3
6
- anonymizer.generators.face_G_cfg = None
7
- anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py"
8
- anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py"
 
 
 
 
 
 
 
 
 
configs/anonymizers/market1501/person.py DELETED
@@ -1,6 +0,0 @@
1
- from ..FB_cse_mask_face import anonymizer, detector, common
2
-
3
- detector.score_threshold = .1
4
- detector.face_detector_cfg.confidence_threshold = .5
5
- detector.cse_cfg.score_thres = 0.3
6
- anonymizer.generators.face_G_cfg = None
 
 
 
 
 
 
 
configs/anonymizers/market1501/pixelation16.py DELETED
@@ -1,8 +0,0 @@
1
- from ..FB_cse_mask_face import anonymizer, detector, common
2
-
3
- detector.score_threshold = .1
4
- detector.face_detector_cfg.confidence_threshold = .5
5
- detector.cse_cfg.score_thres = 0.3
6
- anonymizer.generators.face_G_cfg = None
7
- anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py"
8
- anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py"
 
 
 
 
 
 
 
 
 
configs/anonymizers/market1501/pixelation8.py DELETED
@@ -1,8 +0,0 @@
1
- from ..FB_cse_mask_face import anonymizer, detector, common
2
-
3
- detector.score_threshold = .1
4
- detector.face_detector_cfg.confidence_threshold = .5
5
- detector.cse_cfg.score_thres = 0.3
6
- anonymizer.generators.face_G_cfg = None
7
- anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py"
8
- anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py"
 
 
 
 
 
 
 
 
 
configs/datasets/coco_cse.py DELETED
@@ -1,69 +0,0 @@
1
- import os
2
- from pathlib import Path
3
- from tops.config import LazyCall as L
4
- import torch
5
- import functools
6
- from dp2.data.datasets import CocoCSE
7
- from dp2.data.build import get_dataloader
8
- from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
9
- from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
10
- from dp2.metrics.torch_metrics import compute_metrics_iteratively
11
- from .utils import final_eval_fn
12
-
13
-
14
- dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
15
- metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
16
- data_dir = Path(dataset_base_dir, "coco_cse")
17
- data = dict(
18
- imsize=(288, 160),
19
- im_channels=3,
20
- semantic_nc=26,
21
- cse_nc=16,
22
- train=dict(
23
- dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
24
- loader=L(get_dataloader)(
25
- shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
26
- batch_size="${train.batch_size}",
27
- dataset="${..dataset}",
28
- infinite=True,
29
- gpu_transform=L(torch.nn.Sequential)(*[
30
- L(ToFloat)(),
31
- L(StyleGANAugmentPipe)(
32
- rotate=0.5, rotate_max=.05,
33
- xint=.5, xint_max=0.05,
34
- scale=.5, scale_std=.05,
35
- aniso=0.5, aniso_std=.05,
36
- xfrac=.5, xfrac_std=.05,
37
- brightness=.5, brightness_std=.05,
38
- contrast=.5, contrast_std=.1,
39
- hue=.5, hue_max=.05,
40
- saturation=.5, saturation_std=.5,
41
- imgfilter=.5, imgfilter_std=.1),
42
- L(RandomHorizontalFlip)(p=0.5),
43
- L(CreateEmbedding)(),
44
- L(Resize)(size="${data.imsize}"),
45
- L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
46
- L(CreateCondition)(),
47
- ])
48
- )
49
- ),
50
- val=dict(
51
- dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
52
- loader=L(get_dataloader)(
53
- shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
54
- batch_size="${train.batch_size}",
55
- dataset="${..dataset}",
56
- infinite=False,
57
- gpu_transform=L(torch.nn.Sequential)(*[
58
- L(ToFloat)(),
59
- L(CreateEmbedding)(),
60
- L(Resize)(size="${data.imsize}"),
61
- L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
62
- L(CreateCondition)(),
63
- ])
64
- )
65
- ),
66
- # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
67
- train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
68
- evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
69
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/datasets/fdf128.py DELETED
@@ -1,24 +0,0 @@
1
- from pathlib import Path
2
- from functools import partial
3
- from dp2.data.datasets.fdf import FDFDataset
4
- from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn
5
-
6
- data_dir = Path(dataset_base_dir, "fdf")
7
- data.train.dataset.dirpath = data_dir.joinpath("train")
8
- data.val.dataset.dirpath = data_dir.joinpath("val")
9
- data.imsize = (128, 128)
10
-
11
-
12
- data.train_evaluation_fn = partial(
13
- final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train"))
14
- data.evaluation_fn = partial(
15
- final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final"))
16
-
17
- data.train.dataset.update(
18
- _target_ = FDFDataset,
19
- imsize="${data.imsize}"
20
- )
21
- data.val.dataset.update(
22
- _target_ = FDFDataset,
23
- imsize="${data.imsize}"
24
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/datasets/fdf256.py DELETED
@@ -1,69 +0,0 @@
1
- import os
2
- from pathlib import Path
3
- from tops.config import LazyCall as L
4
- import torch
5
- import functools
6
- from dp2.data.datasets.fdf import FDF256Dataset
7
- from dp2.data.build import get_dataloader
8
- from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
9
- from dp2.metrics.torch_metrics import compute_metrics_iteratively
10
- from dp2.metrics.fid_clip import compute_fid_clip
11
- from dp2.metrics.ppl import calculate_ppl
12
- from .utils import final_eval_fn
13
-
14
-
15
- def final_eval_fn(*args, **kwargs):
16
- result = compute_metrics_iteratively(*args, **kwargs)
17
- result2 = compute_fid_clip(*args, **kwargs)
18
- assert all(key not in result for key in result2)
19
- result.update(result2)
20
- result3 = calculate_ppl(*args, **kwargs,)
21
- assert all(key not in result for key in result3)
22
- result.update(result3)
23
- return result
24
-
25
-
26
- dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
27
- metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
28
- data_dir = Path(dataset_base_dir, "fdf256")
29
- data = dict(
30
- imsize=(256, 256),
31
- im_channels=3,
32
- semantic_nc=None,
33
- cse_nc=None,
34
- n_keypoints=None,
35
- train=dict(
36
- dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
37
- loader=L(get_dataloader)(
38
- shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
39
- batch_size="${train.batch_size}",
40
- dataset="${..dataset}",
41
- infinite=True,
42
- gpu_transform=L(torch.nn.Sequential)(*[
43
- L(ToFloat)(),
44
- L(RandomHorizontalFlip)(p=0.5),
45
- L(Resize)(size="${data.imsize}"),
46
- L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
47
- L(CreateCondition)(),
48
- ])
49
- )
50
- ),
51
- val=dict(
52
- dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
53
- loader=L(get_dataloader)(
54
- shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
55
- batch_size="${train.batch_size}",
56
- dataset="${..dataset}",
57
- infinite=False,
58
- gpu_transform=L(torch.nn.Sequential)(*[
59
- L(ToFloat)(),
60
- L(Resize)(size="${data.imsize}"),
61
- L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
62
- L(CreateCondition)(),
63
- ])
64
- )
65
- ),
66
- # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
67
- train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "fdf_val_train")),
68
- evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
69
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/datasets/fdh.py DELETED
@@ -1,89 +0,0 @@
1
- import os
2
- from pathlib import Path
3
- from tops.config import LazyCall as L
4
- import torch
5
- import functools
6
- from dp2.data.datasets.fdh import get_dataloader_fdh_wds
7
- from dp2.data.utils import get_coco_flipmap
8
- from dp2.data.transforms.transforms import (
9
- Normalize,
10
- ToFloat,
11
- CreateCondition,
12
- RandomHorizontalFlip,
13
- CreateEmbedding,
14
- )
15
- from dp2.metrics.torch_metrics import compute_metrics_iteratively
16
- from dp2.metrics.fid_clip import compute_fid_clip
17
- from .utils import final_eval_fn
18
-
19
-
20
- def train_eval_fn(*args, **kwargs):
21
- result = compute_metrics_iteratively(*args, **kwargs)
22
- result2 = compute_fid_clip(*args, **kwargs)
23
- assert all(key not in result for key in result2)
24
- result.update(result2)
25
- return result
26
-
27
-
28
- dataset_base_dir = (
29
- os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
30
- )
31
- metrics_cache = (
32
- os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
33
- )
34
- data_dir = Path(dataset_base_dir, "fdh")
35
- data = dict(
36
- imsize=(288, 160),
37
- im_channels=3,
38
- cse_nc=16,
39
- n_keypoints=17,
40
- train=dict(
41
- loader=L(get_dataloader_fdh_wds)(
42
- path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
43
- batch_size="${train.batch_size}",
44
- num_workers=6,
45
- transform=L(torch.nn.Sequential)(
46
- L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
47
- ),
48
- gpu_transform=L(torch.nn.Sequential)(
49
- L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
50
- L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
51
- L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
52
- L(CreateCondition)(),
53
- ),
54
- infinite=True,
55
- shuffle=True,
56
- partial_batches=False,
57
- load_embedding=True,
58
- )
59
- ),
60
- val=dict(
61
- loader=L(get_dataloader_fdh_wds)(
62
- path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
63
- batch_size="${train.batch_size}",
64
- num_workers=6,
65
- transform=None,
66
- gpu_transform=L(torch.nn.Sequential)(
67
- L(ToFloat)(keys=["img", "mask", "E_mask", "maskrcnn_mask"], norm=False),
68
- L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
69
- L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
70
- L(CreateCondition)(),
71
- ),
72
- infinite=False,
73
- shuffle=False,
74
- partial_batches=True,
75
- load_embedding=True,
76
- )
77
- ),
78
- # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
79
- train_evaluation_fn=functools.partial(
80
- train_eval_fn,
81
- cache_directory=Path(metrics_cache, "fdh_v7_train"),
82
- data_len=int(30e3),
83
- ),
84
- evaluation_fn=functools.partial(
85
- final_eval_fn,
86
- cache_directory=Path(metrics_cache, "fdh_v6_val"),
87
- data_len=int(30e3),
88
- ),
89
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/datasets/utils.py DELETED
@@ -1,12 +0,0 @@
1
- from dp2.metrics.ppl import calculate_ppl
2
- from dp2.metrics.torch_metrics import compute_metrics_iteratively
3
- from dp2.metrics.fid_clip import compute_fid_clip
4
-
5
-
6
- def final_eval_fn(*args, **kwargs):
7
- result = compute_metrics_iteratively(*args, **kwargs)
8
- result2 = calculate_ppl(*args, **kwargs,)
9
- result2 = compute_fid_clip(*args, **kwargs)
10
- assert all(key not in result for key in result2)
11
- result.update(result2)
12
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/defaults.py DELETED
@@ -1,45 +0,0 @@
1
- import pathlib
2
- import os
3
- import torch
4
- from tops.config import LazyCall as L
5
-
6
- if "PRETRAINED_CHECKPOINTS_PATH" in os.environ:
7
- PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"])
8
- else:
9
- PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints")
10
- if "BASE_OUTPUT_DIR" in os.environ:
11
- BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"])
12
- else:
13
- BASE_OUTPUT_DIR = pathlib.Path("outputs")
14
-
15
-
16
-
17
- common = dict(
18
- logger_backend=["wandb", "stdout", "json", "image_dumper"],
19
- wandb_project="fba_test",
20
- output_dir=BASE_OUTPUT_DIR,
21
- experiment_name=None, # Optional experiment name to show on wandb
22
- )
23
-
24
- train = dict(
25
- batch_size=32,
26
- seed=0,
27
- ims_per_log=1024,
28
- ims_per_val=int(200e3),
29
- max_images_to_train=int(12e6),
30
- amp=dict(
31
- enabled=True,
32
- scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
33
- scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
34
- ),
35
- fp16_ddp_accumulate=False, # All gather gradients in fp16?
36
- broadcast_buffers=False,
37
- bias_act_plugin_enabled=True,
38
- grid_sample_gradfix_enabled=True,
39
- conv2d_gradfix_enabled=False,
40
- channels_last=False,
41
- )
42
-
43
- # exponential moving average
44
- EMA = dict(rampup=0.05)
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/discriminators/sg2_discriminator.py DELETED
@@ -1,42 +0,0 @@
1
- from tops.config import LazyCall as L
2
- from dp2.discriminator import SG2Discriminator
3
- import torch
4
- from dp2.loss import StyleGAN2Loss
5
-
6
-
7
- discriminator = L(SG2Discriminator)(
8
- imsize="${data.imsize}",
9
- im_channels="${data.im_channels}",
10
- min_fmap_resolution=4,
11
- max_cnum_mul=8,
12
- cnum=80,
13
- input_condition=True,
14
- conv_clamp=256,
15
- input_cse=False,
16
- cse_nc="${data.cse_nc}"
17
- )
18
-
19
-
20
- loss_fnc = L(StyleGAN2Loss)(
21
- lazy_regularization=True,
22
- lazy_reg_interval=16,
23
- r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False),
24
- EP_lambd=0.001,
25
- pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01)
26
- )
27
-
28
- def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs):
29
- if lazy_regularization:
30
- # From Analyzing and improving the image quality of stylegan, CVPR 2020
31
- c = lazy_reg_interval / (lazy_reg_interval + 1)
32
- betas = [beta ** c for beta in betas]
33
- lr *= c
34
- print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}")
35
- return type(lr=lr, betas=betas, **kwargs)
36
-
37
-
38
- D_optim = L(build_D_optim)(
39
- type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99),
40
- lazy_regularization="${loss_fnc.lazy_regularization}",
41
- lazy_reg_interval="${loss_fnc.lazy_reg_interval}")
42
- G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/fdf/stylegan.py DELETED
@@ -1,14 +0,0 @@
1
- from ..generators.stylegan_unet import generator
2
- from ..datasets.fdf256 import data
3
- from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
4
- from ..defaults import train, common, EMA
5
-
6
- train.max_images_to_train = int(35e6)
7
- G_optim.lr = 0.002
8
- D_optim.lr = 0.002
9
- generator.input_cse = False
10
- loss_fnc.r1_opts.lambd = 1
11
- train.ims_per_val = int(2e6)
12
-
13
- common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e"
14
- common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/fdf/stylegan_fdf128.py DELETED
@@ -1,13 +0,0 @@
1
- from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
2
- from ..datasets.fdf128 import data
3
- from ..generators.stylegan_unet import generator
4
- from ..defaults import train, common, EMA
5
- from tops.config import LazyCall as L
6
-
7
- train.max_images_to_train = int(25e6)
8
- G_optim.lr = 0.002
9
- D_optim.lr = 0.002
10
- generator.cnum = 128
11
- generator.max_cnum_mul = 4
12
- generator.input_cse = False
13
- loss_fnc.r1_opts.lambd = .1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/fdh/styleganL.py DELETED
@@ -1,16 +0,0 @@
1
- from tops.config import LazyCall as L
2
- from ..generators.stylegan_unet import generator
3
- from ..datasets.fdh import data
4
- from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
5
- from ..defaults import train, common, EMA
6
-
7
- train.max_images_to_train = int(50e6)
8
- train.batch_size = 64
9
- G_optim.lr = 0.002
10
- D_optim.lr = 0.002
11
- data.train.loader.num_workers = 4
12
- train.ims_per_val = int(1e6)
13
- loss_fnc.r1_opts.lambd = .1
14
-
15
- common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c"
16
- common.model_md5sum = "3411478b5ec600a4219cccf4499732bd"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/fdh/styleganL_nocse.py DELETED
@@ -1,14 +0,0 @@
1
- from tops.config import LazyCall as L
2
- from ..generators.stylegan_unet import generator
3
- from ..datasets.fdh import data
4
- from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
5
- from ..defaults import train, common, EMA
6
-
7
- train.max_images_to_train = int(50e6)
8
- G_optim.lr = 0.002
9
- D_optim.lr = 0.002
10
- generator.input_cse = False
11
- data.load_embeddings = False
12
- common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt"
13
- common.model_md5sum = "fda0d809741bc67487abada793975c37"
14
- generator.fix_errors = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/generators/stylegan_unet.py DELETED
@@ -1,22 +0,0 @@
1
- from dp2.generator.stylegan_unet import StyleGANUnet
2
- from tops.config import LazyCall as L
3
-
4
- generator = L(StyleGANUnet)(
5
- imsize="${data.imsize}",
6
- im_channels="${data.im_channels}",
7
- min_fmap_resolution=8,
8
- cnum=64,
9
- max_cnum_mul=8,
10
- n_middle_blocks=0,
11
- z_channels=512,
12
- mask_output=True,
13
- conv_clamp=256,
14
- input_cse=True,
15
- scale_grad=True,
16
- cse_nc="${data.cse_nc}",
17
- w_dim=512,
18
- n_keypoints="${data.n_keypoints}",
19
- input_keypoints=False,
20
- input_keypoint_indices=[],
21
- fix_errors=True
22
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deep_privacy2 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 37dcbeb23a1f51121d53bcd80d32d086d6822b7b
dp2/__init__.py DELETED
File without changes
dp2/anonymizer/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .anonymizer import Anonymizer
 
 
dp2/anonymizer/anonymizer.py DELETED
@@ -1,159 +0,0 @@
1
- from pathlib import Path
2
- from typing import Union, Optional
3
- import numpy as np
4
- import torch
5
- import tops
6
- import torchvision.transforms.functional as F
7
- from motpy import Detection, MultiObjectTracker
8
- from dp2.utils import load_config
9
- from dp2.infer import build_trained_generator
10
- from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
11
-
12
-
13
- def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
14
- cfg = load_config(cfg_path)
15
- G = build_trained_generator(cfg)
16
- tops.logger.log(f"Loaded generator from: {cfg_path}")
17
- return G
18
-
19
-
20
- def resize_batch(img, mask, maskrcnn_mask, condition, imsize, **kwargs):
21
- img = F.resize(img, imsize, antialias=True)
22
- mask = (F.resize(mask, imsize, antialias=True) > 0.99).float()
23
- maskrcnn_mask = (F.resize(maskrcnn_mask, imsize, antialias=True) > 0.5).float()
24
-
25
- condition = img * mask
26
- return dict(img=img, mask=mask, maskrcnn_mask=maskrcnn_mask, condition=condition)
27
-
28
-
29
- class Anonymizer:
30
-
31
- def __init__(
32
- self,
33
- detector,
34
- load_cache: bool,
35
- person_G_cfg: Optional[Union[str, Path]] = None,
36
- cse_person_G_cfg: Optional[Union[str, Path]] = None,
37
- face_G_cfg: Optional[Union[str, Path]] = None,
38
- car_G_cfg: Optional[Union[str, Path]] = None,
39
- ) -> None:
40
- self.detector = detector
41
- self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
42
- self.load_cache = load_cache
43
- if cse_person_G_cfg is not None:
44
- self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
45
- if person_G_cfg is not None:
46
- self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
47
- if face_G_cfg is not None:
48
- self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
49
- if car_G_cfg is not None:
50
- self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
51
-
52
- def initialize_tracker(self, fps: float):
53
- self.tracker = MultiObjectTracker(dt=1/fps)
54
- self.track_to_z_idx = dict()
55
- self.cur_z_idx = 0
56
-
57
- @torch.no_grad()
58
- def anonymize_detections(self,
59
- im, detection, truncation_value: float,
60
- multi_modal_truncation: bool, amp: bool, z_idx,
61
- all_styles=None,
62
- update_identity=None,
63
- ):
64
- G = self.generators[type(detection)]
65
- if G is None:
66
- return im
67
- C, H, W = im.shape
68
- orig_im = im.clone()
69
- if update_identity is None:
70
- update_identity = [True for i in range(len(detection))]
71
- for idx in range(len(detection)):
72
- if not update_identity[idx]:
73
- continue
74
- batch = detection.get_crop(idx, im)
75
- x0, y0, x1, y1 = batch.pop("boxes")[0]
76
- batch = {k: tops.to_cuda(v) for k, v in batch.items()}
77
- batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
78
- batch["img"] = batch["img"].float()
79
- batch["condition"] = batch["mask"] * batch["img"]
80
- orig_shape = None
81
- if G.imsize and batch["img"].shape[-1] != G.imsize[-1] and batch["img"].shape[-2] != G.imsize[-2]:
82
- orig_shape = batch["img"].shape[-2:]
83
- batch = resize_batch(**batch, imsize=G.imsize)
84
- with torch.cuda.amp.autocast(amp):
85
- if all_styles is not None:
86
- anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
87
- elif multi_modal_truncation and hasattr(G, "multi_modal_truncate") and hasattr(G.style_net, "w_centers"):
88
- w_indices = None
89
- if z_idx is not None:
90
- w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
91
- anonymized_im = G.multi_modal_truncate(
92
- **batch, truncation_value=truncation_value,
93
- w_indices=w_indices)["img"]
94
- else:
95
- z = None
96
- if z_idx is not None:
97
- state = np.random.RandomState(seed=z_idx[idx])
98
- z = state.normal(size=(1, G.z_channels))
99
- z = tops.to_cuda(torch.from_numpy(z))
100
- anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
101
- if orig_shape is not None:
102
- anonymized_im = F.resize(anonymized_im, orig_shape, antialias=True)
103
- anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255).round().byte()
104
-
105
- # Resize and denormalize image
106
- gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), antialias=True)
107
- mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
108
- # Remove padding
109
- pad = [max(-x0,0), max(-y0,0)]
110
- pad = [*pad, max(x1-W,0), max(y1-H,0)]
111
- remove_pad = lambda x: x[...,pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
112
- gim = remove_pad(gim)
113
- mask = remove_pad(mask)
114
- x0, y0 = max(x0, 0), max(y0, 0)
115
- x1, y1 = min(x1, W), min(y1, H)
116
- mask = mask.logical_not()[None].repeat(3, 1, 1)
117
- im[:, y0:y1, x0:x1][mask] = gim[mask]
118
-
119
- return im
120
-
121
- def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
122
- all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
123
- for det in all_detections:
124
- im = det.visualize(im)
125
- return im
126
-
127
- @torch.no_grad()
128
- def forward(self, im: torch.Tensor, cache_id: str = None, track=True, **synthesis_kwargs) -> torch.Tensor:
129
- assert im.dtype == torch.uint8
130
- im = tops.to_cuda(im)
131
- all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
132
- if hasattr(self, "tracker") and track:
133
- [_.pre_process() for _ in all_detections]
134
- import numpy as np
135
- boxes = np.concatenate([_.boxes for _ in all_detections])
136
- boxes = [Detection(box) for box in boxes]
137
- self.tracker.step(boxes)
138
- track_ids = self.tracker.detections_matched_ids
139
- z_idx = []
140
- for track_id in track_ids:
141
- if track_id not in self.track_to_z_idx:
142
- self.track_to_z_idx[track_id] = self.cur_z_idx
143
- self.cur_z_idx += 1
144
- z_idx.append(self.track_to_z_idx[track_id])
145
- z_idx = np.array(z_idx)
146
- idx_offset = 0
147
-
148
- for detection in all_detections:
149
- zs = None
150
- if hasattr(self, "tracker") and track:
151
- zs = z_idx[idx_offset:idx_offset+len(detection)]
152
- idx_offset += len(detection)
153
- im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
154
-
155
- return im.cpu()
156
-
157
- def __call__(self, *args, **kwargs):
158
- return self.forward(*args, **kwargs)
159
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/data/__init__.py DELETED
File without changes
dp2/data/build.py DELETED
@@ -1,148 +0,0 @@
1
- import io
2
- import torch
3
- import tops
4
- from .utils import collate_fn, jpg_decoder, get_num_workers, png_decoder
5
-
6
- def get_dataloader(
7
- dataset, gpu_transform: torch.nn.Module,
8
- num_workers,
9
- batch_size,
10
- infinite: bool,
11
- drop_last: bool,
12
- prefetch_factor: int,
13
- shuffle,
14
- channels_last=False
15
- ):
16
- sampler = None
17
- dl_kwargs = dict(
18
- pin_memory=True,
19
- )
20
- if infinite:
21
- sampler = tops.InfiniteSampler(
22
- dataset, rank=tops.rank(),
23
- num_replicas=tops.world_size(),
24
- shuffle=shuffle
25
- )
26
- elif tops.world_size() > 1:
27
- sampler = torch.utils.data.DistributedSampler(
28
- dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
29
- dl_kwargs["drop_last"] = drop_last
30
- else:
31
- dl_kwargs["shuffle"] = shuffle
32
- dl_kwargs["drop_last"] = drop_last
33
- dataloader = torch.utils.data.DataLoader(
34
- dataset, sampler=sampler, collate_fn=collate_fn,
35
- batch_size=batch_size,
36
- num_workers=num_workers, prefetch_factor=prefetch_factor,
37
- **dl_kwargs
38
- )
39
- dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
40
- return dataloader
41
-
42
-
43
- def get_dataloader_places2_wds(
44
- path,
45
- batch_size: int,
46
- num_workers: int,
47
- transform: torch.nn.Module,
48
- gpu_transform: torch.nn.Module,
49
- infinite: bool,
50
- shuffle: bool,
51
- partial_batches: bool,
52
- sample_shuffle=10_000,
53
- tar_shuffle=100,
54
- channels_last=False,
55
- ):
56
- import webdataset as wds
57
- import os
58
- os.environ["RANK"] = str(tops.rank())
59
- os.environ["WORLD_SIZE"] = str(tops.world_size())
60
-
61
- if infinite:
62
- pipeline = [wds.ResampledShards(str(path))]
63
- else:
64
- pipeline = [wds.SimpleShardList(str(path))]
65
- if shuffle:
66
- pipeline.append(wds.shuffle(tar_shuffle))
67
- pipeline.extend([
68
- wds.split_by_node,
69
- wds.split_by_worker,
70
- ])
71
- if shuffle:
72
- pipeline.append(wds.shuffle(sample_shuffle))
73
-
74
- pipeline.extend([
75
- wds.tarfile_to_samples(),
76
- wds.decode("torchrgb8"),
77
- wds.rename_keys(["img", "jpg"], ["__key__", "__key__"]),
78
- ])
79
- if transform is not None:
80
- pipeline.append(wds.map(transform))
81
- pipeline.extend([
82
- wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
83
- ])
84
- pipeline = wds.DataPipeline(*pipeline)
85
- if infinite:
86
- pipeline = pipeline.repeat(nepochs=1000000)
87
- loader = wds.WebLoader(
88
- pipeline, batch_size=None, shuffle=False,
89
- num_workers=get_num_workers(num_workers),
90
- persistent_workers=True,
91
- )
92
- loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
93
- return loader
94
-
95
-
96
-
97
-
98
- def get_dataloader_celebAHQ_wds(
99
- path,
100
- batch_size: int,
101
- num_workers: int,
102
- transform: torch.nn.Module,
103
- gpu_transform: torch.nn.Module,
104
- infinite: bool,
105
- shuffle: bool,
106
- partial_batches: bool,
107
- sample_shuffle=10_000,
108
- tar_shuffle=100,
109
- channels_last=False,
110
- ):
111
- import webdataset as wds
112
- import os
113
- os.environ["RANK"] = str(tops.rank())
114
- os.environ["WORLD_SIZE"] = str(tops.world_size())
115
-
116
- if infinite:
117
- pipeline = [wds.ResampledShards(str(path))]
118
- else:
119
- pipeline = [wds.SimpleShardList(str(path))]
120
- if shuffle:
121
- pipeline.append(wds.shuffle(tar_shuffle))
122
- pipeline.extend([
123
- wds.split_by_node,
124
- wds.split_by_worker,
125
- ])
126
- if shuffle:
127
- pipeline.append(wds.shuffle(sample_shuffle))
128
-
129
- pipeline.extend([
130
- wds.tarfile_to_samples(),
131
- wds.decode(wds.handle_extension(".png", png_decoder)),
132
- wds.rename_keys(["img", "png"], ["__key__", "__key__"]),
133
- ])
134
- if transform is not None:
135
- pipeline.append(wds.map(transform))
136
- pipeline.extend([
137
- wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
138
- ])
139
- pipeline = wds.DataPipeline(*pipeline)
140
- if infinite:
141
- pipeline = pipeline.repeat(nepochs=1000000)
142
- loader = wds.WebLoader(
143
- pipeline, batch_size=None, shuffle=False,
144
- num_workers=get_num_workers(num_workers),
145
- persistent_workers=True,
146
- )
147
- loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last)
148
- return loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/data/datasets/__init__.py DELETED
File without changes
dp2/data/datasets/coco_cse.py DELETED
@@ -1,148 +0,0 @@
1
- import pickle
2
- import torchvision
3
- import torch
4
- import pathlib
5
- import numpy as np
6
- from typing import Callable, Optional, Union
7
- from torch.hub import get_dir as get_hub_dir
8
-
9
-
10
- def cache_embed_stats(embed_map: torch.Tensor):
11
- mean = embed_map.mean(dim=0, keepdim=True)
12
- rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
13
-
14
- cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
15
- path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
16
- path.parent.mkdir(exist_ok=True, parents=True)
17
- torch.save(cache, path)
18
-
19
-
20
- class CocoCSE(torch.utils.data.Dataset):
21
-
22
- def __init__(self,
23
- dirpath: Union[str, pathlib.Path],
24
- transform: Optional[Callable],
25
- normalize_E: bool,):
26
- dirpath = pathlib.Path(dirpath)
27
- self.dirpath = dirpath
28
-
29
- self.transform = transform
30
- assert self.dirpath.is_dir(),\
31
- f"Did not find dataset at: {dirpath}"
32
- self.image_paths, self.embedding_paths = self._load_impaths()
33
- self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
34
- mean = self.embed_map.mean(dim=0, keepdim=True)
35
- rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
36
- self.embed_map = (self.embed_map - mean) * rstd
37
- cache_embed_stats(self.embed_map)
38
-
39
- def _load_impaths(self):
40
- image_dir = self.dirpath.joinpath("images")
41
- image_paths = list(image_dir.glob("*.png"))
42
- image_paths.sort()
43
- embedding_paths = [
44
- self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
45
- ]
46
- return image_paths, embedding_paths
47
-
48
- def __len__(self):
49
- return len(self.image_paths)
50
-
51
- def __getitem__(self, idx):
52
- im = torchvision.io.read_image(str(self.image_paths[idx]))
53
- vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
54
- vertices = torch.from_numpy(vertices.squeeze()).long()
55
- mask = torch.from_numpy(mask.squeeze()).float()
56
- border = torch.from_numpy(border.squeeze()).float()
57
- E_mask = 1 - mask - border
58
- batch = {
59
- "img": im,
60
- "vertices": vertices[None],
61
- "mask": mask[None],
62
- "embed_map": self.embed_map,
63
- "border": border[None],
64
- "E_mask": E_mask[None]
65
- }
66
- if self.transform is None:
67
- return batch
68
- return self.transform(batch)
69
-
70
-
71
- class CocoCSEWithFace(CocoCSE):
72
-
73
- def __init__(self,
74
- dirpath: Union[str, pathlib.Path],
75
- transform: Optional[Callable],
76
- **kwargs):
77
- super().__init__(dirpath, transform, **kwargs)
78
- with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp:
79
- self.face_boxes = pickle.load(fp)
80
-
81
- def __getitem__(self, idx):
82
- item = super().__getitem__(idx)
83
- item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name]
84
- return item
85
-
86
-
87
- class CocoCSESemantic(torch.utils.data.Dataset):
88
-
89
- def __init__(self,
90
- dirpath: Union[str, pathlib.Path],
91
- transform: Optional[Callable],
92
- **kwargs):
93
- dirpath = pathlib.Path(dirpath)
94
- self.dirpath = dirpath
95
-
96
- self.transform = transform
97
- assert self.dirpath.is_dir(),\
98
- f"Did not find dataset at: {dirpath}"
99
- self.image_paths, self.embedding_paths = self._load_impaths()
100
- self.vertx2cat = torch.from_numpy(np.load(self.dirpath.parent.joinpath("vertx2cat.npy")))
101
- self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
102
-
103
- def _load_impaths(self):
104
- image_dir = self.dirpath.joinpath("images")
105
- image_paths = list(image_dir.glob("*.png"))
106
- image_paths.sort()
107
- embedding_paths = [
108
- self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
109
- ]
110
- return image_paths, embedding_paths
111
-
112
- def __len__(self):
113
- return len(self.image_paths)
114
-
115
- def __getitem__(self, idx):
116
- im = torchvision.io.read_image(str(self.image_paths[idx]))
117
- vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
118
- vertices = torch.from_numpy(vertices.squeeze()).long()
119
- mask = torch.from_numpy(mask.squeeze()).float()
120
- border = torch.from_numpy(border.squeeze()).float()
121
- E_mask = 1 - mask - border
122
- batch = {
123
- "img": im,
124
- "vertices": vertices[None],
125
- "mask": mask[None],
126
- "border": border[None],
127
- "vertx2cat": self.vertx2cat,
128
- "embed_map": self.embed_map,
129
- }
130
- if self.transform is None:
131
- return batch
132
- return self.transform(batch)
133
-
134
-
135
- class CocoCSESemanticWithFace(CocoCSESemantic):
136
-
137
- def __init__(self,
138
- dirpath: Union[str, pathlib.Path],
139
- transform: Optional[Callable],
140
- **kwargs):
141
- super().__init__(dirpath, transform, **kwargs)
142
- with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp:
143
- self.face_boxes = pickle.load(fp)
144
-
145
- def __getitem__(self, idx):
146
- item = super().__getitem__(idx)
147
- item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name]
148
- return item
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/data/datasets/fdf.py DELETED
@@ -1,129 +0,0 @@
1
- import pathlib
2
- from typing import Tuple
3
- import numpy as np
4
- import torch
5
- import pathlib
6
- try:
7
- import pyspng
8
- PYSPNG_IMPORTED = True
9
- except ImportError:
10
- PYSPNG_IMPORTED = False
11
- print("Could not load pyspng. Defaulting to pillow image backend.")
12
- from PIL import Image
13
- from tops import logger
14
-
15
-
16
- class FDFDataset:
17
-
18
- def __init__(self,
19
- dirpath,
20
- imsize: Tuple[int],
21
- load_keypoints: bool,
22
- transform):
23
- dirpath = pathlib.Path(dirpath)
24
- self.dirpath = dirpath
25
- self.transform = transform
26
- self.imsize = imsize[0]
27
- self.load_keypoints = load_keypoints
28
- assert self.dirpath.is_dir(),\
29
- f"Did not find dataset at: {dirpath}"
30
- image_dir = self.dirpath.joinpath("images", str(self.imsize))
31
- self.image_paths = list(image_dir.glob("*.png"))
32
- assert len(self.image_paths) > 0,\
33
- f"Did not find images in: {image_dir}"
34
- self.image_paths.sort(key=lambda x: int(x.stem))
35
- self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
36
-
37
- self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch"))
38
- assert len(self.image_paths) == len(self.bounding_boxes)
39
- assert len(self.image_paths) == len(self.landmarks)
40
- logger.log(
41
- f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}")
42
-
43
- def get_mask(self, idx):
44
- mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool)
45
- bounding_box = self.bounding_boxes[idx]
46
- x0, y0, x1, y1 = bounding_box
47
- mask[:, y0:y1, x0:x1] = 0
48
- return mask
49
-
50
- def __len__(self):
51
- return len(self.image_paths)
52
-
53
- def __getitem__(self, index):
54
- impath = self.image_paths[index]
55
- if PYSPNG_IMPORTED:
56
- with open(impath, "rb") as fp:
57
- im = pyspng.load(fp.read())
58
- else:
59
- with Image.open(impath) as fp:
60
- im = np.array(fp)
61
- im = torch.from_numpy(np.rollaxis(im, -1, 0))
62
- masks = self.get_mask(index)
63
- landmark = self.landmarks[index]
64
- batch = {
65
- "img": im,
66
- "mask": masks,
67
- }
68
- if self.load_keypoints:
69
- batch["keypoints"] = landmark
70
- if self.transform is None:
71
- return batch
72
- return self.transform(batch)
73
-
74
-
75
- class FDF256Dataset:
76
-
77
- def __init__(self,
78
- dirpath,
79
- load_keypoints: bool,
80
- transform):
81
- dirpath = pathlib.Path(dirpath)
82
- self.dirpath = dirpath
83
- self.transform = transform
84
- self.load_keypoints = load_keypoints
85
- assert self.dirpath.is_dir(),\
86
- f"Did not find dataset at: {dirpath}"
87
- image_dir = self.dirpath.joinpath("images")
88
- self.image_paths = list(image_dir.glob("*.png"))
89
- assert len(self.image_paths) > 0,\
90
- f"Did not find images in: {image_dir}"
91
- self.image_paths.sort(key=lambda x: int(x.stem))
92
- self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
93
- self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy")))
94
- assert len(self.image_paths) == len(self.bounding_boxes)
95
- assert len(self.image_paths) == len(self.landmarks)
96
- logger.log(
97
- f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}")
98
-
99
- def get_mask(self, idx):
100
- mask = torch.ones((1, 256, 256), dtype=torch.bool)
101
- bounding_box = self.bounding_boxes[idx]
102
- x0, y0, x1, y1 = bounding_box
103
- mask[:, y0:y1, x0:x1] = 0
104
- return mask
105
-
106
- def __len__(self):
107
- return len(self.image_paths)
108
-
109
- def __getitem__(self, index):
110
- impath = self.image_paths[index]
111
- if PYSPNG_IMPORTED:
112
- with open(impath, "rb") as fp:
113
- im = pyspng.load(fp.read())
114
- else:
115
- with Image.open(impath) as fp:
116
- im = np.array(fp)
117
- im = torch.from_numpy(np.rollaxis(im, -1, 0))
118
- masks = self.get_mask(index)
119
- landmark = self.landmarks[index]
120
- batch = {
121
- "img": im,
122
- "mask": masks,
123
- }
124
- if self.load_keypoints:
125
- batch["keypoints"] = landmark
126
- if self.transform is None:
127
- return batch
128
- return self.transform(batch)
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/data/datasets/fdh.py DELETED
@@ -1,104 +0,0 @@
1
- import torch
2
- import tops
3
- import numpy as np
4
- import io
5
- import webdataset as wds
6
- import os
7
- from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
8
-
9
-
10
- def kp_decoder(x):
11
- # Keypoints are between [0, 1] for webdataset
12
- keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
13
- keypoints[:, 0] /= 160
14
- keypoints[:, 1] /= 288
15
- check_outside = lambda x: (x < 0).logical_or(x > 1)
16
- is_outside = check_outside(keypoints[:, 0]).logical_or(
17
- check_outside(keypoints[:, 1])
18
- )
19
- keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
20
- return keypoints
21
-
22
-
23
- def vertices_decoder(x):
24
- vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
25
- return vertices.squeeze()[None]
26
-
27
-
28
- def get_dataloader_fdh_wds(
29
- path,
30
- batch_size: int,
31
- num_workers: int,
32
- transform: torch.nn.Module,
33
- gpu_transform: torch.nn.Module,
34
- infinite: bool,
35
- shuffle: bool,
36
- partial_batches: bool,
37
- load_embedding: bool,
38
- sample_shuffle=10_000,
39
- tar_shuffle=100,
40
- read_condition=False,
41
- channels_last=False,
42
- ):
43
- # Need to set this for split_by_node to work.
44
- os.environ["RANK"] = str(tops.rank())
45
- os.environ["WORLD_SIZE"] = str(tops.world_size())
46
- if infinite:
47
- pipeline = [wds.ResampledShards(str(path))]
48
- else:
49
- pipeline = [wds.SimpleShardList(str(path))]
50
- if shuffle:
51
- pipeline.append(wds.shuffle(tar_shuffle))
52
- pipeline.extend([
53
- wds.split_by_node,
54
- wds.split_by_worker,
55
- ])
56
- if shuffle:
57
- pipeline.append(wds.shuffle(sample_shuffle))
58
-
59
- decoder = [
60
- wds.handle_extension("image.png", png_decoder),
61
- wds.handle_extension("mask.png", mask_decoder),
62
- wds.handle_extension("maskrcnn_mask.png", mask_decoder),
63
- wds.handle_extension("keypoints.npy", kp_decoder),
64
- ]
65
-
66
- rename_keys = [
67
- ["img", "image.png"], ["mask", "mask.png"],
68
- ["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"]
69
- ]
70
- if load_embedding:
71
- decoder.extend([
72
- wds.handle_extension("vertices.npy", vertices_decoder),
73
- wds.handle_extension("E_mask.png", mask_decoder)
74
- ])
75
- rename_keys.extend([
76
- ["vertices", "vertices.npy"],
77
- ["E_mask", "e_mask.png"]
78
- ])
79
-
80
- if read_condition:
81
- decoder.append(
82
- wds.handle_extension("condition.png", png_decoder)
83
- )
84
- rename_keys.append(["condition", "condition.png"])
85
-
86
- pipeline.extend([
87
- wds.tarfile_to_samples(),
88
- wds.decode(*decoder),
89
- wds.rename_keys(*rename_keys),
90
- wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
91
- ])
92
- if transform is not None:
93
- pipeline.append(wds.map(transform))
94
- pipeline = wds.DataPipeline(*pipeline)
95
- if infinite:
96
- pipeline = pipeline.repeat(nepochs=1000000)
97
-
98
- loader = wds.WebLoader(
99
- pipeline, batch_size=None, shuffle=False,
100
- num_workers=get_num_workers(num_workers),
101
- persistent_workers=True,
102
- )
103
- loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
104
- return loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/data/transforms/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize
2
- from .stylegan2_transform import StyleGANAugmentPipe
 
 
 
dp2/data/transforms/functional.py DELETED
@@ -1,61 +0,0 @@
1
- import torchvision.transforms.functional as F
2
- import torch
3
- import pickle
4
- from tops import download_file, assert_shape
5
- from typing import Dict
6
- from functools import lru_cache
7
-
8
- global symmetry_transform
9
-
10
- @lru_cache(maxsize=1)
11
- def get_symmetry_transform(symmetry_url):
12
- file_name = download_file(symmetry_url)
13
- with open(file_name, "rb") as fp:
14
- symmetry = pickle.load(fp)
15
- return torch.from_numpy(symmetry["vertex_transforms"]).long()
16
-
17
-
18
- hflip_handled_cases = set([
19
- "keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition",
20
- "embedding", "vertx2cat", "maskrcnn_mask", "__key__",
21
- "img_hr", "condition_hr", "mask_hr"])
22
-
23
- def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]:
24
- container["img"] = F.hflip(container["img"])
25
- if "condition" in container:
26
- container["condition"] = F.hflip(container["condition"])
27
- if "embedding" in container:
28
- container["embedding"] = F.hflip(container["embedding"])
29
- assert all([key in hflip_handled_cases for key in container]), container.keys()
30
- if "keypoints" in container:
31
- assert flip_map is not None
32
- if container["keypoints"].ndim == 3:
33
- keypoints = container["keypoints"][:, flip_map, :]
34
- keypoints[:, :, 0] = 1 - keypoints[:, :, 0]
35
- else:
36
- assert_shape(container["keypoints"], (None, 3))
37
- keypoints = container["keypoints"][flip_map, :]
38
- keypoints[:, 0] = 1 - keypoints[:, 0]
39
- container["keypoints"] = keypoints
40
- if "mask" in container:
41
- container["mask"] = F.hflip(container["mask"])
42
- if "border" in container:
43
- container["border"] = F.hflip(container["border"])
44
- if "semantic_mask" in container:
45
- container["semantic_mask"] = F.hflip(container["semantic_mask"])
46
- if "vertices" in container:
47
- symmetry_transform = get_symmetry_transform("https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl")
48
- container["vertices"] = F.hflip(container["vertices"])
49
- symmetry_transform_ = symmetry_transform.to(container["vertices"].device)
50
- container["vertices"] = symmetry_transform_[container["vertices"].long()]
51
- if "E_mask" in container:
52
- container["E_mask"] = F.hflip(container["E_mask"])
53
- if "maskrcnn_mask" in container:
54
- container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"])
55
- if "img_hr" in container:
56
- container["img_hr"] = F.hflip(container["img_hr"])
57
- if "condition_hr" in container:
58
- container["condition_hr"] = F.hflip(container["condition_hr"])
59
- if "mask_hr" in container:
60
- container["mask_hr"] = F.hflip(container["mask_hr"])
61
- return container
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/data/transforms/stylegan2_transform.py DELETED
@@ -1,394 +0,0 @@
1
- import numpy as np
2
- import scipy.signal
3
- import torch
4
- try:
5
- from sg3_torch_utils import misc
6
- from sg3_torch_utils.ops import upfirdn2d
7
- from sg3_torch_utils.ops import grid_sample_gradfix
8
- from sg3_torch_utils.ops import conv2d_gradfix
9
- except:
10
- pass
11
- #----------------------------------------------------------------------------
12
- # Coefficients of various wavelet decomposition low-pass filters.
13
-
14
- wavelets = {
15
- 'haar': [0.7071067811865476, 0.7071067811865476],
16
- 'db1': [0.7071067811865476, 0.7071067811865476],
17
- 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
18
- 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
19
- 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
20
- 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
21
- 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
22
- 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
23
- 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
24
- 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
25
- 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
26
- 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
27
- 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
28
- 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
29
- 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
30
- 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
31
- }
32
-
33
- #----------------------------------------------------------------------------
34
- # Helpers for constructing transformation matrices.
35
-
36
-
37
- def matrix(*rows, device=None):
38
- assert all(len(row) == len(rows[0]) for row in rows)
39
- elems = [x for row in rows for x in row]
40
- ref = [x for x in elems if isinstance(x, torch.Tensor)]
41
- if len(ref) == 0:
42
- return misc.constant(np.asarray(rows), device=device)
43
- assert device is None or device == ref[0].device
44
- elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
45
- return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
46
-
47
-
48
- def translate2d(tx, ty, **kwargs):
49
- return matrix(
50
- [1, 0, tx],
51
- [0, 1, ty],
52
- [0, 0, 1],
53
- **kwargs)
54
-
55
-
56
- def translate3d(tx, ty, tz, **kwargs):
57
- return matrix(
58
- [1, 0, 0, tx],
59
- [0, 1, 0, ty],
60
- [0, 0, 1, tz],
61
- [0, 0, 0, 1],
62
- **kwargs)
63
-
64
-
65
- def scale2d(sx, sy, **kwargs):
66
- return matrix(
67
- [sx, 0, 0],
68
- [0, sy, 0],
69
- [0, 0, 1],
70
- **kwargs)
71
-
72
-
73
- def scale3d(sx, sy, sz, **kwargs):
74
- return matrix(
75
- [sx, 0, 0, 0],
76
- [0, sy, 0, 0],
77
- [0, 0, sz, 0],
78
- [0, 0, 0, 1],
79
- **kwargs)
80
-
81
-
82
- def rotate2d(theta, **kwargs):
83
- return matrix(
84
- [torch.cos(theta), torch.sin(-theta), 0],
85
- [torch.sin(theta), torch.cos(theta), 0],
86
- [0, 0, 1],
87
- **kwargs)
88
-
89
-
90
- def rotate3d(v, theta, **kwargs):
91
- vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
92
- s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
93
- return matrix(
94
- [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
95
- [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
96
- [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
97
- [0, 0, 0, 1],
98
- **kwargs)
99
-
100
-
101
- def translate2d_inv(tx, ty, **kwargs):
102
- return translate2d(-tx, -ty, **kwargs)
103
-
104
-
105
- def scale2d_inv(sx, sy, **kwargs):
106
- return scale2d(1 / sx, 1 / sy, **kwargs)
107
-
108
-
109
- def rotate2d_inv(theta, **kwargs):
110
- return rotate2d(-theta, **kwargs)
111
-
112
-
113
- class StyleGANAugmentPipe(torch.nn.Module):
114
- def __init__(self,
115
- rotate90=0, xint=0, xint_max=0.125,
116
- scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
117
- brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5,
118
- hue_max=1, saturation_std=1,
119
- imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
120
- ):
121
- super().__init__()
122
- self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
123
-
124
- # Pixel blitting.
125
- self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
126
- self.xint = float(xint) # Probability multiplier for integer translation.
127
- self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
128
-
129
- # General geometric transformations.
130
- self.scale = float(scale) # Probability multiplier for isotropic scaling.
131
- self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
132
- self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
133
- self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
134
- self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
135
- self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
136
- self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
137
- self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
138
-
139
- # Color transformations.
140
- self.brightness = float(brightness) # Probability multiplier for brightness.
141
- self.contrast = float(contrast) # Probability multiplier for contrast.
142
- self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
143
- self.hue = float(hue) # Probability multiplier for hue rotation.
144
- self.saturation = float(saturation) # Probability multiplier for saturation.
145
- self.brightness_std = float(brightness_std) # Standard deviation of brightness.
146
- self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
147
- self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
148
- self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
149
-
150
- # Image-space filtering.
151
- self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
152
- self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
153
- self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
154
-
155
- # Setup orthogonal lowpass filter for geometric augmentations.
156
- self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
157
-
158
- # Construct filter bank for image-space filtering.
159
- Hz_lo = np.asarray(wavelets['sym2']) # H(z)
160
- Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
161
- Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
162
- Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
163
- Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
164
- for i in range(1, Hz_fbank.shape[0]):
165
- Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
166
- Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
167
- Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
168
- self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
169
-
170
- def forward(self, batch, debug_percentile=None):
171
- images = batch["img"]
172
- batch["vertices"] = batch["vertices"].float()
173
- assert isinstance(images, torch.Tensor) and images.ndim == 4
174
- batch_size, num_channels, height, width = images.shape
175
- device = images.device
176
- self.Hz_fbank = self.Hz_fbank.to(device)
177
- self.Hz_geom = self.Hz_geom.to(device)
178
- if debug_percentile is not None:
179
- debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
180
-
181
- # -------------------------------------
182
- # Select parameters for pixel blitting.
183
- # -------------------------------------
184
-
185
- # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
186
- I_3 = torch.eye(3, device=device)
187
- G_inv = I_3
188
-
189
- # Apply integer translation with probability (xint * strength).
190
- if self.xint > 0:
191
- t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
192
- t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
193
- if debug_percentile is not None:
194
- t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
195
- G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
196
-
197
- # --------------------------------------------------------
198
- # Select parameters for general geometric transformations.
199
- # --------------------------------------------------------
200
-
201
- # Apply isotropic scaling with probability (scale * strength).
202
- if self.scale > 0:
203
- s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
204
- s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
205
- if debug_percentile is not None:
206
- s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
207
- G_inv = G_inv @ scale2d_inv(s, s)
208
-
209
- # Apply pre-rotation with probability p_rot.
210
- p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
211
- if self.rotate > 0:
212
- theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
213
- theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
214
- if debug_percentile is not None:
215
- theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
216
- G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
217
-
218
- # Apply anisotropic scaling with probability (aniso * strength).
219
- if self.aniso > 0:
220
- s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
221
- s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
222
- if debug_percentile is not None:
223
- s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
224
- G_inv = G_inv @ scale2d_inv(s, 1 / s)
225
-
226
- # Apply post-rotation with probability p_rot.
227
- if self.rotate > 0:
228
- theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
229
- theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
230
- if debug_percentile is not None:
231
- theta = torch.zeros_like(theta)
232
- G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
233
-
234
- # Apply fractional translation with probability (xfrac * strength).
235
- if self.xfrac > 0:
236
- t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
237
- t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
238
- if debug_percentile is not None:
239
- t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
240
- G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
241
-
242
- # ----------------------------------
243
- # Execute geometric transformations.
244
- # ----------------------------------
245
-
246
- # Execute if the transform is not identity.
247
- if G_inv is not I_3:
248
- # Calculate padding.
249
- cx = (width - 1) / 2
250
- cy = (height - 1) / 2
251
- cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
252
- cp = G_inv @ cp.t() # [batch, xyz, idx]
253
- Hz_pad = self.Hz_geom.shape[0] // 4
254
- margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
255
- margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
256
- margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
257
- margin = margin.max(misc.constant([0, 0] * 2, device=device))
258
- margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
259
- mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
260
-
261
- # Pad image and adjust origin.
262
- images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
263
- batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0)
264
- batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
265
- batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
266
- G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
267
-
268
- # Upsample.
269
- images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
270
- batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest")
271
- batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest")
272
- batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest")
273
- G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
274
- G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
275
-
276
- # Execute transformation.
277
- shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
278
- G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
279
- grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
280
- images = grid_sample_gradfix.grid_sample(images, grid)
281
-
282
- batch["mask"] = torch.nn.functional.grid_sample(
283
- input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
284
- batch["E_mask"] = torch.nn.functional.grid_sample(
285
- input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
286
- batch["vertices"] = torch.nn.functional.grid_sample(
287
- input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
288
-
289
-
290
- # Downsample and crop.
291
- images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
292
- batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
293
- batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
294
- batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
295
- # --------------------------------------------
296
- # Select parameters for color transformations.
297
- # --------------------------------------------
298
-
299
- # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
300
- I_4 = torch.eye(4, device=device)
301
- C = I_4
302
-
303
- # Apply brightness with probability (brightness * strength).
304
- if self.brightness > 0:
305
- b = torch.randn([batch_size], device=device) * self.brightness_std
306
- b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
307
- if debug_percentile is not None:
308
- b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
309
- C = translate3d(b, b, b) @ C
310
-
311
- # Apply contrast with probability (contrast * strength).
312
- if self.contrast > 0:
313
- c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
314
- c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
315
- if debug_percentile is not None:
316
- c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
317
- C = scale3d(c, c, c) @ C
318
-
319
- # Apply luma flip with probability (lumaflip * strength).
320
- v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
321
-
322
- # Apply hue rotation with probability (hue * strength).
323
- if self.hue > 0 and num_channels > 1:
324
- theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
325
- theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
326
- if debug_percentile is not None:
327
- theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
328
- C = rotate3d(v, theta) @ C # Rotate around v.
329
-
330
- # Apply saturation with probability (saturation * strength).
331
- if self.saturation > 0 and num_channels > 1:
332
- s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
333
- s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
334
- if debug_percentile is not None:
335
- s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
336
- C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
337
-
338
- # ------------------------------
339
- # Execute color transformations.
340
- # ------------------------------
341
-
342
- # Execute if the transform is not identity.
343
- if C is not I_4:
344
- images = images.reshape([batch_size, num_channels, height * width])
345
- if num_channels == 3:
346
- images = C[:, :3, :3] @ images + C[:, :3, 3:]
347
- elif num_channels == 1:
348
- C = C[:, :3, :].mean(dim=1, keepdims=True)
349
- images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
350
- else:
351
- raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
352
- images = images.reshape([batch_size, num_channels, height, width])
353
-
354
- # ----------------------
355
- # Image-space filtering.
356
- # ----------------------
357
-
358
- if self.imgfilter > 0:
359
- num_bands = self.Hz_fbank.shape[0]
360
- assert len(self.imgfilter_bands) == num_bands
361
- expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
362
-
363
- # Apply amplification for each band with probability (imgfilter * strength * band_strength).
364
- g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
365
- for i, band_strength in enumerate(self.imgfilter_bands):
366
- t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
367
- t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
368
- if debug_percentile is not None:
369
- t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
370
- t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
371
- t[:, i] = t_i # Replace i'th element.
372
- t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
373
- g = g * t # Accumulate into global gain.
374
-
375
- # Construct combined amplification filter.
376
- Hz_prime = g @ self.Hz_fbank # [batch, tap]
377
- Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
378
- Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
379
-
380
- # Apply filter.
381
- p = self.Hz_fbank.shape[1] // 2
382
- images = images.reshape([1, batch_size * num_channels, height, width])
383
- images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
384
- images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
385
- images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
386
- images = images.reshape([batch_size, num_channels, height, width])
387
-
388
- # ------------------------
389
- # Image-space corruptions.
390
- # ------------------------
391
- batch["img"] = images
392
- batch["vertices"] = batch["vertices"].long()
393
- batch["border"] = 1 - batch["E_mask"] - batch["mask"]
394
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/data/transforms/transforms.py DELETED
@@ -1,247 +0,0 @@
1
- from pathlib import Path
2
- from typing import Dict, List
3
- import torchvision
4
- import torch
5
- import tops
6
- import torchvision.transforms.functional as F
7
- from .functional import hflip
8
-
9
-
10
- class RandomHorizontalFlip(torch.nn.Module):
11
-
12
- def __init__(self, p: float, flip_map=None,**kwargs):
13
- super().__init__()
14
- self.flip_ratio = p
15
- self.flip_map = flip_map
16
- if self.flip_ratio is None:
17
- self.flip_ratio = 0.5
18
- assert 0 <= self.flip_ratio <= 1
19
-
20
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
21
- if torch.rand(1) > self.flip_ratio:
22
- return container
23
- return hflip(container, self.flip_map)
24
-
25
-
26
- class CenterCrop(torch.nn.Module):
27
- """
28
- Performs the transform on the image.
29
- NOTE: Does not transform the mask to improve runtime.
30
- """
31
-
32
- def __init__(self, size: List[int]):
33
- super().__init__()
34
- self.size = tuple(size)
35
-
36
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
37
- min_size = min(container["img"].shape[1], container["img"].shape[2])
38
- if min_size < self.size[0]:
39
- container["img"] = F.center_crop(container["img"], min_size)
40
- container["img"] = F.resize(container["img"], self.size)
41
- return container
42
- container["img"] = F.center_crop(container["img"], self.size)
43
- return container
44
-
45
-
46
- class Resize(torch.nn.Module):
47
- """
48
- Performs the transform on the image.
49
- NOTE: Does not transform the mask to improve runtime.
50
- """
51
-
52
- def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
53
- super().__init__()
54
- self.size = tuple(size)
55
- self.interpolation = interpolation
56
-
57
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
58
- container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
59
- if "semantic_mask" in container:
60
- container["semantic_mask"] = F.resize(
61
- container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
62
- if "embedding" in container:
63
- container["embedding"] = F.resize(
64
- container["embedding"], self.size, self.interpolation)
65
- if "mask" in container:
66
- container["mask"] = F.resize(
67
- container["mask"], self.size, F.InterpolationMode.NEAREST)
68
- if "E_mask" in container:
69
- container["E_mask"] = F.resize(
70
- container["E_mask"], self.size, F.InterpolationMode.NEAREST)
71
- if "maskrcnn_mask" in container:
72
- container["maskrcnn_mask"] = F.resize(
73
- container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
74
- if "vertices" in container:
75
- container["vertices"] = F.resize(
76
- container["vertices"], self.size, F.InterpolationMode.NEAREST)
77
- return container
78
-
79
- def __repr__(self):
80
- repr = super().__repr__()
81
- vars_ = dict(size=self.size, interpolation=self.interpolation)
82
- return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
83
-
84
-
85
- class InsertHRImage(torch.nn.Module):
86
- """
87
- Resizes mask by maxpool and assumes condition is already created
88
- """
89
- def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
90
- super().__init__()
91
- self.size = tuple(size)
92
- self.interpolation = interpolation
93
-
94
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
95
- assert container["img"].dtype == torch.float32
96
- container["img_hr"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
97
- container["condition_hr"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
98
- mask = container["mask"] > 0
99
- container["mask_hr"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
100
- container["condition_hr"] = container["condition_hr"] * (1 - container["mask_hr"]) + container["img_hr"] * container["mask_hr"]
101
- return container
102
-
103
- def __repr__(self):
104
- repr = super().__repr__()
105
- vars_ = dict(size=self.size, interpolation=self.interpolation)
106
- return repr + " "
107
-
108
-
109
- class CopyHRImage(torch.nn.Module):
110
- def __init__(self) -> None:
111
- super().__init__()
112
-
113
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
114
- container["img_hr"] = container["img"]
115
- container["condition_hr"] = container["condition"]
116
- container["mask_hr"] = container["mask"]
117
- return container
118
-
119
-
120
- class Resize2(torch.nn.Module):
121
- """
122
- Resizes mask by maxpool and assumes condition is already created
123
- """
124
- def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, downsample_condition: bool = True, mask_condition= True):
125
- super().__init__()
126
- self.size = tuple(size)
127
- self.interpolation = interpolation
128
- self.downsample_condition = downsample_condition
129
- self.mask_condition = mask_condition
130
-
131
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
132
- # assert container["img"].dtype == torch.float32
133
- container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
134
- mask = container["mask"] > 0
135
- container["mask"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
136
-
137
- if self.downsample_condition:
138
- container["condition"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
139
- if self.mask_condition:
140
- container["condition"] = container["condition"] * (1 - container["mask"]) + container["img"] * container["mask"]
141
- return container
142
-
143
- def __repr__(self):
144
- repr = super().__repr__()
145
- vars_ = dict(size=self.size, interpolation=self.interpolation)
146
- return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
147
-
148
-
149
-
150
- class Normalize(torch.nn.Module):
151
- """
152
- Performs the transform on the image.
153
- NOTE: Does not transform the mask to improve runtime.
154
- """
155
-
156
- def __init__(self, mean, std, inplace, keys=["img"]):
157
- super().__init__()
158
- self.mean = mean
159
- self.std = std
160
- self.inplace = inplace
161
- self.keys = keys
162
-
163
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
164
- for key in self.keys:
165
- container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
166
- return container
167
-
168
- def __repr__(self):
169
- repr = super().__repr__()
170
- vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
171
- return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
172
-
173
-
174
- class ToFloat(torch.nn.Module):
175
-
176
- def __init__(self, keys=["img"], norm=True) -> None:
177
- super().__init__()
178
- self.keys = keys
179
- self.gain = 255 if norm else 1
180
-
181
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
182
- for key in self.keys:
183
- container[key] = container[key].float() / self.gain
184
- return container
185
-
186
-
187
- class RandomCrop(torchvision.transforms.RandomCrop):
188
- """
189
- Performs the transform on the image.
190
- NOTE: Does not transform the mask to improve runtime.
191
- """
192
-
193
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
194
- container["img"] = super().forward(container["img"])
195
- return container
196
-
197
-
198
- class CreateCondition(torch.nn.Module):
199
-
200
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
201
- if container["img"].dtype == torch.uint8:
202
- container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
203
- return container
204
- container["condition"] = container["img"] * container["mask"]
205
- return container
206
-
207
-
208
- class CreateEmbedding(torch.nn.Module):
209
-
210
- def __init__(self, embed_path: Path, cuda=True) -> None:
211
- super().__init__()
212
- self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
213
- if cuda:
214
- self.embed_map = tops.to_cuda(self.embed_map)
215
-
216
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
217
- vertices = container["vertices"]
218
- if vertices.ndim == 3:
219
- embedding = self.embed_map[vertices.long()].squeeze(dim=0)
220
- embedding = embedding.permute(2, 0, 1) * container["E_mask"]
221
- pass
222
- else:
223
- assert vertices.ndim == 4
224
- embedding = self.embed_map[vertices.long()].squeeze(dim=1)
225
- embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
226
- container["embedding"] = embedding
227
- container["embed_map"] = self.embed_map.clone()
228
- return container
229
-
230
-
231
- class UpdateMask(torch.nn.Module):
232
-
233
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
234
- container["mask"] = (container["img"] == container["condition"]).any(dim=1, keepdims=True).float()
235
- return container
236
-
237
-
238
- class LoadClassEmbedding(torch.nn.Module):
239
-
240
- def __init__(self, embedding_path: Path) -> None:
241
- super().__init__()
242
- self.embedding = torch.load(embedding_path, map_location="cpu")
243
-
244
- def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
245
- key = "_".join(container["__key__"].split("train/")[-1].split("/")[:-1])
246
- container["class_embedding"] = self.embedding[key].view(-1)
247
- return container
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/data/utils.py DELETED
@@ -1,102 +0,0 @@
1
- import torch
2
- from PIL import Image
3
- import numpy as np
4
- import multiprocessing
5
- import io
6
- from tops import logger
7
- from torch.utils.data._utils.collate import default_collate
8
-
9
- try:
10
- import pyspng
11
-
12
- PYSPNG_IMPORTED = True
13
- except ImportError:
14
- PYSPNG_IMPORTED = False
15
- print("Could not load pyspng. Defaulting to pillow image backend.")
16
- from PIL import Image
17
-
18
-
19
- def get_coco_keypoints():
20
- return [
21
- "nose",
22
- "left_eye",
23
- "right_eye",
24
- "left_ear",
25
- "right_ear",
26
- "left_shoulder",
27
- "right_shoulder",
28
- "left_elbow",
29
- "right_elbow",
30
- "left_wrist",
31
- "right_wrist",
32
- "left_hip",
33
- "right_hip",
34
- "left_knee",
35
- "right_knee",
36
- "left_ankle",
37
- "right_ankle",
38
- ]
39
-
40
-
41
- def get_coco_flipmap():
42
- keypoints = get_coco_keypoints()
43
- keypoint_flip_map = {
44
- "left_eye": "right_eye",
45
- "left_ear": "right_ear",
46
- "left_shoulder": "right_shoulder",
47
- "left_elbow": "right_elbow",
48
- "left_wrist": "right_wrist",
49
- "left_hip": "right_hip",
50
- "left_knee": "right_knee",
51
- "left_ankle": "right_ankle",
52
- }
53
- for key, value in list(keypoint_flip_map.items()):
54
- keypoint_flip_map[value] = key
55
- keypoint_flip_map["nose"] = "nose"
56
- keypoint_flip_map_idx = []
57
- for source in keypoints:
58
- keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
59
- return keypoint_flip_map_idx
60
-
61
-
62
- def mask_decoder(x):
63
- mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
64
- mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255.
65
- return mask
66
-
67
-
68
- def png_decoder(x):
69
- if PYSPNG_IMPORTED:
70
- return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
71
- with Image.open(io.BytesIO(x)) as im:
72
- im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
73
- return im
74
-
75
-
76
- def jpg_decoder(x):
77
- with Image.open(io.BytesIO(x)) as im:
78
- im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
79
- return im
80
-
81
-
82
- def get_num_workers(num_workers: int):
83
- n_cpus = multiprocessing.cpu_count()
84
- if num_workers > n_cpus:
85
- logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
86
- return n_cpus
87
- return num_workers
88
-
89
-
90
- def collate_fn(batch):
91
- elem = batch[0]
92
- ignore_keys = set(["embed_map", "vertx2cat"])
93
- batch_ = {
94
- key: default_collate([d[key] for d in batch])
95
- for key in elem
96
- if key not in ignore_keys
97
- }
98
- if "embed_map" in elem:
99
- batch_["embed_map"] = elem["embed_map"]
100
- if "vertx2cat" in elem:
101
- batch_["vertx2cat"] = elem["vertx2cat"]
102
- return batch_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/detection/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .cse_mask_face_detector import CSeMaskFaceDetector
2
- from .person_detector import CSEPersonDetector
3
- from .structures import PersonDetection, VehicleDetection, FaceDetection
 
 
 
 
dp2/detection/base.py DELETED
@@ -1,45 +0,0 @@
1
- import pickle
2
- import torch
3
- import lzma
4
- from pathlib import Path
5
- from tops import logger
6
-
7
-
8
- class BaseDetector:
9
-
10
-
11
- def __init__(self, cache_directory: str) -> None:
12
- if cache_directory is not None:
13
- self.cache_directory = Path(cache_directory, str(self.__class__.__name__))
14
- self.cache_directory.mkdir(exist_ok=True, parents=True)
15
-
16
- def save_to_cache(self, detection, cache_path: Path, after_preprocess=True):
17
- logger.log(f"Caching detection to: {cache_path}")
18
- with lzma.open(cache_path, "wb") as fp:
19
- torch.save(
20
- [det.state_dict(after_preprocess=after_preprocess) for det in detection], fp,
21
- pickle_protocol=pickle.HIGHEST_PROTOCOL)
22
-
23
- def load_from_cache(self, cache_path: Path):
24
- logger.log(f"Loading detection from cache path: {cache_path}")
25
- with lzma.open(cache_path, "rb") as fp:
26
- state_dict = torch.load(fp)
27
- return [
28
- state["cls"].from_state_dict(state_dict=state) for state in state_dict
29
- ]
30
-
31
- def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool):
32
- if cache_id is None:
33
- return self.forward(im)
34
- cache_path = self.cache_directory.joinpath(cache_id + ".torch")
35
- if cache_path.is_file() and load_cache:
36
- try:
37
- return self.load_from_cache(cache_path)
38
- except Exception as e:
39
- logger.warn(f"The cache file was corrupted: {cache_path}")
40
- exit()
41
- detections = self.forward(im)
42
- self.save_to_cache(detections, cache_path)
43
- return detections
44
-
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/detection/box_utils.py DELETED
@@ -1,104 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio):
5
- x0, y0, x1, y1 = [int(_) for _ in bbox]
6
- h, w = y1 - y0, x1 - x0
7
- cur_ratio = h / w
8
-
9
- if cur_ratio == target_aspect_ratio:
10
- return [x0, y0, x1, y1]
11
- if cur_ratio < target_aspect_ratio:
12
- target_height = int(w*target_aspect_ratio)
13
- y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
14
- else:
15
- target_width = int(h/target_aspect_ratio)
16
- x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
17
- return x0, y0, x1, y1
18
-
19
-
20
- def expand_axis(start, end, target_width, limit):
21
- # Can return a bbox outside of limit
22
- cur_width = end - start
23
- start = start - (target_width-cur_width)//2
24
- end = end + (target_width-cur_width)//2
25
- if end - start != target_width:
26
- end += 1
27
- assert end - start == target_width
28
- if start < 0 and end > limit:
29
- return start, end
30
- if start < 0 and end < limit:
31
- to_shift = min(0 - start, limit - end)
32
- start += to_shift
33
- end += to_shift
34
- if end > limit and start > 0:
35
- to_shift = min(end - limit, start)
36
- end -= to_shift
37
- start -= to_shift
38
- assert end - start == target_width
39
- return start, end
40
-
41
-
42
- def expand_box(bbox, imshape, mask, percentage_background: float):
43
- assert isinstance(bbox[0], int)
44
- assert 0 < percentage_background < 1
45
- # Percentage in S
46
- mask_pixels = mask.long().sum().cpu()
47
- total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
48
- percentage_mask = mask_pixels / total_pixels
49
- if (1 - percentage_mask) > percentage_background:
50
- return bbox
51
- target_pixels = mask_pixels / (1 - percentage_background)
52
- x0, y0, x1, y1 = bbox
53
- H = y1 - y0
54
- W = x1 - x0
55
- p = np.sqrt(target_pixels/(H*W))
56
- target_width = int(np.ceil(p * W))
57
- target_height = int(np.ceil(p * H))
58
- x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
59
- y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
60
- return [x0, y0, x1, y1]
61
-
62
-
63
- def expand_axises_by_percentage(bbox_XYXY, imshape, percentage):
64
- x0, y0, x1, y1 = bbox_XYXY
65
- H = y1 - y0
66
- W = x1 - x0
67
- expansion = int(((H*W)**0.5) * percentage)
68
- new_width = W + expansion
69
- new_height = H + expansion
70
- x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1])
71
- y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0])
72
- return [x0, y0, x1, y1]
73
-
74
-
75
- def get_expanded_bbox(
76
- bbox_XYXY,
77
- imshape,
78
- mask,
79
- percentage_background: float,
80
- axis_minimum_expansion: float,
81
- target_aspect_ratio: float):
82
- bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist()
83
- # Expand each axis of the bounding box by a minimum percentage
84
- bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion)
85
- # Find the minimum bbox with the aspect ratio. Can be outside of imshape
86
- bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio)
87
- # Expands square box such that X% of the bbox is background
88
- bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background)
89
- assert isinstance(bbox_XYXY[0], (int, np.int64))
90
- return bbox_XYXY
91
-
92
-
93
- def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape):
94
- def area_inside_ratio(bbox, imshape):
95
- area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
96
- area_inside = (min(bbox[2], imshape[1]) - max(0,bbox[0])) * (min(imshape[0],bbox[3]) - max(0,bbox[1]))
97
- return area_inside / area
98
- ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0])
99
- area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
100
- if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside:
101
- return False
102
- if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area:
103
- return False
104
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/detection/box_utils_fdf.py DELETED
@@ -1,203 +0,0 @@
1
- """
2
- The FDF dataset expands bound boxes differently from what is used for CSE.
3
- """
4
-
5
- import numpy as np
6
-
7
-
8
- def quadratic_bounding_box(x0, y0, width, height, imshape):
9
- # We assume that we can create a image that is quadratic without
10
- # minimizing any of the sides
11
- assert width <= min(imshape[:2])
12
- assert height <= min(imshape[:2])
13
- min_side = min(height, width)
14
- if height != width:
15
- side_diff = abs(height - width)
16
- # Want to extend the shortest side
17
- if min_side == height:
18
- # Vertical side
19
- height += side_diff
20
- if height > imshape[0]:
21
- # Take full frame, and shrink width
22
- y0 = 0
23
- height = imshape[0]
24
-
25
- side_diff = abs(height - width)
26
- width -= side_diff
27
- x0 += side_diff // 2
28
- else:
29
- y0 -= side_diff // 2
30
- y0 = max(0, y0)
31
- else:
32
- # Horizontal side
33
- width += side_diff
34
- if width > imshape[1]:
35
- # Take full frame width, and shrink height
36
- x0 = 0
37
- width = imshape[1]
38
-
39
- side_diff = abs(height - width)
40
- height -= side_diff
41
- y0 += side_diff // 2
42
- else:
43
- x0 -= side_diff // 2
44
- x0 = max(0, x0)
45
- # Check that bbox goes outside image
46
- x1 = x0 + width
47
- y1 = y0 + height
48
- if imshape[1] < x1:
49
- diff = x1 - imshape[1]
50
- x0 -= diff
51
- if imshape[0] < y1:
52
- diff = y1 - imshape[0]
53
- y0 -= diff
54
- assert x0 >= 0, "Bounding box outside image."
55
- assert y0 >= 0, "Bounding box outside image."
56
- assert x0 + width <= imshape[1], "Bounding box outside image."
57
- assert y0 + height <= imshape[0], "Bounding box outside image."
58
- return x0, y0, width, height
59
-
60
-
61
- def expand_bounding_box(bbox, percentage, imshape):
62
- orig_bbox = bbox.copy()
63
- x0, y0, x1, y1 = bbox
64
- width = x1 - x0
65
- height = y1 - y0
66
- x0, y0, width, height = quadratic_bounding_box(
67
- x0, y0, width, height, imshape)
68
- expanding_factor = int(max(height, width) * percentage)
69
-
70
- possible_max_expansion = [(imshape[0] - width) // 2,
71
- (imshape[1] - height) // 2,
72
- expanding_factor]
73
-
74
- expanding_factor = min(possible_max_expansion)
75
- # Expand height
76
-
77
- if expanding_factor > 0:
78
-
79
- y0 = y0 - expanding_factor
80
- y0 = max(0, y0)
81
-
82
- height += expanding_factor * 2
83
- if height > imshape[0]:
84
- y0 -= (imshape[0] - height)
85
- height = imshape[0]
86
-
87
- if height + y0 > imshape[0]:
88
- y0 -= (height + y0 - imshape[0])
89
-
90
- # Expand width
91
- x0 = x0 - expanding_factor
92
- x0 = max(0, x0)
93
-
94
- width += expanding_factor * 2
95
- if width > imshape[1]:
96
- x0 -= (imshape[1] - width)
97
- width = imshape[1]
98
-
99
- if width + x0 > imshape[1]:
100
- x0 -= (width + x0 - imshape[1])
101
- y1 = y0 + height
102
- x1 = x0 + width
103
- assert y0 >= 0, "Y0 is minus"
104
- assert height <= imshape[0], "Height is larger than image."
105
- assert x0 + width <= imshape[1]
106
- assert y0 + height <= imshape[0]
107
- assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!"
108
- assert x0 >= 0, "Y0 is minus"
109
- assert width <= imshape[1], "Height is larger than image."
110
- # Check that original bbox is within new
111
- x0_o, y0_o, x1_o, y1_o = orig_bbox
112
- assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}"
113
- assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}"
114
- assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}"
115
- assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}"
116
-
117
- x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]]
118
- x1 = x0 + width
119
- y1 = y0 + height
120
- return np.array([x0, y0, x1, y1])
121
-
122
-
123
- def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
124
- keypoint = keypoint[:, :3] # only nose + eyes are relevant
125
- kp_X = keypoint[0, :]
126
- kp_Y = keypoint[1, :]
127
- within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1)
128
- within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1)
129
- return within_X and within_Y
130
-
131
-
132
- def expand_bbox_simple(bbox, percentage):
133
- x0, y0, x1, y1 = bbox.astype(float)
134
- width = x1 - x0
135
- height = y1 - y0
136
- x_c = int(x0) + width // 2
137
- y_c = int(y0) + height // 2
138
- avg_size = max(width, height)
139
- new_width = avg_size * (1 + percentage)
140
- x0 = x_c - new_width // 2
141
- y0 = y_c - new_width // 2
142
- x1 = x_c + new_width // 2
143
- y1 = y_c + new_width // 2
144
- return np.array([x0, y0, x1, y1]).astype(int)
145
-
146
-
147
- def pad_image(im, bbox, pad_value):
148
- x0, y0, x1, y1 = bbox
149
- if x0 < 0:
150
- pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]),
151
- dtype=np.uint8) + pad_value
152
- im = np.concatenate((pad_im, im), axis=1)
153
- x1 += abs(x0)
154
- x0 = 0
155
- if y0 < 0:
156
- pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]),
157
- dtype=np.uint8) + pad_value
158
- im = np.concatenate((pad_im, im), axis=0)
159
- y1 += abs(y0)
160
- y0 = 0
161
- if x1 >= im.shape[1]:
162
- pad_im = np.zeros(
163
- (im.shape[0], x1 - im.shape[1] + 1, im.shape[2]),
164
- dtype=np.uint8) + pad_value
165
- im = np.concatenate((im, pad_im), axis=1)
166
- if y1 >= im.shape[0]:
167
- pad_im = np.zeros(
168
- (y1 - im.shape[0] + 1, im.shape[1], im.shape[2]),
169
- dtype=np.uint8) + pad_value
170
- im = np.concatenate((im, pad_im), axis=0)
171
- return im[y0:y1, x0:x1]
172
-
173
-
174
- def clip_box(bbox, im):
175
- bbox[0] = max(0, bbox[0])
176
- bbox[1] = max(0, bbox[1])
177
- bbox[2] = min(im.shape[1] - 1, bbox[2])
178
- bbox[3] = min(im.shape[0] - 1, bbox[3])
179
- return bbox
180
-
181
-
182
- def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True):
183
- outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0]
184
- if simple_expand or (outside_im and pad_im):
185
- return pad_image(im, bbox, pad_value)
186
- bbox = clip_box(bbox, im)
187
- x0, y0, x1, y1 = bbox
188
- return im[y0:y1, x0:x1]
189
-
190
-
191
- def expand_bbox(
192
- bbox_ltrb, imshape, simple_expand, default_to_simple=False,
193
- expansion_factor=0.35):
194
- assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox.shape}"
195
- bbox = bbox_ltrb.astype(float)
196
- # FDF256 uses simple expand with ratio 0.4
197
- if simple_expand:
198
- return expand_bbox_simple(bbox, 0.4)
199
- try:
200
- return expand_bounding_box(bbox, expansion_factor, imshape)
201
- except AssertionError:
202
- return expand_bbox_simple(bbox, expansion_factor * 2)
203
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/detection/cse_mask_face_detector.py DELETED
@@ -1,116 +0,0 @@
1
- import torch
2
- import lzma
3
- import tops
4
- from pathlib import Path
5
- from dp2.detection.base import BaseDetector
6
- from .utils import combine_cse_maskrcnn_dets
7
- from face_detection import build_detector as build_face_detector
8
- from .models.cse import CSEDetector
9
- from .models.mask_rcnn import MaskRCNNDetector
10
- from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection
11
- from tops import logger
12
-
13
-
14
- def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
15
- assert len(box1.shape) == 2
16
- assert len(box2.shape) == 2
17
- box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
18
- # This can be batched
19
- for i, box in enumerate(box1):
20
- is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
21
- is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
22
- is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
23
- box1_inside[i] = is_outside.logical_not().any()
24
- return box1_inside
25
-
26
-
27
- class CSeMaskFaceDetector(BaseDetector):
28
-
29
- def __init__(
30
- self,
31
- mask_rcnn_cfg,
32
- face_detector_cfg: dict,
33
- cse_cfg: dict,
34
- face_post_process_cfg: dict,
35
- cse_post_process_cfg,
36
- score_threshold: float,
37
- **kwargs
38
- ) -> None:
39
- super().__init__(**kwargs)
40
- self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
41
- if "confidence_threshold" not in face_detector_cfg:
42
- face_detector_cfg["confidence_threshold"] = score_threshold
43
- if "score_thres" not in cse_cfg:
44
- cse_cfg["score_thres"] = score_threshold
45
- self.cse_detector = CSEDetector(**cse_cfg)
46
- self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True)
47
- self.cse_post_process_cfg = cse_post_process_cfg
48
- self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
49
- self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold")
50
- self.face_post_process_cfg = face_post_process_cfg
51
-
52
- def __call__(self, *args, **kwargs):
53
- return self.forward(*args, **kwargs)
54
-
55
- def _detect_faces(self, im: torch.Tensor):
56
- H, W = im.shape[1:]
57
- im = im.float() - self.face_mean
58
- im = self.face_detector.resize(im[None], 1.0)
59
- boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
60
- boxes_XYXY[:, [0, 2]] *= W
61
- boxes_XYXY[:, [1, 3]] *= H
62
- return boxes_XYXY.round().long()
63
-
64
- def load_from_cache(self, cache_path: Path):
65
- logger.log(f"Loading detection from cache path: {cache_path}",)
66
- with lzma.open(cache_path, "rb") as fp:
67
- state_dict = torch.load(fp, map_location="cpu")
68
- kwargs = dict(
69
- post_process_cfg=self.cse_post_process_cfg,
70
- embed_map=self.cse_detector.embed_map,
71
- **self.face_post_process_cfg
72
- )
73
- return [
74
- state["cls"].from_state_dict(**kwargs, state_dict=state)
75
- for state in state_dict
76
- ]
77
-
78
- @torch.no_grad()
79
- def forward(self, im: torch.Tensor):
80
- maskrcnn_dets = self.mask_rcnn(im)
81
- cse_dets = self.cse_detector(im)
82
- embed_map = self.cse_detector.embed_map
83
- print("Calling face detector.")
84
- face_boxes = self._detect_faces(im).cpu()
85
- maskrcnn_person = {
86
- k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items()
87
- }
88
- maskrcnn_other = {
89
- k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items()
90
- }
91
- maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"])
92
- combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets(
93
- maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold)
94
-
95
- persons_with_cse = CSEPersonDetection(
96
- combined_segmentation, cse_dets, **self.cse_post_process_cfg,
97
- embed_map=embed_map,orig_imshape_CHW=im.shape
98
- )
99
- persons_with_cse.pre_process()
100
- not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]]
101
- persons_without_cse = PersonDetection(
102
- maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg,
103
- orig_imshape_CHW=im.shape
104
- )
105
- persons_without_cse.pre_process()
106
-
107
- face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or(
108
- box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes)
109
- )
110
- face_boxes = face_boxes[face_boxes_covered.logical_not()]
111
- face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
112
-
113
- # Order matters. The anonymizer will anonymize FIFO.
114
- # Later detections will overwrite.
115
- all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse]
116
- return all_detections
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/detection/face_detector.py DELETED
@@ -1,62 +0,0 @@
1
- import torch
2
- import lzma
3
- import tops
4
- from pathlib import Path
5
- from dp2.detection.base import BaseDetector
6
- from face_detection import build_detector as build_face_detector
7
- from .structures import FaceDetection
8
- from tops import logger
9
-
10
-
11
- def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
12
- assert len(box1.shape) == 2
13
- assert len(box2.shape) == 2
14
- box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
15
- # This can be batched
16
- for i, box in enumerate(box1):
17
- is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
18
- is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
19
- is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
20
- box1_inside[i] = is_outside.logical_not().any()
21
- return box1_inside
22
-
23
-
24
- class FaceDetector(BaseDetector):
25
-
26
- def __init__(
27
- self,
28
- face_detector_cfg: dict,
29
- score_threshold: float,
30
- face_post_process_cfg: dict,
31
- **kwargs
32
- ) -> None:
33
- super().__init__(**kwargs)
34
- self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
35
- self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
36
- self.face_post_process_cfg = face_post_process_cfg
37
-
38
- def __call__(self, *args, **kwargs):
39
- return self.forward(*args, **kwargs)
40
-
41
- def _detect_faces(self, im: torch.Tensor):
42
- H, W = im.shape[1:]
43
- im = im.float() - self.face_mean
44
- im = self.face_detector.resize(im[None], 1.0)
45
- boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
46
- boxes_XYXY[:, [0, 2]] *= W
47
- boxes_XYXY[:, [1, 3]] *= H
48
- return boxes_XYXY.round().long().cpu()
49
-
50
- @torch.no_grad()
51
- def forward(self, im: torch.Tensor):
52
- face_boxes = self._detect_faces(im)
53
- face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
54
- return [face_boxes]
55
-
56
- def load_from_cache(self, cache_path: Path):
57
- logger.log(f"Loading detection from cache path: {cache_path}")
58
- with lzma.open(cache_path, "rb") as fp:
59
- state_dict = torch.load(fp)
60
- return [
61
- state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict
62
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/detection/models/__init__.py DELETED
File without changes
dp2/detection/models/cse.py DELETED
@@ -1,135 +0,0 @@
1
- import torch
2
- from typing import List
3
- import tops
4
- from torchvision.transforms.functional import InterpolationMode, resize
5
- from densepose.data.utils import get_class_to_mesh_name_mapping
6
- from densepose import add_densepose_config
7
- from densepose.structures import DensePoseEmbeddingPredictorOutput
8
- from densepose.vis.extractor import DensePoseOutputsExtractor
9
- from densepose.modeling import build_densepose_embedder
10
- from detectron2.config import get_cfg
11
- from detectron2.data.transforms import ResizeShortestEdge
12
- from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer
13
- from detectron2.modeling import build_model
14
-
15
-
16
- model_urls = {
17
- "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl",
18
- "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl",
19
- }
20
-
21
-
22
- def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape):
23
- assert len(S.shape) == 3
24
- H, W = imshape
25
- N = len(boxes_XYXY)
26
- segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device)
27
- boxes_XYXY = boxes_XYXY.long()
28
- for i in range(N):
29
- x0, y0, x1, y1 = boxes_XYXY[i]
30
- assert x0 >= 0 and y0 >= 0
31
- assert x1 <= imshape[1]
32
- assert y1 <= imshape[0]
33
- h = y1 - y0
34
- w = x1 - x0
35
- segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0
36
- return segmentation
37
-
38
-
39
- class CSEDetector:
40
-
41
- def __init__(
42
- self,
43
- cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
44
- cfg_2_download: List[str] = [
45
- "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
46
- "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml",
47
- "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"],
48
- score_thres: float = 0.9,
49
- nms_thresh: float = None,
50
- ) -> None:
51
- with tops.logger.capture_log_stdout():
52
- cfg = get_cfg()
53
- self.device = tops.get_device()
54
- add_densepose_config(cfg)
55
- cfg_path = tops.download_file(cfg_url)
56
- for p in cfg_2_download:
57
- tops.download_file(p)
58
- with tops.logger.capture_log_stdout():
59
- cfg.merge_from_file(cfg_path)
60
- assert cfg_url in model_urls, cfg_url
61
- model_path = tops.download_file(model_urls[cfg_url])
62
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
63
- if nms_thresh is not None:
64
- cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh
65
- cfg.MODEL.WEIGHTS = str(model_path)
66
- cfg.MODEL.DEVICE = str(self.device)
67
- cfg.freeze()
68
- with tops.logger.capture_log_stdout():
69
- self.model = build_model(cfg)
70
- self.model.eval()
71
- DetectionCheckpointer(self.model).load(str(model_path))
72
- self.input_format = cfg.INPUT.FORMAT
73
- self.densepose_extractor = DensePoseOutputsExtractor()
74
- self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
75
-
76
- self.embedder = build_densepose_embedder(cfg)
77
- self.mesh_vertex_embeddings = {
78
- mesh_name: self.embedder(mesh_name).to(self.device)
79
- for mesh_name in self.class_to_mesh_name.values()
80
- if self.embedder.has_embeddings(mesh_name)
81
- }
82
- self.cfg = cfg
83
- self.embed_map = self.mesh_vertex_embeddings["smpl_27554"]
84
- tops.logger.log("CSEDetector built.")
85
-
86
- def __call__(self, *args, **kwargs):
87
- return self.forward(*args, **kwargs)
88
-
89
- def resize_im(self, im):
90
- H, W = im.shape[1:]
91
- newH, newW = ResizeShortestEdge.get_output_shape(
92
- H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
93
- return resize(
94
- im, (newH, newW), InterpolationMode.BILINEAR, antialias=True)
95
-
96
- @torch.no_grad()
97
- def forward(self, im):
98
- assert im.dtype == torch.uint8
99
- if self.input_format == "BGR":
100
- im = im.flip(0)
101
- H, W = im.shape[1:]
102
- im = self.resize_im(im)
103
- output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
104
- scores = output.get("scores")
105
- if len(scores) == 0:
106
- return dict(
107
- instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device),
108
- instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device),
109
- embed_map=self.mesh_vertex_embeddings["smpl_27554"],
110
- bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device),
111
- im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device),
112
- scores=torch.empty((0), dtype=torch.float, device=im.device)
113
- )
114
- pred_densepose, boxes_xywh, classes = self.densepose_extractor(output)
115
- assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose
116
- S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes)
117
- E = pred_densepose.embedding
118
- mesh_name = self.class_to_mesh_name[classes[0]]
119
- assert mesh_name == "smpl_27554"
120
- x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)]
121
- boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1)
122
- boxes_XYXY = boxes_XYXY.round_().long()
123
-
124
- non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not()
125
- S = S[non_empty_boxes]
126
- E = E[non_empty_boxes]
127
- boxes_XYXY = boxes_XYXY[non_empty_boxes]
128
- scores = scores[non_empty_boxes]
129
- im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W])
130
- return dict(
131
- instance_segmentation=S, instance_embedding=E,
132
- bbox_XYXY=boxes_XYXY,
133
- im_segmentation=im_segmentation,
134
- scores=scores.view(-1))
135
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/detection/models/keypoint_maskrcnn.py DELETED
@@ -1,111 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from detectron2.checkpoint import DetectionCheckpointer
4
- from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads
5
- from detectron2.data.transforms import ResizeShortestEdge
6
- from detectron2.structures import Instances
7
- from detectron2 import model_zoo
8
- from detectron2.config import instantiate
9
- from detectron2.config import LazyCall as L
10
- from PIL import Image
11
- import tops
12
- import functools
13
- from torchvision.transforms.functional import resize
14
-
15
-
16
- def get_rn50_fpn_keypoint_rcnn(weight_path: str):
17
- from detectron2.modeling.poolers import ROIPooler
18
- from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
19
- from detectron2.layers import ShapeSpec
20
- model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model
21
- model.roi_heads.update(
22
- num_classes=1,
23
- keypoint_in_features=["p2", "p3", "p4", "p5"],
24
- keypoint_pooler=L(ROIPooler)(
25
- output_size=14,
26
- scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
27
- sampling_ratio=0,
28
- pooler_type="ROIAlignV2",
29
- ),
30
- keypoint_head=L(KRCNNConvDeconvUpsampleHead)(
31
- input_shape=ShapeSpec(channels=256, width=14, height=14),
32
- num_keypoints=17,
33
- conv_dims=[512] * 8,
34
- loss_normalizer="visible",
35
- ),
36
- )
37
-
38
- # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2.
39
- # 1000 proposals per-image is found to hurt box AP.
40
- # Therefore we increase it to 1500 per-image.
41
- model.proposal_generator.post_nms_topk = (1500, 1000)
42
-
43
- # Keypoint AP degrades (though box AP improves) when using plain L1 loss
44
- model.roi_heads.box_predictor.smooth_l1_beta = 0.5
45
- model = instantiate(model)
46
-
47
- dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader
48
- test_transform = instantiate(dataloader.test.mapper.augmentations)
49
- DetectionCheckpointer(model).load(weight_path)
50
- return model, test_transform
51
-
52
-
53
- models = {
54
- "rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth")
55
- }
56
-
57
-
58
-
59
-
60
- class KeypointMaskRCNN:
61
-
62
- def __init__(self, model_name: str, score_threshold: float) -> None:
63
- assert model_name in models, f"Did not find {model_name} in models"
64
- model, test_transform = models[model_name]()
65
- self.model = model.eval().to(tops.get_device())
66
- if isinstance(self.model.roi_heads, CascadeROIHeads):
67
- for head in self.model.roi_heads.box_predictors:
68
- assert hasattr(head, "test_score_thresh")
69
- head.test_score_thresh = score_threshold
70
- else:
71
- assert isinstance(self.model.roi_heads, StandardROIHeads)
72
- assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh")
73
- self.model.roi_heads.box_predictor.test_score_thresh = score_threshold
74
-
75
- self.test_transform = test_transform
76
- assert len(self.test_transform) == 1
77
- self.test_transform = self.test_transform[0]
78
- assert isinstance(self.test_transform, ResizeShortestEdge)
79
- assert self.test_transform.interp == Image.BILINEAR
80
- self.image_format = self.model.input_format
81
-
82
- def resize_im(self, im):
83
- H, W = im.shape[-2:]
84
- if self.test_transform.is_range:
85
- size = np.random.randint(self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1)
86
- else:
87
- size = np.random.choice(self.test_transform.short_edge_length)
88
- newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size)
89
- return resize(
90
- im, (newH, newW), antialias=True)
91
-
92
- def __call__(self, *args, **kwargs):
93
- return self.forward(*args, **kwargs)
94
-
95
- @torch.no_grad()
96
- def forward(self, im: torch.Tensor) -> Instances:
97
- assert im.ndim == 3
98
- if self.image_format == "BGR":
99
- im = im.flip(0)
100
- H, W = im.shape[-2:]
101
- im = self.resize_im(im)
102
- im = im.float()
103
- inputs = dict(image=im, height=H, width=W)
104
- # instances contains
105
- # dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps'])
106
- instances = self.model([inputs])[0]["instances"]
107
- return dict(
108
- scores=instances.get("scores").cpu(),
109
- segmentation=instances.get("pred_masks").cpu(),
110
- keypoints=instances.get("pred_keypoints").cpu()
111
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/detection/models/mask_rcnn.py DELETED
@@ -1,78 +0,0 @@
1
- import torch
2
- import tops
3
- from detectron2.modeling import build_model
4
- from detectron2.checkpoint import DetectionCheckpointer
5
- from detectron2.structures import Boxes
6
- from detectron2.data import MetadataCatalog
7
- from detectron2 import model_zoo
8
- from typing import Dict
9
- from detectron2.data.transforms import ResizeShortestEdge
10
- from torchvision.transforms.functional import resize
11
-
12
-
13
-
14
- model_urls = {
15
- "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml": "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl",
16
-
17
- }
18
- class MaskRCNNDetector:
19
-
20
- def __init__(
21
- self,
22
- cfg_name: str = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml",
23
- score_thres: float = 0.9,
24
- class_filter=["person"], #["car", "bicycle","truck", "bus", "backpack"]
25
- fp16_inference: bool = False
26
- ) -> None:
27
- cfg = model_zoo.get_config(cfg_name)
28
- cfg.MODEL.DEVICE = str(tops.get_device())
29
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
30
- cfg.freeze()
31
- self.cfg = cfg
32
- with tops.logger.capture_log_stdout():
33
- self.model = build_model(cfg)
34
- DetectionCheckpointer(self.model).load(model_urls[cfg_name])
35
- self.model.eval()
36
- self.input_format = cfg.INPUT.FORMAT
37
- self.class_names = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
38
- self.class_to_keep = set([self.class_names.index(cls_) for cls_ in class_filter])
39
- self.person_class = self.class_names.index("person")
40
- self.fp16_inference = fp16_inference
41
- tops.logger.log("Mask R-CNN built.")
42
-
43
- def __call__(self, *args, **kwargs):
44
- return self.forward(*args, **kwargs)
45
-
46
- def resize_im(self, im):
47
- H, W = im.shape[1:]
48
- newH, newW = ResizeShortestEdge.get_output_shape(
49
- H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
50
- return resize(
51
- im, (newH, newW), antialias=True)
52
-
53
- @torch.no_grad()
54
- def forward(self, im: torch.Tensor):
55
- if self.input_format == "BGR":
56
- im = im.flip(0)
57
- else:
58
- assert self.input_format == "RGB"
59
- H, W = im.shape[-2:]
60
- im = self.resize_im(im)
61
- with torch.cuda.amp.autocast(enabled=self.fp16_inference):
62
- output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
63
- scores = output.get("scores")
64
- N = len(scores)
65
- classes = output.get("pred_classes")
66
- idx2keep = [i for i in range(N) if classes[i].tolist() in self.class_to_keep]
67
- classes = classes[idx2keep]
68
- assert isinstance(output.get("pred_boxes"), Boxes)
69
- segmentation = output.get("pred_masks")[idx2keep]
70
- assert segmentation.dtype == torch.bool
71
- is_person = classes == self.person_class
72
- return {
73
- "scores": output.get("scores")[idx2keep],
74
- "segmentation": segmentation,
75
- "classes": output.get("pred_classes")[idx2keep],
76
- "is_person": is_person
77
- }
78
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/detection/person_detector.py DELETED
@@ -1,135 +0,0 @@
1
- import torch
2
- import lzma
3
- from dp2.detection.base import BaseDetector
4
- from .utils import combine_cse_maskrcnn_dets
5
- from .models.cse import CSEDetector
6
- from .models.mask_rcnn import MaskRCNNDetector
7
- from .models.keypoint_maskrcnn import KeypointMaskRCNN
8
- from .structures import CSEPersonDetection, PersonDetection
9
- from pathlib import Path
10
-
11
-
12
- class CSEPersonDetector(BaseDetector):
13
- def __init__(
14
- self,
15
- score_threshold: float,
16
- mask_rcnn_cfg: dict,
17
- cse_cfg: dict,
18
- cse_post_process_cfg: dict,
19
- **kwargs
20
- ) -> None:
21
- super().__init__(**kwargs)
22
- self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
23
- self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold)
24
- self.post_process_cfg = cse_post_process_cfg
25
- self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold")
26
-
27
- def __call__(self, *args, **kwargs):
28
- return self.forward(*args, **kwargs)
29
-
30
- def load_from_cache(self, cache_path: Path):
31
- with lzma.open(cache_path, "rb") as fp:
32
- state_dict = torch.load(fp)
33
- kwargs = dict(
34
- post_process_cfg=self.post_process_cfg,
35
- embed_map=self.cse_detector.embed_map,
36
- )
37
- return [
38
- state["cls"].from_state_dict(**kwargs, state_dict=state)
39
- for state in state_dict
40
- ]
41
-
42
- @torch.no_grad()
43
- def forward(self, im: torch.Tensor, cse_dets=None):
44
- mask_dets = self.mask_rcnn(im)
45
- if cse_dets is None:
46
- cse_dets = self.cse_detector(im)
47
- segmentation = mask_dets["segmentation"]
48
- segmentation, cse_dets, _ = combine_cse_maskrcnn_dets(
49
- segmentation, cse_dets, self.iou_combine_threshold
50
- )
51
- det = CSEPersonDetection(
52
- segmentation=segmentation,
53
- cse_dets=cse_dets,
54
- embed_map=self.cse_detector.embed_map,
55
- orig_imshape_CHW=im.shape,
56
- **self.post_process_cfg
57
- )
58
- return [det]
59
-
60
-
61
- class MaskRCNNPersonDetector(BaseDetector):
62
- def __init__(
63
- self,
64
- score_threshold: float,
65
- mask_rcnn_cfg: dict,
66
- cse_post_process_cfg: dict,
67
- **kwargs
68
- ) -> None:
69
- super().__init__(**kwargs)
70
- self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
71
- self.post_process_cfg = cse_post_process_cfg
72
-
73
- def __call__(self, *args, **kwargs):
74
- return self.forward(*args, **kwargs)
75
-
76
- def load_from_cache(self, cache_path: Path):
77
- with lzma.open(cache_path, "rb") as fp:
78
- state_dict = torch.load(fp)
79
- kwargs = dict(
80
- post_process_cfg=self.post_process_cfg,
81
- )
82
- return [
83
- state["cls"].from_state_dict(**kwargs, state_dict=state)
84
- for state in state_dict
85
- ]
86
-
87
- @torch.no_grad()
88
- def forward(self, im: torch.Tensor):
89
- mask_dets = self.mask_rcnn(im)
90
- segmentation = mask_dets["segmentation"]
91
- det = PersonDetection(
92
- segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape
93
- )
94
- return [det]
95
-
96
-
97
- class KeypointMaskRCNNPersonDetector(BaseDetector):
98
- def __init__(
99
- self,
100
- score_threshold: float,
101
- mask_rcnn_cfg: dict,
102
- cse_post_process_cfg: dict,
103
- **kwargs
104
- ) -> None:
105
- super().__init__(**kwargs)
106
- self.mask_rcnn = KeypointMaskRCNN(
107
- **mask_rcnn_cfg, score_threshold=score_threshold
108
- )
109
- self.post_process_cfg = cse_post_process_cfg
110
-
111
- def __call__(self, *args, **kwargs):
112
- return self.forward(*args, **kwargs)
113
-
114
- def load_from_cache(self, cache_path: Path):
115
- with lzma.open(cache_path, "rb") as fp:
116
- state_dict = torch.load(fp)
117
- kwargs = dict(
118
- post_process_cfg=self.post_process_cfg,
119
- )
120
- return [
121
- state["cls"].from_state_dict(**kwargs, state_dict=state)
122
- for state in state_dict
123
- ]
124
-
125
- @torch.no_grad()
126
- def forward(self, im: torch.Tensor):
127
- mask_dets = self.mask_rcnn(im)
128
- segmentation = mask_dets["segmentation"]
129
- det = PersonDetection(
130
- segmentation,
131
- **self.post_process_cfg,
132
- orig_imshape_CHW=im.shape,
133
- keypoints=mask_dets["keypoints"]
134
- )
135
- return [det]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dp2/detection/structures.py DELETED
@@ -1,464 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from dp2 import utils
4
- from dp2.utils import vis_utils, crop_box
5
- from .utils import (
6
- cut_pad_resize, masks_to_boxes,
7
- get_kernel, transform_embedding, initialize_cse_boxes
8
- )
9
- from .box_utils import get_expanded_bbox, include_box
10
- import torchvision
11
- import tops
12
- from .box_utils_fdf import expand_bbox as expand_bbox_fdf
13
-
14
-
15
- class VehicleDetection:
16
-
17
- def __init__(self, segmentation: torch.BoolTensor) -> None:
18
- self.segmentation = segmentation
19
- self.boxes = masks_to_boxes(segmentation)
20
- assert self.boxes.shape[1] == 4, self.boxes.shape
21
- self.n_detections = self.segmentation.shape[0]
22
- area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0])
23
-
24
- sorted_idx = torch.argsort(area, descending=True)
25
- self.segmentation = self.segmentation[sorted_idx]
26
- self.boxes = self.boxes[sorted_idx].cpu()
27
-
28
- def pre_process(self):
29
- pass
30
-
31
- def get_crop(self, idx: int, im):
32
- assert idx < len(self)
33
- box = self.boxes[idx]
34
- im = crop_box(self.im, box)
35
- mask = crop_box(self.segmentation[idx])
36
- mask = mask == 0
37
- return dict(img=im, mask=mask.float(), boxes=box)
38
-
39
- def visualize(self, im):
40
- if len(self) == 0:
41
- return im
42
- im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not())
43
- return im
44
-
45
- def __len__(self):
46
- return self.n_detections
47
-
48
- @staticmethod
49
- def from_state_dict(state_dict, **kwargs):
50
- numel = np.prod(state_dict["shape"])
51
- arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel)
52
- segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"])
53
- return VehicleDetection(segmentation)
54
-
55
- def state_dict(self, **kwargs):
56
- segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy()))
57
- return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape)
58
-
59
-
60
- class FaceDetection:
61
-
62
- def __init__(self, boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool, **kwargs) -> None:
63
- self.boxes = boxes_ltrb.cpu()
64
- assert self.boxes.shape[1] == 4, self.boxes.shape
65
- self.target_imsize = tuple(target_imsize)
66
- # Sory by area to paste in largest faces last
67
- area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
68
- idx = area.argsort(descending=False)
69
- self.boxes = self.boxes[idx]
70
- self.fdf128_expand = fdf128_expand
71
-
72
- def visualize(self, im):
73
- if len(self) == 0:
74
- return im
75
- orig_device = im.device
76
- for box in self.boxes:
77
- simple_expand = False if self.fdf128_expand else True
78
- e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand))
79
- im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2)
80
- im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
81
-
82
- return im.to(device=orig_device)
83
-
84
- def get_crop(self, idx: int, im):
85
- assert idx < len(self)
86
- box = self.boxes[idx].numpy()
87
- simple_expand = False if self.fdf128_expand else True
88
- expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], simple_expand=simple_expand)
89
- im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True)
90
- area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
91
-
92
- # Find the square mask corresponding to box.
93
- box_mask = box.copy().astype(float)
94
- box_mask[[0, 2]] -= expanded_boxes[0]
95
- box_mask[[1, 3]] -= expanded_boxes[1]
96
-
97
- width = expanded_boxes[2] - expanded_boxes[0]
98
- resize_factor = self.target_imsize[0] / width
99
- box_mask = (box_mask * resize_factor).astype(int)
100
- mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32)
101
- crop_box(mask, box_mask).fill_(0)
102
- return dict(
103
- img=im[None], mask=mask[None],
104
- boxes=torch.from_numpy(expanded_boxes).view(1, -1))
105
-
106
- def __len__(self):
107
- return len(self.boxes)
108
-
109
- @staticmethod
110
- def from_state_dict(state_dict, **kwargs):
111
- return FaceDetection(state_dict["boxes"].cpu(), **kwargs)
112
-
113
- def state_dict(self, **kwargs):
114
- return dict(boxes=self.boxes, cls=self.__class__)
115
-
116
- def pre_process(self):
117
- pass
118
-
119
-
120
- def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape):
121
- """
122
- Dilation happens after padding, which could place dilation in the padded area.
123
- Remove this.
124
- """
125
- x0, y0, x1, y1 = exp_box
126
- H, W = orig_imshape
127
- # Padding in original image space
128
- p_y0 = max(0, -y0)
129
- p_y1 = max(y1 - H, 0)
130
- p_x0 = max(0, -x0)
131
- p_x1 = max(x1 - W, 0)
132
- resize_ratio = mask.shape[-2] / (y1-y0)
133
- p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]]
134
- mask[..., :p_y0, :] = 0
135
- mask[..., :p_x0] = 0
136
- mask[..., mask.shape[-2] - p_y1:, :] = 0
137
- mask[..., mask.shape[-1] - p_x1:] = 0
138
-
139
-
140
- class CSEPersonDetection:
141
-
142
- def __init__(self,
143
- segmentation, cse_dets,
144
- target_imsize,
145
- exp_bbox_cfg, exp_bbox_filter,
146
- dilation_percentage: float,
147
- embed_map: torch.Tensor,
148
- orig_imshape_CHW,
149
- normalize_embedding: bool) -> None:
150
- self.segmentation = segmentation
151
- self.cse_dets = cse_dets
152
- self.target_imsize = list(target_imsize)
153
- self.pre_processed = False
154
- self.exp_bbox_cfg = exp_bbox_cfg
155
- self.exp_bbox_filter = exp_bbox_filter
156
- self.dilation_percentage = dilation_percentage
157
- self.embed_map = embed_map
158
- self.normalize_embedding = normalize_embedding
159
- if self.normalize_embedding:
160
- embed_map_mean = self.embed_map.mean(dim=0, keepdim=True)
161
- embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
162
- self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd
163
- self.orig_imshape_CHW = orig_imshape_CHW
164
-
165
- @torch.no_grad()
166
- def pre_process(self):
167
- if self.pre_processed:
168
- return
169
- boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu()
170
- expanded_boxes = []
171
- included_boxes = []
172
- for i in range(len(boxes)):
173
- exp_box = get_expanded_bbox(
174
- boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
175
- target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
176
- if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
177
- continue
178
- included_boxes.append(i)
179
- expanded_boxes.append(exp_box)
180
- expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
181
- self.segmentation = self.segmentation[included_boxes]
182
- self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()}
183
-
184
- self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
185
- area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
186
- for i, box in enumerate(expanded_boxes):
187
- self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
188
-
189
- dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
190
- self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
191
- self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
192
- [remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))]
193
- self.boxes = expanded_boxes.cpu()
194
- self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
195
-
196
- self.pre_processed = True
197
- self.n_detections = len(self.boxes)
198
- self.mask = self.mask.logical_not()
199
-
200
- E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool)
201
- self.vertices = torch.zeros_like(E_mask, dtype=torch.long)
202
- for i in range(self.n_detections):
203
- E_, E_mask[i] = transform_embedding(
204
- self.cse_dets["instance_embedding"][i],
205
- self.cse_dets["instance_segmentation"][i],
206
- self.boxes[i],
207
- self.cse_dets["bbox_XYXY"][i].cpu(),
208
- self.target_imsize
209
- )
210
- self.vertices[i] = utils.from_E_to_vertex(E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None]
211
- self.E_mask = E_mask
212
-
213
- sorted_idx = torch.argsort(area, descending=False)
214
- self.mask = self.mask[sorted_idx]
215
- self.boxes = self.boxes[sorted_idx.cpu()]
216
- self.vertices = self.vertices[sorted_idx]
217
- self.E_mask = self.E_mask[sorted_idx]
218
- self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
219
-
220
- def get_crop(self, idx: int, im):
221
- self.pre_process()
222
- assert idx < len(self)
223
- box = self.boxes[idx]
224
- mask = self.mask[idx]
225
- im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
226
-
227
- vertices_ = self.vertices[idx]
228
- E_mask_ = self.E_mask[idx].float()
229
- if self.normalize_embedding:
230
- embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
231
- else:
232
- embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
233
-
234
- return dict(
235
- img=im,
236
- mask=mask.float()[None],
237
- boxes=box.reshape(1, -1),
238
- E_mask=E_mask_[None],
239
- vertices=vertices_[None],
240
- embed_map=self.embed_map,
241
- embedding=embedding[None],
242
- maskrcnn_mask=self.maskrcnn_mask[idx].float()[None]
243
- )
244
-
245
- def __len__(self):
246
- self.pre_process()
247
- return self.n_detections
248
-
249
- def state_dict(self, after_preprocess=False):
250
- """
251
- The processed annotations occupy more space than the original detections.
252
- """
253
- if not after_preprocess:
254
- return {
255
- "combined_segmentation": self.segmentation.bool(),
256
- "cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(),
257
- "cse_instance_embedding": self.cse_dets["instance_embedding"],
258
- "cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(),
259
- "cls": self.__class__,
260
- "orig_imshape_CHW": self.orig_imshape_CHW
261
- }
262
- self.pre_process()
263
- return dict(
264
- E_mask=torch.from_numpy(np.packbits(self.E_mask.bool().cpu().numpy())),
265
- mask=torch.from_numpy(np.packbits(self.mask.bool().cpu().numpy())),
266
- maskrcnn_mask=torch.from_numpy(np.packbits(self.maskrcnn_mask.bool().cpu().numpy())),
267
- vertices=self.vertices.to(torch.int16).cpu(),
268
- cls=self.__class__,
269
- boxes=self.boxes,
270
- orig_imshape_CHW=self.orig_imshape_CHW,
271
- )
272
-
273
- @staticmethod
274
- def from_state_dict(
275
- state_dict, embed_map,
276
- post_process_cfg, **kwargs):
277
- after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict
278
- if after_preprocess:
279
- detection = CSEPersonDetection(
280
- segmentation=None, cse_dets=None, embed_map=embed_map,
281
- orig_imshape_CHW=state_dict["orig_imshape_CHW"],
282
- **post_process_cfg)
283
- detection.vertices = tops.to_cuda(state_dict["vertices"].long())
284
- numel = np.prod(detection.vertices.shape)
285
- detection.E_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["E_mask"].numpy(), count=numel))).view(*detection.vertices.shape)
286
- detection.mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["mask"].numpy(), count=numel))).view(*detection.vertices.shape)
287
- detection.maskrcnn_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["maskrcnn_mask"].numpy(), count=numel))).view(*detection.vertices.shape)
288
- detection.n_detections = len(detection.mask)
289
- detection.pre_processed = True
290
-
291
- if isinstance(state_dict["boxes"], np.ndarray):
292
- state_dict["boxes"] = torch.from_numpy(state_dict["boxes"])
293
- detection.boxes = state_dict["boxes"]
294
- return detection
295
-
296
- cse_dets = dict(
297
- instance_segmentation=state_dict["cse_instance_segmentation"],
298
- instance_embedding=state_dict["cse_instance_embedding"],
299
- embed_map=embed_map,
300
- bbox_XYXY=state_dict["cse_bbox_XYXY"])
301
- cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()}
302
-
303
- segmentation = state_dict["combined_segmentation"]
304
- return CSEPersonDetection(
305
- segmentation, cse_dets, embed_map=embed_map,
306
- orig_imshape_CHW=state_dict["orig_imshape_CHW"],
307
- **post_process_cfg)
308
-
309
- def visualize(self, im):
310
- self.pre_process()
311
- if len(self) == 0:
312
- return im
313
- im = vis_utils.draw_cropped_masks(
314
- im.clone(), self.mask, self.boxes, visualize_instances=False)
315
- E = self.embed_map[self.vertices.long()].squeeze(1).permute(0,3, 1, 2)
316
- im = im.to(E.device)
317
- im = vis_utils.draw_cse_all(
318
- E, self.E_mask.squeeze(1).bool(), im,
319
- self.boxes, self.embed_map)
320
- im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
321
- return im
322
-
323
-
324
- def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes):
325
- keypoints = keypoints.clone()
326
- N = boxes.shape[0]
327
- tops.assert_shape(keypoints, (N, None, 3))
328
- tops.assert_shape(boxes, (N, 4))
329
- x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T]
330
-
331
- w = x1 - x0
332
- h = y1 - y0
333
- keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w
334
- keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h
335
- check_outside = lambda x: (x < 0).logical_or(x > 1)
336
- is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1]))
337
- keypoints[:, :, 2] = keypoints[:, :, 2] >= 0
338
- keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not())
339
- return keypoints
340
-
341
-
342
- class PersonDetection:
343
-
344
- def __init__(
345
- self,
346
- segmentation,
347
- target_imsize,
348
- exp_bbox_cfg, exp_bbox_filter,
349
- dilation_percentage: float,
350
- orig_imshape_CHW,
351
- keypoints=None,
352
- **kwargs) -> None:
353
- self.segmentation = segmentation
354
- self.target_imsize = list(target_imsize)
355
- self.pre_processed = False
356
- self.exp_bbox_cfg = exp_bbox_cfg
357
- self.exp_bbox_filter = exp_bbox_filter
358
- self.dilation_percentage = dilation_percentage
359
- self.orig_imshape_CHW = orig_imshape_CHW
360
- self.keypoints = keypoints
361
-
362
- @torch.no_grad()
363
- def pre_process(self):
364
- if self.pre_processed:
365
- return
366
- boxes = masks_to_boxes(self.segmentation).cpu()
367
- expanded_boxes = []
368
- included_boxes = []
369
- for i in range(len(boxes)):
370
- exp_box = get_expanded_bbox(
371
- boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
372
- target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
373
- if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
374
- continue
375
- included_boxes.append(i)
376
- expanded_boxes.append(exp_box)
377
- expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
378
- self.segmentation = self.segmentation[included_boxes]
379
- if self.keypoints is not None:
380
- self.keypoints = self.keypoints[included_boxes]
381
- area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
382
- self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
383
- for i, box in enumerate(expanded_boxes):
384
- self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
385
- if self.keypoints is not None:
386
- self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes)
387
- dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
388
- self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
389
- self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
390
-
391
- [remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))]
392
- self.boxes = expanded_boxes
393
- self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
394
-
395
- self.pre_processed = True
396
- self.n_detections = len(self.boxes)
397
- self.mask = self.mask.logical_not()
398
-
399
- sorted_idx = torch.argsort(area, descending=False)
400
- self.mask = self.mask[sorted_idx]
401
- self.boxes = self.boxes[sorted_idx.cpu()]
402
- self.segmentation = self.segmentation[sorted_idx]
403
- self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
404
- if self.keypoints is not None:
405
- self.keypoints = self.keypoints[sorted_idx]
406
-
407
- def get_crop(self, idx: int, im: torch.Tensor):
408
- assert idx < len(self)
409
- self.pre_process()
410
- box = self.boxes[idx]
411
- mask = self.mask[idx][None].float()
412
- im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
413
- batch = dict(
414
- img=im, mask=mask, boxes=box.reshape(1, -1),
415
- maskrcnn_mask=self.maskrcnn_mask[idx][None].float())
416
- if self.keypoints is not None:
417
- batch["keypoints"] = self.keypoints[idx:idx+1]
418
- return batch
419
-
420
- def __len__(self):
421
- self.pre_process()
422
- return self.n_detections
423
-
424
- def state_dict(self, **kwargs):
425
- return dict(
426
- segmentation=self.segmentation.bool(),
427
- cls=self.__class__,
428
- orig_imshape_CHW=self.orig_imshape_CHW,
429
- keypoints=self.keypoints
430
- )
431
-
432
- @staticmethod
433
- def from_state_dict(
434
- state_dict,
435
- post_process_cfg, **kwargs):
436
- return PersonDetection(
437
- state_dict["segmentation"],
438
- orig_imshape_CHW=state_dict["orig_imshape_CHW"],
439
- **post_process_cfg,
440
- keypoints=state_dict["keypoints"])
441
-
442
- def visualize(self, im):
443
- self.pre_process()
444
- im = im.cpu()
445
- if len(self) == 0:
446
- return im
447
- im = vis_utils.draw_cropped_masks(im.clone(), self.mask, self.boxes, visualize_instances=False)
448
- im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes)
449
- return im
450
-
451
-
452
- def get_dilated_boxes(exp_bbox: torch.LongTensor, mask):
453
- """
454
- mask: resized mask
455
- """
456
- assert exp_bbox.shape[0] == mask.shape[0]
457
- boxes = masks_to_boxes(mask.squeeze(1)).cpu()
458
- H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0]
459
- boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long()
460
- boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long()
461
- boxes[:, [0, 2]] += exp_bbox[:, 0:1]
462
- boxes[:, [1, 3]] += exp_bbox[:, 1:2]
463
- return boxes
464
-