gaur3009 commited on
Commit
de79343
·
verified ·
1 Parent(s): 984b1c3

Upload 44 files

Browse files
Files changed (45) hide show
  1. .gitattributes +25 -0
  2. cloth_segmentation/LICENSE +21 -0
  3. cloth_segmentation/README.md +56 -0
  4. cloth_segmentation/assets/000.png +3 -0
  5. cloth_segmentation/assets/001.png +3 -0
  6. cloth_segmentation/assets/002.png +3 -0
  7. cloth_segmentation/assets/003.png +3 -0
  8. cloth_segmentation/assets/004.png +3 -0
  9. cloth_segmentation/assets/005.png +3 -0
  10. cloth_segmentation/assets/006.png +3 -0
  11. cloth_segmentation/assets/007.png +3 -0
  12. cloth_segmentation/assets/008.png +3 -0
  13. cloth_segmentation/assets/009.png +3 -0
  14. cloth_segmentation/assets/010.png +3 -0
  15. cloth_segmentation/assets/011.png +3 -0
  16. cloth_segmentation/assets/012.png +3 -0
  17. cloth_segmentation/assets/013.png +3 -0
  18. cloth_segmentation/assets/014.png +3 -0
  19. cloth_segmentation/assets/015.png +3 -0
  20. cloth_segmentation/assets/016.png +3 -0
  21. cloth_segmentation/assets/017.png +3 -0
  22. cloth_segmentation/assets/018.png +3 -0
  23. cloth_segmentation/assets/019.png +3 -0
  24. cloth_segmentation/assets/020.png +3 -0
  25. cloth_segmentation/assets/021.png +3 -0
  26. cloth_segmentation/assets/022.png +3 -0
  27. cloth_segmentation/assets/023.png +3 -0
  28. cloth_segmentation/assets/024.png +3 -0
  29. cloth_segmentation/assets/label_descriptions.json +842 -0
  30. cloth_segmentation/data/aligned_dataset.py +169 -0
  31. cloth_segmentation/data/base_data_loader.py +10 -0
  32. cloth_segmentation/data/base_dataset.py +189 -0
  33. cloth_segmentation/data/custom_dataset_data_loader.py +71 -0
  34. cloth_segmentation/data/data_loader.py +7 -0
  35. cloth_segmentation/data/image_folder.py +81 -0
  36. cloth_segmentation/infer.py +86 -0
  37. cloth_segmentation/model_surgery.py +51 -0
  38. cloth_segmentation/networks/__init__.py +1 -0
  39. cloth_segmentation/networks/u2net.py +565 -0
  40. cloth_segmentation/options/base_options.py +38 -0
  41. cloth_segmentation/samples.md +33 -0
  42. cloth_segmentation/train.py +190 -0
  43. cloth_segmentation/utils/distributed.py +47 -0
  44. cloth_segmentation/utils/saving_utils.py +45 -0
  45. cloth_segmentation/utils/tensorboard_utils.py +54 -0
.gitattributes CHANGED
@@ -33,3 +33,28 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ cloth_segmentation/assets/000.png filter=lfs diff=lfs merge=lfs -text
37
+ cloth_segmentation/assets/001.png filter=lfs diff=lfs merge=lfs -text
38
+ cloth_segmentation/assets/002.png filter=lfs diff=lfs merge=lfs -text
39
+ cloth_segmentation/assets/003.png filter=lfs diff=lfs merge=lfs -text
40
+ cloth_segmentation/assets/004.png filter=lfs diff=lfs merge=lfs -text
41
+ cloth_segmentation/assets/005.png filter=lfs diff=lfs merge=lfs -text
42
+ cloth_segmentation/assets/006.png filter=lfs diff=lfs merge=lfs -text
43
+ cloth_segmentation/assets/007.png filter=lfs diff=lfs merge=lfs -text
44
+ cloth_segmentation/assets/008.png filter=lfs diff=lfs merge=lfs -text
45
+ cloth_segmentation/assets/009.png filter=lfs diff=lfs merge=lfs -text
46
+ cloth_segmentation/assets/010.png filter=lfs diff=lfs merge=lfs -text
47
+ cloth_segmentation/assets/011.png filter=lfs diff=lfs merge=lfs -text
48
+ cloth_segmentation/assets/012.png filter=lfs diff=lfs merge=lfs -text
49
+ cloth_segmentation/assets/013.png filter=lfs diff=lfs merge=lfs -text
50
+ cloth_segmentation/assets/014.png filter=lfs diff=lfs merge=lfs -text
51
+ cloth_segmentation/assets/015.png filter=lfs diff=lfs merge=lfs -text
52
+ cloth_segmentation/assets/016.png filter=lfs diff=lfs merge=lfs -text
53
+ cloth_segmentation/assets/017.png filter=lfs diff=lfs merge=lfs -text
54
+ cloth_segmentation/assets/018.png filter=lfs diff=lfs merge=lfs -text
55
+ cloth_segmentation/assets/019.png filter=lfs diff=lfs merge=lfs -text
56
+ cloth_segmentation/assets/020.png filter=lfs diff=lfs merge=lfs -text
57
+ cloth_segmentation/assets/021.png filter=lfs diff=lfs merge=lfs -text
58
+ cloth_segmentation/assets/022.png filter=lfs diff=lfs merge=lfs -text
59
+ cloth_segmentation/assets/023.png filter=lfs diff=lfs merge=lfs -text
60
+ cloth_segmentation/assets/024.png filter=lfs diff=lfs merge=lfs -text
cloth_segmentation/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Levin Dabhi
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
cloth_segmentation/README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Clothes Segmentation using U2NET #
2
+
3
+ ![Python 3.8](https://img.shields.io/badge/python-3.8-green.svg)
4
+ [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
5
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EhEy3uQh-5oOSagUotVOJAf8m7Vqn0D6?usp=sharing)
6
+
7
+ This repo contains training code, inference code and pre-trained model for Cloths Parsing from human portrait.</br>
8
+ Here clothes are parsed into 3 category: Upper body(red), Lower body(green) and Full body(yellow)
9
+
10
+ ![Sample 000](assets/000.png)
11
+ ![Sample 024](assets/024.png)
12
+ ![Sample 018](assets/018.png)
13
+
14
+ This model works well with any background and almost all poses. For more samples visit [samples.md](samples.md)
15
+
16
+ # Techinal details
17
+
18
+ * **U2NET** : This project uses an amazing [U2NET](https://arxiv.org/abs/2005.09007) as a deep learning model. Instead of having 1 channel output from u2net for typical salient object detection task it outputs 4 channels each respresting upper body cloth, lower body cloth, fully body cloth and background. Only categorical cross-entropy loss is used for a given version of the checkpoint.
19
+
20
+ * **Dataset** : U2net is trained on 45k images [iMaterialist (Fashion) 2019 at FGVC6](https://www.kaggle.com/c/imaterialist-fashion-2019-FGVC6/data) dataset. To reduce complexity, I have clubbed the original 42 categories from dataset labels into 3 categories (upper body, lower body and full body). All images are resized into square `¯\_(ツ)_/¯` 768 x 768 px for training. (This experiment was conducted with 768 px but around 384 px will work fine too if one is retraining on another dataset).
21
+
22
+ # Training
23
+
24
+ - For training this project requires,
25
+ <ul>
26
+ <ul>
27
+ <li>&nbsp; PyTorch > 1.3.0</li>
28
+ <li>&nbsp; tensorboardX</li>
29
+ <li>&nbsp; gdown</li>
30
+ </ul>
31
+ </ul>
32
+
33
+ - Download dataset from this [link](https://www.kaggle.com/c/imaterialist-fashion-2019-FGVC6/data), extract all items.
34
+ - Set path of `train` folder which contains training images and `train.csv` which is label csv file in `options/base_options.py`
35
+ - To port original u2net of all layer except last layer please run `python setup_model_weights.py` and it will generate weights after model surgey in `prev_checkpoints` folder.
36
+ - You can explore various options in `options/base_options.py` like checkpoint saving folder, logs folder etc.
37
+ - For single gpu set `distributed = False` in `options/base_options.py`, for multi gpu set it to `True`.
38
+ - For single gpu run `python train.py`
39
+ - For multi gpu run <br>
40
+ &nbsp;`python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=4 --use_env train.py` <br>
41
+ Here command is for single node, 4 gpu. Tested only for single node.
42
+ - You can watch loss graphs and samples in tensorboard by running tensorboard command in log folder.
43
+
44
+
45
+ # Testing/Inference
46
+ - Download pretrained model from this [link](https://drive.google.com/file/d/1mhF3yqd7R-Uje092eypktNl-RoZNuiCJ/view?usp=sharing)(165 MB) in `trained_checkpoint` folder.
47
+ - Put input images in `input_images` folder
48
+ - Run `python infer.py` for inference.
49
+ - Output will be saved in `output_images`
50
+ ### OR
51
+ - Inference in colab from here [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EhEy3uQh-5oOSagUotVOJAf8m7Vqn0D6?usp=sharing)
52
+
53
+ # Acknowledgements
54
+ - U2net model is from original [u2net repo](https://github.com/xuebinqin/U-2-Net). Thanks to Xuebin Qin for amazing repo.
55
+ - Complete repo follows structure of [Pix2pixHD repo](https://github.com/NVIDIA/pix2pixHD)
56
+
cloth_segmentation/assets/000.png ADDED

Git LFS Details

  • SHA256: 95dc35a24a82c5ba65cec55a3fab818e07a376fc422d8221e3ac2e37d2ba1131
  • Pointer size: 131 Bytes
  • Size of remote file: 569 kB
cloth_segmentation/assets/001.png ADDED

Git LFS Details

  • SHA256: 43d04ddeeb7b1f259d2ffa381f6dbffafb2abc77d4cb24d4634c232ce263d17b
  • Pointer size: 131 Bytes
  • Size of remote file: 516 kB
cloth_segmentation/assets/002.png ADDED

Git LFS Details

  • SHA256: fd999d19e2f3b1e77a6231926fb7ff12bbe88695eda0aef13db1a10422105c24
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
cloth_segmentation/assets/003.png ADDED

Git LFS Details

  • SHA256: 864737c941d7d7a79d5fd25be988d85b7808cd27ed643237ce0ca0078a95b6b4
  • Pointer size: 131 Bytes
  • Size of remote file: 807 kB
cloth_segmentation/assets/004.png ADDED

Git LFS Details

  • SHA256: 26bb66a599aa290c7b4be212c542366f3a7a621ffb23624a7aeaf1f86837d9c7
  • Pointer size: 131 Bytes
  • Size of remote file: 800 kB
cloth_segmentation/assets/005.png ADDED

Git LFS Details

  • SHA256: 7698549ea369352e35918dcd5be747446a4b5dac60ac9859aadb820bc4488d7c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
cloth_segmentation/assets/006.png ADDED

Git LFS Details

  • SHA256: 35bad76aef2023866ac5347d37e6af643625c7fc0d51e20383b528411ca93c40
  • Pointer size: 131 Bytes
  • Size of remote file: 769 kB
cloth_segmentation/assets/007.png ADDED

Git LFS Details

  • SHA256: 722e58b747faec25a75bb82ca210a797d03073edba23b1842663c85cc19db49b
  • Pointer size: 131 Bytes
  • Size of remote file: 945 kB
cloth_segmentation/assets/008.png ADDED

Git LFS Details

  • SHA256: 351fca02f5fc888300f3de8c0c1876f223ab6542399904184552c95be565893f
  • Pointer size: 131 Bytes
  • Size of remote file: 310 kB
cloth_segmentation/assets/009.png ADDED

Git LFS Details

  • SHA256: b6395e2c6945e56f1c1adcdd61f63dfc901f10c46b11f399cc3f7d011264e0fc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
cloth_segmentation/assets/010.png ADDED

Git LFS Details

  • SHA256: 726f073ad20eee924d7e32264055a395d6741776108dd9a20462b0f83aed55e9
  • Pointer size: 131 Bytes
  • Size of remote file: 611 kB
cloth_segmentation/assets/011.png ADDED

Git LFS Details

  • SHA256: d526026686794bff283c75a8fcd84a557553b3dd2fe6eb905b9ffd1f4b62cf9d
  • Pointer size: 131 Bytes
  • Size of remote file: 476 kB
cloth_segmentation/assets/012.png ADDED

Git LFS Details

  • SHA256: 5bdadd9fb0fdf4e753170a42f6926fb4459a98d91a63c395f9c65513eba5beef
  • Pointer size: 131 Bytes
  • Size of remote file: 782 kB
cloth_segmentation/assets/013.png ADDED

Git LFS Details

  • SHA256: f871a7ab042fdac3aacea7a2f0c0ffdfbb482e0d55b4e9613f30e3a2485d61f3
  • Pointer size: 131 Bytes
  • Size of remote file: 723 kB
cloth_segmentation/assets/014.png ADDED

Git LFS Details

  • SHA256: 289aa7c5b407effbdeb9ac1880a7fbc37e3cbf584f368bcc3173ee69763d15a6
  • Pointer size: 131 Bytes
  • Size of remote file: 515 kB
cloth_segmentation/assets/015.png ADDED

Git LFS Details

  • SHA256: 01183092fa367f89464ef3099cdb38a9373351a239fb1ba38ae3c1e1401e1901
  • Pointer size: 131 Bytes
  • Size of remote file: 336 kB
cloth_segmentation/assets/016.png ADDED

Git LFS Details

  • SHA256: f4e2641b803c1f5c7caffec847dc09df1ebd0fcd2cd83d337512701f51fa41b9
  • Pointer size: 131 Bytes
  • Size of remote file: 975 kB
cloth_segmentation/assets/017.png ADDED

Git LFS Details

  • SHA256: a259e6289542ccc821de0180a3f9179586bb70da2522b0fdbeaf394289523ca8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
cloth_segmentation/assets/018.png ADDED

Git LFS Details

  • SHA256: e30fb6c6a85f867dd2ed6efb6927352237a8718401ceede41ed51f28ec122d19
  • Pointer size: 131 Bytes
  • Size of remote file: 654 kB
cloth_segmentation/assets/019.png ADDED

Git LFS Details

  • SHA256: 5cdb54ef58ede969193901fbbcd9b2b9c09cd51238f613acb95633e1723ccd6c
  • Pointer size: 131 Bytes
  • Size of remote file: 542 kB
cloth_segmentation/assets/020.png ADDED

Git LFS Details

  • SHA256: 59ee622dfe7aaf0e7ecc386ee02714e1e4384df401719e9f9c4d168e877728e4
  • Pointer size: 131 Bytes
  • Size of remote file: 601 kB
cloth_segmentation/assets/021.png ADDED

Git LFS Details

  • SHA256: 2a91ded9738b6b4dd3b3904150286674697a2cc8b2962998615c3c61217739b4
  • Pointer size: 131 Bytes
  • Size of remote file: 469 kB
cloth_segmentation/assets/022.png ADDED

Git LFS Details

  • SHA256: 9672b17bc62ab1c98fe406e3d0ea86e86a11d9215d32d13ca20809a250b19498
  • Pointer size: 131 Bytes
  • Size of remote file: 671 kB
cloth_segmentation/assets/023.png ADDED

Git LFS Details

  • SHA256: c9407451886d5930108db1ffe0c6206b9f0c40111878ca1035f9305511a21f44
  • Pointer size: 131 Bytes
  • Size of remote file: 460 kB
cloth_segmentation/assets/024.png ADDED

Git LFS Details

  • SHA256: 31323f79974dcd19c6f894d56595c2ba41a0c596bdb57392ddb6141e0d0115a8
  • Pointer size: 131 Bytes
  • Size of remote file: 942 kB
cloth_segmentation/assets/label_descriptions.json ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "info": {
3
+ "year": 2019,
4
+ "version": "1.0",
5
+ "description": "The 2019 FGVC^6 iMaterialist Competition - Fashion track dataset.",
6
+ "contributor": "iMaterialist Fashion Competition group",
7
+ "url": "https://github.com/visipedia/imat_comp",
8
+ "date_created": "2019-04-19 12:38:27.493919"
9
+ },
10
+ "categories": [
11
+ {
12
+ "id": 0,
13
+ "name": "shirt, blouse",
14
+ "supercategory": "upperbody",
15
+ "level": 2
16
+ },
17
+ {
18
+ "id": 1,
19
+ "name": "top, t-shirt, sweatshirt",
20
+ "supercategory": "upperbody",
21
+ "level": 2
22
+ },
23
+ {
24
+ "id": 2,
25
+ "name": "sweater",
26
+ "supercategory": "upperbody",
27
+ "level": 2
28
+ },
29
+ {
30
+ "id": 3,
31
+ "name": "cardigan",
32
+ "supercategory": "upperbody",
33
+ "level": 2
34
+ },
35
+ {
36
+ "id": 4,
37
+ "name": "jacket",
38
+ "supercategory": "upperbody",
39
+ "level": 2
40
+ },
41
+ {
42
+ "id": 5,
43
+ "name": "vest",
44
+ "supercategory": "upperbody",
45
+ "level": 2
46
+ },
47
+ {
48
+ "id": 6,
49
+ "name": "pants",
50
+ "supercategory": "lowerbody",
51
+ "level": 2
52
+ },
53
+ {
54
+ "id": 7,
55
+ "name": "shorts",
56
+ "supercategory": "lowerbody",
57
+ "level": 2
58
+ },
59
+ {
60
+ "id": 8,
61
+ "name": "skirt",
62
+ "supercategory": "lowerbody",
63
+ "level": 2
64
+ },
65
+ {
66
+ "id": 9,
67
+ "name": "coat",
68
+ "supercategory": "wholebody",
69
+ "level": 2
70
+ },
71
+ {
72
+ "id": 10,
73
+ "name": "dress",
74
+ "supercategory": "wholebody",
75
+ "level": 2
76
+ },
77
+ {
78
+ "id": 11,
79
+ "name": "jumpsuit",
80
+ "supercategory": "wholebody",
81
+ "level": 2
82
+ },
83
+ {
84
+ "id": 12,
85
+ "name": "cape",
86
+ "supercategory": "wholebody",
87
+ "level": 2
88
+ },
89
+ {
90
+ "id": 13,
91
+ "name": "glasses",
92
+ "supercategory": "head",
93
+ "level": 2
94
+ },
95
+ {
96
+ "id": 14,
97
+ "name": "hat",
98
+ "supercategory": "head",
99
+ "level": 2
100
+ },
101
+ {
102
+ "id": 15,
103
+ "name": "headband, head covering, hair accessory",
104
+ "supercategory": "head",
105
+ "level": 2
106
+ },
107
+ {
108
+ "id": 16,
109
+ "name": "tie",
110
+ "supercategory": "neck",
111
+ "level": 2
112
+ },
113
+ {
114
+ "id": 17,
115
+ "name": "glove",
116
+ "supercategory": "arms and hands",
117
+ "level": 2
118
+ },
119
+ {
120
+ "id": 18,
121
+ "name": "watch",
122
+ "supercategory": "arms and hands",
123
+ "level": 2
124
+ },
125
+ {
126
+ "id": 19,
127
+ "name": "belt",
128
+ "supercategory": "waist",
129
+ "level": 2
130
+ },
131
+ {
132
+ "id": 20,
133
+ "name": "leg warmer",
134
+ "supercategory": "legs and feet",
135
+ "level": 2
136
+ },
137
+ {
138
+ "id": 21,
139
+ "name": "tights, stockings",
140
+ "supercategory": "legs and feet",
141
+ "level": 2
142
+ },
143
+ {
144
+ "id": 22,
145
+ "name": "sock",
146
+ "supercategory": "legs and feet",
147
+ "level": 2
148
+ },
149
+ {
150
+ "id": 23,
151
+ "name": "shoe",
152
+ "supercategory": "legs and feet",
153
+ "level": 2
154
+ },
155
+ {
156
+ "id": 24,
157
+ "name": "bag, wallet",
158
+ "supercategory": "others",
159
+ "level": 2
160
+ },
161
+ {
162
+ "id": 25,
163
+ "name": "scarf",
164
+ "supercategory": "others",
165
+ "level": 2
166
+ },
167
+ {
168
+ "id": 26,
169
+ "name": "umbrella",
170
+ "supercategory": "others",
171
+ "level": 2
172
+ },
173
+ {
174
+ "id": 27,
175
+ "name": "hood",
176
+ "supercategory": "garment parts",
177
+ "level": 2
178
+ },
179
+ {
180
+ "id": 28,
181
+ "name": "collar",
182
+ "supercategory": "garment parts",
183
+ "level": 2
184
+ },
185
+ {
186
+ "id": 29,
187
+ "name": "lapel",
188
+ "supercategory": "garment parts",
189
+ "level": 2
190
+ },
191
+ {
192
+ "id": 30,
193
+ "name": "epaulette",
194
+ "supercategory": "garment parts",
195
+ "level": 2
196
+ },
197
+ {
198
+ "id": 31,
199
+ "name": "sleeve",
200
+ "supercategory": "garment parts",
201
+ "level": 2
202
+ },
203
+ {
204
+ "id": 32,
205
+ "name": "pocket",
206
+ "supercategory": "garment parts",
207
+ "level": 2
208
+ },
209
+ {
210
+ "id": 33,
211
+ "name": "neckline",
212
+ "supercategory": "garment parts",
213
+ "level": 2
214
+ },
215
+ {
216
+ "id": 34,
217
+ "name": "buckle",
218
+ "supercategory": "closures",
219
+ "level": 2
220
+ },
221
+ {
222
+ "id": 35,
223
+ "name": "zipper",
224
+ "supercategory": "closures",
225
+ "level": 2
226
+ },
227
+ {
228
+ "id": 36,
229
+ "name": "applique",
230
+ "supercategory": "decorations",
231
+ "level": 2
232
+ },
233
+ {
234
+ "id": 37,
235
+ "name": "bead",
236
+ "supercategory": "decorations",
237
+ "level": 2
238
+ },
239
+ {
240
+ "id": 38,
241
+ "name": "bow",
242
+ "supercategory": "decorations",
243
+ "level": 2
244
+ },
245
+ {
246
+ "id": 39,
247
+ "name": "flower",
248
+ "supercategory": "decorations",
249
+ "level": 2
250
+ },
251
+ {
252
+ "id": 40,
253
+ "name": "fringe",
254
+ "supercategory": "decorations",
255
+ "level": 2
256
+ },
257
+ {
258
+ "id": 41,
259
+ "name": "ribbon",
260
+ "supercategory": "decorations",
261
+ "level": 2
262
+ },
263
+ {
264
+ "id": 42,
265
+ "name": "rivet",
266
+ "supercategory": "decorations",
267
+ "level": 2
268
+ },
269
+ {
270
+ "id": 43,
271
+ "name": "ruffle",
272
+ "supercategory": "decorations",
273
+ "level": 2
274
+ },
275
+ {
276
+ "id": 44,
277
+ "name": "sequin",
278
+ "supercategory": "decorations",
279
+ "level": 2
280
+ },
281
+ {
282
+ "id": 45,
283
+ "name": "tassel",
284
+ "supercategory": "decorations",
285
+ "level": 2
286
+ }
287
+ ],
288
+ "attributes": [
289
+ {
290
+ "id": 0,
291
+ "name": "above-the-hip (length)",
292
+ "supercategory": "length",
293
+ "level": 1
294
+ },
295
+ {
296
+ "id": 1,
297
+ "name": "hip (length)",
298
+ "supercategory": "length",
299
+ "level": 1
300
+ },
301
+ {
302
+ "id": 2,
303
+ "name": "micro (length)",
304
+ "supercategory": "length",
305
+ "level": 1
306
+ },
307
+ {
308
+ "id": 3,
309
+ "name": "mini (length)",
310
+ "supercategory": "length",
311
+ "level": 1
312
+ },
313
+ {
314
+ "id": 4,
315
+ "name": "above-the-knee (length)",
316
+ "supercategory": "length",
317
+ "level": 1
318
+ },
319
+ {
320
+ "id": 5,
321
+ "name": "knee (length)",
322
+ "supercategory": "length",
323
+ "level": 1
324
+ },
325
+ {
326
+ "id": 6,
327
+ "name": "below the knee (length)",
328
+ "supercategory": "length",
329
+ "level": 1
330
+ },
331
+ {
332
+ "id": 7,
333
+ "name": "midi",
334
+ "supercategory": "length",
335
+ "level": 1
336
+ },
337
+ {
338
+ "id": 8,
339
+ "name": "maxi (length)",
340
+ "supercategory": "length",
341
+ "level": 1
342
+ },
343
+ {
344
+ "id": 9,
345
+ "name": "floor (length)",
346
+ "supercategory": "length",
347
+ "level": 1
348
+ },
349
+ {
350
+ "id": 10,
351
+ "name": "single breasted",
352
+ "supercategory": "opening type",
353
+ "level": 1
354
+ },
355
+ {
356
+ "id": 11,
357
+ "name": "double breasted",
358
+ "supercategory": "opening type",
359
+ "level": 1
360
+ },
361
+ {
362
+ "id": 12,
363
+ "name": "lace up",
364
+ "supercategory": "opening type",
365
+ "level": 1
366
+ },
367
+ {
368
+ "id": 13,
369
+ "name": "wrapping",
370
+ "supercategory": "opening type",
371
+ "level": 1
372
+ },
373
+ {
374
+ "id": 14,
375
+ "name": "zip-up",
376
+ "supercategory": "opening type",
377
+ "level": 1
378
+ },
379
+ {
380
+ "id": 15,
381
+ "name": "fly (opening)",
382
+ "supercategory": "opening type",
383
+ "level": 1
384
+ },
385
+ {
386
+ "id": 16,
387
+ "name": "buckled (opening)",
388
+ "supercategory": "opening type",
389
+ "level": 1
390
+ },
391
+ {
392
+ "id": 17,
393
+ "name": "toggled (opening)",
394
+ "supercategory": "opening type",
395
+ "level": 1
396
+ },
397
+ {
398
+ "id": 18,
399
+ "name": "no opening",
400
+ "supercategory": "opening type",
401
+ "level": 1
402
+ },
403
+ {
404
+ "id": 19,
405
+ "name": "asymmetrical",
406
+ "supercategory": "silhouette",
407
+ "level": 1
408
+ },
409
+ {
410
+ "id": 20,
411
+ "name": "symmetrical",
412
+ "supercategory": "silhouette",
413
+ "level": 1
414
+ },
415
+ {
416
+ "id": 21,
417
+ "name": "peplum",
418
+ "supercategory": "silhouette",
419
+ "level": 1
420
+ },
421
+ {
422
+ "id": 22,
423
+ "name": "circle",
424
+ "supercategory": "silhouette",
425
+ "level": 1
426
+ },
427
+ {
428
+ "id": 23,
429
+ "name": "flare",
430
+ "supercategory": "silhouette",
431
+ "level": 1
432
+ },
433
+ {
434
+ "id": 24,
435
+ "name": "fit and flare",
436
+ "supercategory": "silhouette",
437
+ "level": 1
438
+ },
439
+ {
440
+ "id": 25,
441
+ "name": "trumpet",
442
+ "supercategory": "silhouette",
443
+ "level": 1
444
+ },
445
+ {
446
+ "id": 26,
447
+ "name": "mermaid",
448
+ "supercategory": "silhouette",
449
+ "level": 1
450
+ },
451
+ {
452
+ "id": 27,
453
+ "name": "balloon",
454
+ "supercategory": "silhouette",
455
+ "level": 1
456
+ },
457
+ {
458
+ "id": 28,
459
+ "name": "bell",
460
+ "supercategory": "silhouette",
461
+ "level": 1
462
+ },
463
+ {
464
+ "id": 29,
465
+ "name": "bell bottom",
466
+ "supercategory": "silhouette",
467
+ "level": 1
468
+ },
469
+ {
470
+ "id": 30,
471
+ "name": "bootcut",
472
+ "supercategory": "silhouette",
473
+ "level": 1
474
+ },
475
+ {
476
+ "id": 31,
477
+ "name": "peg",
478
+ "supercategory": "silhouette",
479
+ "level": 1
480
+ },
481
+ {
482
+ "id": 32,
483
+ "name": "pencil",
484
+ "supercategory": "silhouette",
485
+ "level": 1
486
+ },
487
+ {
488
+ "id": 33,
489
+ "name": "straight",
490
+ "supercategory": "silhouette",
491
+ "level": 1
492
+ },
493
+ {
494
+ "id": 34,
495
+ "name": "a-line",
496
+ "supercategory": "silhouette",
497
+ "level": 1
498
+ },
499
+ {
500
+ "id": 35,
501
+ "name": "tent",
502
+ "supercategory": "silhouette",
503
+ "level": 1
504
+ },
505
+ {
506
+ "id": 36,
507
+ "name": "baggy",
508
+ "supercategory": "silhouette",
509
+ "level": 1
510
+ },
511
+ {
512
+ "id": 37,
513
+ "name": "wide leg",
514
+ "supercategory": "silhouette",
515
+ "level": 1
516
+ },
517
+ {
518
+ "id": 38,
519
+ "name": "high low",
520
+ "supercategory": "silhouette",
521
+ "level": 1
522
+ },
523
+ {
524
+ "id": 39,
525
+ "name": "curved (fit)",
526
+ "supercategory": "silhouette",
527
+ "level": 1
528
+ },
529
+ {
530
+ "id": 40,
531
+ "name": "tight (fit)",
532
+ "supercategory": "silhouette",
533
+ "level": 1
534
+ },
535
+ {
536
+ "id": 41,
537
+ "name": "regular (fit)",
538
+ "supercategory": "silhouette",
539
+ "level": 1
540
+ },
541
+ {
542
+ "id": 42,
543
+ "name": "loose (fit)",
544
+ "supercategory": "silhouette",
545
+ "level": 1
546
+ },
547
+ {
548
+ "id": 43,
549
+ "name": "oversized",
550
+ "supercategory": "silhouette",
551
+ "level": 1
552
+ },
553
+ {
554
+ "id": 44,
555
+ "name": "burnout",
556
+ "supercategory": "textile finishing, manufacturing techniques",
557
+ "level": 1
558
+ },
559
+ {
560
+ "id": 45,
561
+ "name": "distressed",
562
+ "supercategory": "textile finishing, manufacturing techniques",
563
+ "level": 1
564
+ },
565
+ {
566
+ "id": 46,
567
+ "name": "washed",
568
+ "supercategory": "textile finishing, manufacturing techniques",
569
+ "level": 1
570
+ },
571
+ {
572
+ "id": 47,
573
+ "name": "embossed",
574
+ "supercategory": "textile finishing, manufacturing techniques",
575
+ "level": 1
576
+ },
577
+ {
578
+ "id": 48,
579
+ "name": "frayed",
580
+ "supercategory": "textile finishing, manufacturing techniques",
581
+ "level": 1
582
+ },
583
+ {
584
+ "id": 49,
585
+ "name": "printed",
586
+ "supercategory": "textile finishing, manufacturing techniques",
587
+ "level": 1
588
+ },
589
+ {
590
+ "id": 50,
591
+ "name": "ruched",
592
+ "supercategory": "textile finishing, manufacturing techniques",
593
+ "level": 1
594
+ },
595
+ {
596
+ "id": 51,
597
+ "name": "quilted",
598
+ "supercategory": "textile finishing, manufacturing techniques",
599
+ "level": 1
600
+ },
601
+ {
602
+ "id": 52,
603
+ "name": "pleat",
604
+ "supercategory": "textile finishing, manufacturing techniques",
605
+ "level": 1
606
+ },
607
+ {
608
+ "id": 53,
609
+ "name": "gathering",
610
+ "supercategory": "textile finishing, manufacturing techniques",
611
+ "level": 1
612
+ },
613
+ {
614
+ "id": 54,
615
+ "name": "smocking",
616
+ "supercategory": "textile finishing, manufacturing techniques",
617
+ "level": 1
618
+ },
619
+ {
620
+ "id": 55,
621
+ "name": "tiered",
622
+ "supercategory": "textile finishing, manufacturing techniques",
623
+ "level": 1
624
+ },
625
+ {
626
+ "id": 56,
627
+ "name": "cutout",
628
+ "supercategory": "textile finishing, manufacturing techniques",
629
+ "level": 1
630
+ },
631
+ {
632
+ "id": 57,
633
+ "name": "slit",
634
+ "supercategory": "textile finishing, manufacturing techniques",
635
+ "level": 1
636
+ },
637
+ {
638
+ "id": 58,
639
+ "name": "perforated",
640
+ "supercategory": "textile finishing, manufacturing techniques",
641
+ "level": 1
642
+ },
643
+ {
644
+ "id": 59,
645
+ "name": "lining",
646
+ "supercategory": "textile finishing, manufacturing techniques",
647
+ "level": 1
648
+ },
649
+ {
650
+ "id": 60,
651
+ "name": "no special manufacturing technique",
652
+ "supercategory": "textile finishing, manufacturing techniques",
653
+ "level": 1
654
+ },
655
+ {
656
+ "id": 61,
657
+ "name": "plain (pattern)",
658
+ "supercategory": "textile pattern",
659
+ "level": 1
660
+ },
661
+ {
662
+ "id": 62,
663
+ "name": "abstract",
664
+ "supercategory": "textile pattern",
665
+ "level": 1
666
+ },
667
+ {
668
+ "id": 63,
669
+ "name": "cartoon",
670
+ "supercategory": "textile pattern",
671
+ "level": 1
672
+ },
673
+ {
674
+ "id": 64,
675
+ "name": "letters, numbers",
676
+ "supercategory": "textile pattern",
677
+ "level": 1
678
+ },
679
+ {
680
+ "id": 65,
681
+ "name": "camouflage",
682
+ "supercategory": "textile pattern",
683
+ "level": 1
684
+ },
685
+ {
686
+ "id": 66,
687
+ "name": "check",
688
+ "supercategory": "textile pattern",
689
+ "level": 1
690
+ },
691
+ {
692
+ "id": 67,
693
+ "name": "dot",
694
+ "supercategory": "textile pattern",
695
+ "level": 1
696
+ },
697
+ {
698
+ "id": 68,
699
+ "name": "fair isle",
700
+ "supercategory": "textile pattern",
701
+ "level": 1
702
+ },
703
+ {
704
+ "id": 69,
705
+ "name": "floral",
706
+ "supercategory": "textile pattern",
707
+ "level": 1
708
+ },
709
+ {
710
+ "id": 70,
711
+ "name": "geometric",
712
+ "supercategory": "textile pattern",
713
+ "level": 1
714
+ },
715
+ {
716
+ "id": 71,
717
+ "name": "paisley",
718
+ "supercategory": "textile pattern",
719
+ "level": 1
720
+ },
721
+ {
722
+ "id": 72,
723
+ "name": "stripe",
724
+ "supercategory": "textile pattern",
725
+ "level": 1
726
+ },
727
+ {
728
+ "id": 73,
729
+ "name": "houndstooth (pattern)",
730
+ "supercategory": "textile pattern",
731
+ "level": 1
732
+ },
733
+ {
734
+ "id": 74,
735
+ "name": "herringbone (pattern)",
736
+ "supercategory": "textile pattern",
737
+ "level": 1
738
+ },
739
+ {
740
+ "id": 75,
741
+ "name": "chevron",
742
+ "supercategory": "textile pattern",
743
+ "level": 1
744
+ },
745
+ {
746
+ "id": 76,
747
+ "name": "argyle",
748
+ "supercategory": "textile pattern",
749
+ "level": 1
750
+ },
751
+ {
752
+ "id": 77,
753
+ "name": "leopard",
754
+ "supercategory": "animal",
755
+ "level": 2
756
+ },
757
+ {
758
+ "id": 78,
759
+ "name": "snakeskin (pattern)",
760
+ "supercategory": "animal",
761
+ "level": 2
762
+ },
763
+ {
764
+ "id": 79,
765
+ "name": "cheetah",
766
+ "supercategory": "animal",
767
+ "level": 2
768
+ },
769
+ {
770
+ "id": 80,
771
+ "name": "peacock",
772
+ "supercategory": "animal",
773
+ "level": 2
774
+ },
775
+ {
776
+ "id": 81,
777
+ "name": "zebra",
778
+ "supercategory": "animal",
779
+ "level": 2
780
+ },
781
+ {
782
+ "id": 82,
783
+ "name": "giraffe",
784
+ "supercategory": "animal",
785
+ "level": 2
786
+ },
787
+ {
788
+ "id": 83,
789
+ "name": "toile de jouy",
790
+ "supercategory": "textile pattern",
791
+ "level": 1
792
+ },
793
+ {
794
+ "id": 84,
795
+ "name": "plant",
796
+ "supercategory": "textile pattern",
797
+ "level": 1
798
+ },
799
+ {
800
+ "id": 85,
801
+ "name": "empire waistline",
802
+ "supercategory": "waistline",
803
+ "level": 1
804
+ },
805
+ {
806
+ "id": 86,
807
+ "name": "dropped waistline",
808
+ "supercategory": "waistline",
809
+ "level": 1
810
+ },
811
+ {
812
+ "id": 87,
813
+ "name": "high waist",
814
+ "supercategory": "waistline",
815
+ "level": 1
816
+ },
817
+ {
818
+ "id": 88,
819
+ "name": "normal waist",
820
+ "supercategory": "waistline",
821
+ "level": 1
822
+ },
823
+ {
824
+ "id": 89,
825
+ "name": "low waist",
826
+ "supercategory": "waistline",
827
+ "level": 1
828
+ },
829
+ {
830
+ "id": 90,
831
+ "name": "basque (wasitline)",
832
+ "supercategory": "waistline",
833
+ "level": 1
834
+ },
835
+ {
836
+ "id": 91,
837
+ "name": "no waistline",
838
+ "supercategory": "waistline",
839
+ "level": 1
840
+ }
841
+ ]
842
+ }
cloth_segmentation/data/aligned_dataset.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.base_dataset import BaseDataset, Rescale_fixed, Normalize_image
2
+ from data.image_folder import make_dataset, make_dataset_test
3
+
4
+ import os
5
+ import cv2
6
+ import json
7
+ import itertools
8
+ import collections
9
+ from tqdm import tqdm
10
+
11
+ import pandas as pd
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+ import torch
16
+ import torchvision.transforms as transforms
17
+
18
+
19
+ class AlignedDataset(BaseDataset):
20
+ def initialize(self, opt):
21
+ self.opt = opt
22
+ self.image_dir = opt.image_folder
23
+ self.df_path = opt.df_path
24
+ self.width = opt.fine_width
25
+ self.height = opt.fine_height
26
+
27
+ # for rgb imgs
28
+
29
+ transforms_list = []
30
+ transforms_list += [transforms.ToTensor()]
31
+ transforms_list += [Normalize_image(opt.mean, opt.std)]
32
+ self.transform_rgb = transforms.Compose(transforms_list)
33
+
34
+ self.df = pd.read_csv(self.df_path)
35
+ self.image_info = collections.defaultdict(dict)
36
+ self.df["CategoryId"] = self.df.ClassId.apply(lambda x: str(x).split("_")[0])
37
+ temp_df = (
38
+ self.df.groupby("ImageId")["EncodedPixels", "CategoryId"]
39
+ .agg(lambda x: list(x))
40
+ .reset_index()
41
+ )
42
+ size_df = self.df.groupby("ImageId")["Height", "Width"].mean().reset_index()
43
+ temp_df = temp_df.merge(size_df, on="ImageId", how="left")
44
+ for index, row in tqdm(temp_df.iterrows(), total=len(temp_df)):
45
+ image_id = row["ImageId"]
46
+ image_path = os.path.join(self.image_dir, image_id)
47
+ self.image_info[index]["image_id"] = image_id
48
+ self.image_info[index]["image_path"] = image_path
49
+ self.image_info[index]["width"] = self.width
50
+ self.image_info[index]["height"] = self.height
51
+ self.image_info[index]["labels"] = row["CategoryId"]
52
+ self.image_info[index]["orig_height"] = row["Height"]
53
+ self.image_info[index]["orig_width"] = row["Width"]
54
+ self.image_info[index]["annotations"] = row["EncodedPixels"]
55
+
56
+ self.dataset_size = len(self.image_info)
57
+
58
+ def __getitem__(self, index):
59
+ # load images ad masks
60
+ idx = index
61
+ img_path = self.image_info[idx]["image_path"]
62
+ img = Image.open(img_path).convert("RGB")
63
+ img = img.resize((self.width, self.height), resample=Image.BICUBIC)
64
+ image_tensor = self.transform_rgb(img)
65
+
66
+ info = self.image_info[idx]
67
+ mask = np.zeros(
68
+ (len(info["annotations"]), self.width, self.height), dtype=np.uint8
69
+ )
70
+ labels = []
71
+ for m, (annotation, label) in enumerate(
72
+ zip(info["annotations"], info["labels"])
73
+ ):
74
+ sub_mask = self.rle_decode(
75
+ annotation, (info["orig_height"], info["orig_width"])
76
+ )
77
+ sub_mask = Image.fromarray(sub_mask)
78
+ sub_mask = sub_mask.resize(
79
+ (self.width, self.height), resample=Image.BICUBIC
80
+ )
81
+ mask[m, :, :] = sub_mask
82
+ labels.append(int(label) + 1)
83
+
84
+ num_objs = len(labels)
85
+ boxes = []
86
+ new_labels = []
87
+ new_masks = []
88
+
89
+ for i in range(num_objs):
90
+ try:
91
+ pos = np.where(mask[i, :, :])
92
+ xmin = np.min(pos[1])
93
+ xmax = np.max(pos[1])
94
+ ymin = np.min(pos[0])
95
+ ymax = np.max(pos[0])
96
+ if abs(xmax - xmin) >= 20 and abs(ymax - ymin) >= 20:
97
+ boxes.append([xmin, ymin, xmax, ymax])
98
+ new_labels.append(labels[i])
99
+ new_masks.append(mask[i, :, :])
100
+ except ValueError:
101
+ continue
102
+
103
+ if len(new_labels) == 0:
104
+ boxes.append([0, 0, 20, 20])
105
+ new_labels.append(0)
106
+ new_masks.append(mask[0, :, :])
107
+
108
+ nmx = np.zeros((len(new_masks), self.width, self.height), dtype=np.uint8)
109
+ for i, n in enumerate(new_masks):
110
+ nmx[i, :, :] = n
111
+
112
+ boxes = torch.as_tensor(boxes, dtype=torch.float32)
113
+ labels = torch.as_tensor(new_labels, dtype=torch.int64)
114
+ masks = torch.as_tensor(nmx, dtype=torch.uint8)
115
+
116
+ final_label = np.zeros((self.width, self.height), dtype=np.uint8)
117
+ first_channel = np.zeros((self.width, self.height), dtype=np.uint8)
118
+ second_channel = np.zeros((self.width, self.height), dtype=np.uint8)
119
+ third_channel = np.zeros((self.width, self.height), dtype=np.uint8)
120
+
121
+ upperbody = [0, 1, 2, 3, 4, 5]
122
+ lowerbody = [6, 7, 8]
123
+ wholebody = [9, 10, 11, 12]
124
+
125
+ for i in range(len(labels)):
126
+ if labels[i] in upperbody:
127
+ first_channel += new_masks[i]
128
+ elif labels[i] in lowerbody:
129
+ second_channel += new_masks[i]
130
+ elif labels[i] in wholebody:
131
+ third_channel += new_masks[i]
132
+
133
+ first_channel = (first_channel > 0).astype("uint8")
134
+ second_channel = (second_channel > 0).astype("uint8")
135
+ third_channel = (third_channel > 0).astype("uint8")
136
+
137
+ final_label = first_channel + second_channel * 2 + third_channel * 3
138
+ conflict_mask = (final_label <= 3).astype("uint8")
139
+ final_label = (conflict_mask) * final_label + (1 - conflict_mask) * 1
140
+ target_tensor = torch.as_tensor(final_label, dtype=torch.int64)
141
+
142
+ return image_tensor, target_tensor
143
+
144
+ def __len__(self):
145
+ return len(self.image_info)
146
+
147
+ def name(self):
148
+ return "AlignedDataset"
149
+
150
+ def rle_decode(self, mask_rle, shape):
151
+ """
152
+ mask_rle: run-length as string formated: [start0] [length0] [start1] [length1]... in 1d array
153
+ shape: (height,width) of array to return
154
+ Returns numpy array according to the shape, 1 - mask, 0 - background
155
+ """
156
+ shape = (shape[1], shape[0])
157
+ s = mask_rle.split()
158
+ # gets starts & lengths 1d arrays
159
+ starts, lengths = [np.asarray(x, dtype=int) for x in (s[0::2], s[1::2])]
160
+ starts -= 1
161
+ # gets ends 1d array
162
+ ends = starts + lengths
163
+ # creates blank mask image 1d array
164
+ img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
165
+ # sets mark pixles
166
+ for lo, hi in zip(starts, ends):
167
+ img[lo:hi] = 1
168
+ # reshape as a 2d mask image
169
+ return img.reshape(shape).T # Needed to align to RLE direction
cloth_segmentation/data/base_data_loader.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ class BaseDataLoader:
2
+ def __init__(self):
3
+ pass
4
+
5
+ def initialize(self, opt):
6
+ self.opt = opt
7
+ pass
8
+
9
+ def load_data():
10
+ return None
cloth_segmentation/data/base_dataset.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import cv2
4
+ import numpy as np
5
+ import random
6
+
7
+ import torch
8
+ import torch.utils.data as data
9
+ import torchvision.transforms as transforms
10
+
11
+
12
+ class BaseDataset(data.Dataset):
13
+ def __init__(self):
14
+ super(BaseDataset, self).__init__()
15
+
16
+ def name(self):
17
+ return "BaseDataset"
18
+
19
+ def initialize(self, opt):
20
+ pass
21
+
22
+
23
+ class Rescale_fixed(object):
24
+ """Rescale the input image into given size.
25
+
26
+ Args:
27
+ (w,h) (tuple): output size or x (int) then resized will be done in (x,x).
28
+ """
29
+
30
+ def __init__(self, output_size):
31
+ self.output_size = output_size
32
+
33
+ def __call__(self, image):
34
+ return image.resize(self.output_size, Image.BICUBIC)
35
+
36
+
37
+ class Rescale_custom(object):
38
+ """Rescale the input image and target image into randomly selected size with lower bound of min_size arg.
39
+
40
+ Args:
41
+ min_size (int): Minimum desired output size.
42
+ """
43
+
44
+ def __init__(self, min_size, max_size):
45
+ assert isinstance(min_size, (int, float))
46
+ self.min_size = min_size
47
+ self.max_size = max_size
48
+
49
+ def __call__(self, sample):
50
+
51
+ input_image, target_image = sample["input_image"], sample["target_image"]
52
+
53
+ assert input_image.size == target_image.size
54
+ w, h = input_image.size
55
+
56
+ # Randomly select size to resize
57
+ if min(self.max_size, h, w) > self.min_size:
58
+ self.output_size = np.random.randint(
59
+ self.min_size, min(self.max_size, h, w)
60
+ )
61
+ else:
62
+ self.output_size = self.min_size
63
+
64
+ # calculate new size by keeping aspect ratio same
65
+ if h > w:
66
+ new_h, new_w = self.output_size * h / w, self.output_size
67
+ else:
68
+ new_h, new_w = self.output_size, self.output_size * w / h
69
+
70
+ new_w, new_h = int(new_w), int(new_h)
71
+ input_image = input_image.resize((new_w, new_h), Image.BICUBIC)
72
+ target_image = target_image.resize((new_w, new_h), Image.BICUBIC)
73
+ return {"input_image": input_image, "target_image": target_image}
74
+
75
+
76
+ class ToTensor(object):
77
+ """Convert ndarrays in sample to Tensors."""
78
+
79
+ def __init__(self):
80
+ self.totensor = transforms.ToTensor()
81
+
82
+ def __call__(self, sample):
83
+ input_image, target_image = sample["input_image"], sample["target_image"]
84
+
85
+ return {
86
+ "input_image": self.totensor(input_image),
87
+ "target_image": self.totensor(target_image),
88
+ }
89
+
90
+
91
+ class RandomCrop_custom(object):
92
+ """Crop randomly the image in a sample.
93
+
94
+ Args:
95
+ output_size (tuple or int): Desired output size. If int, square crop
96
+ is made.
97
+ """
98
+
99
+ def __init__(self, output_size):
100
+ assert isinstance(output_size, (int, tuple))
101
+ if isinstance(output_size, int):
102
+ self.output_size = (output_size, output_size)
103
+ else:
104
+ assert len(output_size) == 2
105
+ self.output_size = output_size
106
+
107
+ self.randomcrop = transforms.RandomCrop(self.output_size)
108
+
109
+ def __call__(self, sample):
110
+ input_image, target_image = sample["input_image"], sample["target_image"]
111
+ cropped_imgs = self.randomcrop(torch.cat((input_image, target_image)))
112
+
113
+ return {
114
+ "input_image": cropped_imgs[
115
+ :3,
116
+ :,
117
+ ],
118
+ "target_image": cropped_imgs[
119
+ 3:,
120
+ :,
121
+ ],
122
+ }
123
+
124
+
125
+ class Normalize_custom(object):
126
+ """Normalize given dict into given mean and standard dev
127
+
128
+ Args:
129
+ mean (tuple or int): Desired mean to substract from dict's tensors
130
+ std (tuple or int): Desired std to divide from dict's tensors
131
+ """
132
+
133
+ def __init__(self, mean, std):
134
+ assert isinstance(mean, (float, tuple))
135
+ if isinstance(mean, float):
136
+ self.mean = (mean, mean, mean)
137
+ else:
138
+ assert len(mean) == 3
139
+ self.mean = mean
140
+
141
+ if isinstance(std, float):
142
+ self.std = (std, std, std)
143
+ else:
144
+ assert len(std) == 3
145
+ self.std = std
146
+
147
+ self.normalize = transforms.Normalize(self.mean, self.std)
148
+
149
+ def __call__(self, sample):
150
+ input_image, target_image = sample["input_image"], sample["target_image"]
151
+
152
+ return {
153
+ "input_image": self.normalize(input_image),
154
+ "target_image": self.normalize(target_image),
155
+ }
156
+
157
+
158
+ class Normalize_image(object):
159
+ """Normalize given tensor into given mean and standard dev
160
+
161
+ Args:
162
+ mean (float): Desired mean to substract from tensors
163
+ std (float): Desired std to divide from tensors
164
+ """
165
+
166
+ def __init__(self, mean, std):
167
+ assert isinstance(mean, (float))
168
+ if isinstance(mean, float):
169
+ self.mean = mean
170
+
171
+ if isinstance(std, float):
172
+ self.std = std
173
+
174
+ self.normalize_1 = transforms.Normalize(self.mean, self.std)
175
+ self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
176
+ self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)
177
+
178
+ def __call__(self, image_tensor):
179
+ if image_tensor.shape[0] == 1:
180
+ return self.normalize_1(image_tensor)
181
+
182
+ elif image_tensor.shape[0] == 3:
183
+ return self.normalize_3(image_tensor)
184
+
185
+ elif image_tensor.shape[0] == 18:
186
+ return self.normalize_18(image_tensor)
187
+
188
+ else:
189
+ assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"
cloth_segmentation/data/custom_dataset_data_loader.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data
2
+ from data.base_data_loader import BaseDataLoader
3
+
4
+
5
+ def CreateDataset(opt):
6
+ dataset = None
7
+ from data.aligned_dataset import AlignedDataset
8
+ dataset = AlignedDataset()
9
+
10
+ print("dataset [%s] was created" % (dataset.name()))
11
+ dataset.initialize(opt)
12
+ return dataset
13
+
14
+
15
+ class CustomDatasetDataLoader(BaseDataLoader):
16
+ def name(self):
17
+ return 'CustomDatasetDataLoader'
18
+
19
+ def initialize(self, opt):
20
+ BaseDataLoader.initialize(self, opt)
21
+ self.dataset = CreateDataset(opt)
22
+ self.dataloader = torch.utils.data.DataLoader(
23
+ self.dataset,
24
+ batch_size=opt.batchSize,
25
+ sampler=data_sampler(self.dataset,
26
+ not opt.serial_batches, opt.distributed),
27
+ num_workers=int(opt.nThreads),
28
+ pin_memory=True)
29
+
30
+ def get_loader(self):
31
+ return self.dataloader
32
+
33
+ def __len__(self):
34
+ return min(len(self.dataset), self.opt.max_dataset_size)
35
+
36
+
37
+ def data_sampler(dataset, shuffle, distributed):
38
+ if distributed:
39
+ return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
40
+
41
+ if shuffle:
42
+ return torch.utils.data.RandomSampler(dataset)
43
+
44
+ else:
45
+ return torch.utils.data.SequentialSampler(dataset)
46
+
47
+
48
+ def sample_data(loader):
49
+ while True:
50
+ for batch in loader:
51
+ yield batch
52
+
53
+
54
+ class CustomTestDataLoader(BaseDataLoader):
55
+ def name(self):
56
+ return 'CustomDatasetDataLoader'
57
+
58
+ def initialize(self, opt):
59
+ BaseDataLoader.initialize(self, opt)
60
+ self.dataset = CreateDataset(opt)
61
+ self.dataloader = torch.utils.data.DataLoader(
62
+ self.dataset,
63
+ batch_size=opt.batchSize,
64
+ num_workers=int(opt.nThreads),
65
+ pin_memory=True)
66
+
67
+ def get_loader(self):
68
+ return self.dataloader
69
+
70
+ def __len__(self):
71
+ return min(len(self.dataset), self.opt.max_dataset_size)
cloth_segmentation/data/data_loader.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ def CreateDataLoader(opt):
2
+ from data.custom_dataset_data_loader import CustomDatasetDataLoader
3
+
4
+ data_loader = CustomDatasetDataLoader()
5
+ print(data_loader.name())
6
+ data_loader.initialize(opt)
7
+ return data_loader
cloth_segmentation/data/image_folder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###############################################################################
2
+ # Code from
3
+ # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4
+ # Modified the original code so that it also loads images from the current
5
+ # directory as well as the subdirectories
6
+ ###############################################################################
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import os
10
+
11
+ IMG_EXTENSIONS = [
12
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
13
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
14
+ ]
15
+
16
+
17
+ def is_image_file(filename):
18
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
19
+
20
+
21
+ def make_dataset(dir):
22
+ images = []
23
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
24
+
25
+ f = dir.split('/')[-1].split('_')[-1]
26
+ print(dir, f)
27
+ dirs = os.listdir(dir)
28
+ for img in dirs:
29
+ path = os.path.join(dir, img)
30
+ images.append(path)
31
+ return images
32
+
33
+
34
+ def make_dataset_test(dir):
35
+ images = []
36
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
37
+
38
+ f = dir.split('/')[-1].split('_')[-1]
39
+ for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
40
+ if f == 'label' or f == 'labelref':
41
+ img = str(i) + '.png'
42
+ else:
43
+ img = str(i) + '.jpg'
44
+ path = os.path.join(dir, img)
45
+ # print(path)
46
+ images.append(path)
47
+ return images
48
+
49
+
50
+ def default_loader(path):
51
+ return Image.open(path).convert('RGB')
52
+
53
+
54
+ class ImageFolder(data.Dataset):
55
+
56
+ def __init__(self, root, transform=None, return_paths=False,
57
+ loader=default_loader):
58
+ imgs = make_dataset(root)
59
+ if len(imgs) == 0:
60
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
61
+ "Supported image extensions are: " +
62
+ ",".join(IMG_EXTENSIONS)))
63
+
64
+ self.root = root
65
+ self.imgs = imgs
66
+ self.transform = transform
67
+ self.return_paths = return_paths
68
+ self.loader = loader
69
+
70
+ def __getitem__(self, index):
71
+ path = self.imgs[index]
72
+ img = self.loader(path)
73
+ if self.transform is not None:
74
+ img = self.transform(img)
75
+ if self.return_paths:
76
+ return img, path
77
+ else:
78
+ return img
79
+
80
+ def __len__(self):
81
+ return len(self.imgs)
cloth_segmentation/infer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from tqdm import tqdm
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ import warnings
8
+
9
+ warnings.filterwarnings("ignore", category=FutureWarning)
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torchvision.transforms as transforms
15
+
16
+ from data.base_dataset import Normalize_image
17
+ from utils.saving_utils import load_checkpoint_mgpu
18
+
19
+ from networks import U2NET
20
+
21
+ device = "cuda"
22
+
23
+ image_dir = "input_images"
24
+ result_dir = "output_images"
25
+ checkpoint_path = os.path.join("trained_checkpoint", "cloth_segm_u2net_latest.pth")
26
+ do_palette = True
27
+
28
+
29
+ def get_palette(num_cls):
30
+ """Returns the color map for visualizing the segmentation mask.
31
+ Args:
32
+ num_cls: Number of classes
33
+ Returns:
34
+ The color map
35
+ """
36
+ n = num_cls
37
+ palette = [0] * (n * 3)
38
+ for j in range(0, n):
39
+ lab = j
40
+ palette[j * 3 + 0] = 0
41
+ palette[j * 3 + 1] = 0
42
+ palette[j * 3 + 2] = 0
43
+ i = 0
44
+ while lab:
45
+ palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
46
+ palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
47
+ palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
48
+ i += 1
49
+ lab >>= 3
50
+ return palette
51
+
52
+
53
+ transforms_list = []
54
+ transforms_list += [transforms.ToTensor()]
55
+ transforms_list += [Normalize_image(0.5, 0.5)]
56
+ transform_rgb = transforms.Compose(transforms_list)
57
+
58
+ net = U2NET(in_ch=3, out_ch=4)
59
+ net = load_checkpoint_mgpu(net, checkpoint_path)
60
+ net = net.to(device)
61
+ net = net.eval()
62
+
63
+ palette = get_palette(4)
64
+
65
+ images_list = sorted(os.listdir(image_dir))
66
+ pbar = tqdm(total=len(images_list))
67
+ for image_name in images_list:
68
+ img = Image.open(os.path.join(image_dir, image_name)).convert("RGB")
69
+ image_tensor = transform_rgb(img)
70
+ image_tensor = torch.unsqueeze(image_tensor, 0)
71
+
72
+ output_tensor = net(image_tensor.to(device))
73
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
74
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
75
+ output_tensor = torch.squeeze(output_tensor, dim=0)
76
+ output_tensor = torch.squeeze(output_tensor, dim=0)
77
+ output_arr = output_tensor.cpu().numpy()
78
+
79
+ output_img = Image.fromarray(output_arr.astype("uint8"), mode="L")
80
+ if do_palette:
81
+ output_img.putpalette(palette)
82
+ output_img.save(os.path.join(result_dir, image_name[:-3] + "png"))
83
+
84
+ pbar.update(1)
85
+
86
+ pbar.close()
cloth_segmentation/model_surgery.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gdown
3
+ import torch
4
+
5
+ from networks import U2NET
6
+ from utils.saving_utils import save_checkpoint
7
+
8
+ os.makedirs("prev_checkpoints", exist_ok=True)
9
+ gdown.download(
10
+ "https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
11
+ "./prev_checkpoints/u2net.pth",
12
+ quiet=False,
13
+ )
14
+
15
+ u_net = U2NET(in_ch=3, out_ch=4)
16
+ save_checkpoint(u_net, os.path.join("prev_checkpoints", "u2net_random.pth"))
17
+
18
+ # u2net.pth contains trained weights
19
+ trained_net_pth = os.path.join("prev_checkpoints", "u2net.pth")
20
+ # u2net_random.pth contains random weights
21
+ custom_net_pth = os.path.join("prev_checkpoints", "u2net_random.pth")
22
+
23
+ net_state_dict = torch.load(trained_net_pth)
24
+ count = 0
25
+ for k, v in net_state_dict.items():
26
+ count += 1
27
+ print("Total number of layers in trained model are: {}".format(count))
28
+
29
+ custom_state_dict = torch.load(custom_net_pth)
30
+ count = 0
31
+ for k, v in custom_state_dict.items():
32
+ count += 1
33
+ print("Total number of layers in trained model are: {}".format(count))
34
+
35
+ total_count = 0
36
+ update_count = 0
37
+ for k, v in net_state_dict.items():
38
+ total_count += 1
39
+ if custom_state_dict[k].shape == v.shape:
40
+ update_count += 1
41
+ custom_state_dict[k] = v
42
+
43
+ print(
44
+ "Out of {} layers in custom network, {} layers weights are recovered from trained model".format(
45
+ total_count, update_count
46
+ )
47
+ )
48
+ torch.save(
49
+ custom_state_dict, os.path.join("prev_checkpoints", "cloth_segm_unet_surgery.pth")
50
+ )
51
+ print("cloth_segm_unet_surgery.pth is generated in prev_checkpoints directory!")
cloth_segmentation/networks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .u2net import U2NET
cloth_segmentation/networks/u2net.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
8
+ super(REBNCONV, self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(
11
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
12
+ )
13
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
14
+ self.relu_s1 = nn.ReLU(inplace=True)
15
+
16
+ def forward(self, x):
17
+
18
+ hx = x
19
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
20
+
21
+ return xout
22
+
23
+
24
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
25
+ def _upsample_like(src, tar):
26
+
27
+ src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
28
+
29
+ return src
30
+
31
+
32
+ ### RSU-7 ###
33
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
34
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
35
+ super(RSU7, self).__init__()
36
+
37
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
38
+
39
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
40
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
41
+
42
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
43
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
44
+
45
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
46
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
47
+
48
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
49
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
50
+
51
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
52
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
53
+
54
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
55
+
56
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
57
+
58
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
59
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
60
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
61
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
62
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
63
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
64
+
65
+ def forward(self, x):
66
+
67
+ hx = x
68
+ hxin = self.rebnconvin(hx)
69
+
70
+ hx1 = self.rebnconv1(hxin)
71
+ hx = self.pool1(hx1)
72
+
73
+ hx2 = self.rebnconv2(hx)
74
+ hx = self.pool2(hx2)
75
+
76
+ hx3 = self.rebnconv3(hx)
77
+ hx = self.pool3(hx3)
78
+
79
+ hx4 = self.rebnconv4(hx)
80
+ hx = self.pool4(hx4)
81
+
82
+ hx5 = self.rebnconv5(hx)
83
+ hx = self.pool5(hx5)
84
+
85
+ hx6 = self.rebnconv6(hx)
86
+
87
+ hx7 = self.rebnconv7(hx6)
88
+
89
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
90
+ hx6dup = _upsample_like(hx6d, hx5)
91
+
92
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
93
+ hx5dup = _upsample_like(hx5d, hx4)
94
+
95
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
96
+ hx4dup = _upsample_like(hx4d, hx3)
97
+
98
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
99
+ hx3dup = _upsample_like(hx3d, hx2)
100
+
101
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
102
+ hx2dup = _upsample_like(hx2d, hx1)
103
+
104
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
105
+
106
+ """
107
+ del hx1, hx2, hx3, hx4, hx5, hx6, hx7
108
+ del hx6d, hx5d, hx3d, hx2d
109
+ del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
110
+ """
111
+
112
+ return hx1d + hxin
113
+
114
+
115
+ ### RSU-6 ###
116
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6, self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
143
+
144
+ def forward(self, x):
145
+
146
+ hx = x
147
+
148
+ hxin = self.rebnconvin(hx)
149
+
150
+ hx1 = self.rebnconv1(hxin)
151
+ hx = self.pool1(hx1)
152
+
153
+ hx2 = self.rebnconv2(hx)
154
+ hx = self.pool2(hx2)
155
+
156
+ hx3 = self.rebnconv3(hx)
157
+ hx = self.pool3(hx3)
158
+
159
+ hx4 = self.rebnconv4(hx)
160
+ hx = self.pool4(hx4)
161
+
162
+ hx5 = self.rebnconv5(hx)
163
+
164
+ hx6 = self.rebnconv6(hx5)
165
+
166
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
167
+ hx5dup = _upsample_like(hx5d, hx4)
168
+
169
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
170
+ hx4dup = _upsample_like(hx4d, hx3)
171
+
172
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
173
+ hx3dup = _upsample_like(hx3d, hx2)
174
+
175
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
176
+ hx2dup = _upsample_like(hx2d, hx1)
177
+
178
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
179
+
180
+ """
181
+ del hx1, hx2, hx3, hx4, hx5, hx6
182
+ del hx5d, hx4d, hx3d, hx2d
183
+ del hx2dup, hx3dup, hx4dup, hx5dup
184
+ """
185
+
186
+ return hx1d + hxin
187
+
188
+
189
+ ### RSU-5 ###
190
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
191
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
192
+ super(RSU5, self).__init__()
193
+
194
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
195
+
196
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
197
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
198
+
199
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
200
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
201
+
202
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
203
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
204
+
205
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
206
+
207
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
208
+
209
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
211
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
212
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
213
+
214
+ def forward(self, x):
215
+
216
+ hx = x
217
+
218
+ hxin = self.rebnconvin(hx)
219
+
220
+ hx1 = self.rebnconv1(hxin)
221
+ hx = self.pool1(hx1)
222
+
223
+ hx2 = self.rebnconv2(hx)
224
+ hx = self.pool2(hx2)
225
+
226
+ hx3 = self.rebnconv3(hx)
227
+ hx = self.pool3(hx3)
228
+
229
+ hx4 = self.rebnconv4(hx)
230
+
231
+ hx5 = self.rebnconv5(hx4)
232
+
233
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
234
+ hx4dup = _upsample_like(hx4d, hx3)
235
+
236
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
237
+ hx3dup = _upsample_like(hx3d, hx2)
238
+
239
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
240
+ hx2dup = _upsample_like(hx2d, hx1)
241
+
242
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
243
+
244
+ """
245
+ del hx1, hx2, hx3, hx4, hx5
246
+ del hx4d, hx3d, hx2d
247
+ del hx2dup, hx3dup, hx4dup
248
+ """
249
+
250
+ return hx1d + hxin
251
+
252
+
253
+ ### RSU-4 ###
254
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
255
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
256
+ super(RSU4, self).__init__()
257
+
258
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
259
+
260
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
261
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
262
+
263
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
264
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
265
+
266
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
267
+
268
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
269
+
270
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
271
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
272
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
273
+
274
+ def forward(self, x):
275
+
276
+ hx = x
277
+
278
+ hxin = self.rebnconvin(hx)
279
+
280
+ hx1 = self.rebnconv1(hxin)
281
+ hx = self.pool1(hx1)
282
+
283
+ hx2 = self.rebnconv2(hx)
284
+ hx = self.pool2(hx2)
285
+
286
+ hx3 = self.rebnconv3(hx)
287
+
288
+ hx4 = self.rebnconv4(hx3)
289
+
290
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
291
+ hx3dup = _upsample_like(hx3d, hx2)
292
+
293
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
294
+ hx2dup = _upsample_like(hx2d, hx1)
295
+
296
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
297
+
298
+ """
299
+ del hx1, hx2, hx3, hx4
300
+ del hx3d, hx2d
301
+ del hx2dup, hx3dup
302
+ """
303
+
304
+ return hx1d + hxin
305
+
306
+
307
+ ### RSU-4F ###
308
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
309
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
310
+ super(RSU4F, self).__init__()
311
+
312
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
313
+
314
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
315
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
316
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
317
+
318
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
319
+
320
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
321
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
322
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
323
+
324
+ def forward(self, x):
325
+
326
+ hx = x
327
+
328
+ hxin = self.rebnconvin(hx)
329
+
330
+ hx1 = self.rebnconv1(hxin)
331
+ hx2 = self.rebnconv2(hx1)
332
+ hx3 = self.rebnconv3(hx2)
333
+
334
+ hx4 = self.rebnconv4(hx3)
335
+
336
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
337
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
338
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
339
+
340
+ """
341
+ del hx1, hx2, hx3, hx4
342
+ del hx3d, hx2d
343
+ """
344
+
345
+ return hx1d + hxin
346
+
347
+
348
+ ##### U^2-Net ####
349
+ class U2NET(nn.Module):
350
+ def __init__(self, in_ch=3, out_ch=1):
351
+ super(U2NET, self).__init__()
352
+
353
+ self.stage1 = RSU7(in_ch, 32, 64)
354
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
355
+
356
+ self.stage2 = RSU6(64, 32, 128)
357
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage3 = RSU5(128, 64, 256)
360
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage4 = RSU4(256, 128, 512)
363
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage5 = RSU4F(512, 256, 512)
366
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage6 = RSU4F(512, 256, 512)
369
+
370
+ # decoder
371
+ self.stage5d = RSU4F(1024, 256, 512)
372
+ self.stage4d = RSU4(1024, 128, 256)
373
+ self.stage3d = RSU5(512, 64, 128)
374
+ self.stage2d = RSU6(256, 32, 64)
375
+ self.stage1d = RSU7(128, 16, 64)
376
+
377
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
378
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
379
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
380
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
381
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
382
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
383
+
384
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
385
+
386
+ def forward(self, x):
387
+
388
+ hx = x
389
+
390
+ # stage 1
391
+ hx1 = self.stage1(hx)
392
+ hx = self.pool12(hx1)
393
+
394
+ # stage 2
395
+ hx2 = self.stage2(hx)
396
+ hx = self.pool23(hx2)
397
+
398
+ # stage 3
399
+ hx3 = self.stage3(hx)
400
+ hx = self.pool34(hx3)
401
+
402
+ # stage 4
403
+ hx4 = self.stage4(hx)
404
+ hx = self.pool45(hx4)
405
+
406
+ # stage 5
407
+ hx5 = self.stage5(hx)
408
+ hx = self.pool56(hx5)
409
+
410
+ # stage 6
411
+ hx6 = self.stage6(hx)
412
+ hx6up = _upsample_like(hx6, hx5)
413
+
414
+ # -------------------- decoder --------------------
415
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
416
+ hx5dup = _upsample_like(hx5d, hx4)
417
+
418
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
419
+ hx4dup = _upsample_like(hx4d, hx3)
420
+
421
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
422
+ hx3dup = _upsample_like(hx3d, hx2)
423
+
424
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
425
+ hx2dup = _upsample_like(hx2d, hx1)
426
+
427
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
428
+
429
+ # side output
430
+ d1 = self.side1(hx1d)
431
+
432
+ d2 = self.side2(hx2d)
433
+ d2 = _upsample_like(d2, d1)
434
+
435
+ d3 = self.side3(hx3d)
436
+ d3 = _upsample_like(d3, d1)
437
+
438
+ d4 = self.side4(hx4d)
439
+ d4 = _upsample_like(d4, d1)
440
+
441
+ d5 = self.side5(hx5d)
442
+ d5 = _upsample_like(d5, d1)
443
+
444
+ d6 = self.side6(hx6)
445
+ d6 = _upsample_like(d6, d1)
446
+
447
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
448
+
449
+ """
450
+ del hx1, hx2, hx3, hx4, hx5, hx6
451
+ del hx5d, hx4d, hx3d, hx2d, hx1d
452
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
453
+ """
454
+
455
+ return d0, d1, d2, d3, d4, d5, d6
456
+
457
+
458
+ ### U^2-Net small ###
459
+ class U2NETP(nn.Module):
460
+ def __init__(self, in_ch=3, out_ch=1):
461
+ super(U2NETP, self).__init__()
462
+
463
+ self.stage1 = RSU7(in_ch, 16, 64)
464
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
465
+
466
+ self.stage2 = RSU6(64, 16, 64)
467
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
468
+
469
+ self.stage3 = RSU5(64, 16, 64)
470
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
471
+
472
+ self.stage4 = RSU4(64, 16, 64)
473
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
474
+
475
+ self.stage5 = RSU4F(64, 16, 64)
476
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
477
+
478
+ self.stage6 = RSU4F(64, 16, 64)
479
+
480
+ # decoder
481
+ self.stage5d = RSU4F(128, 16, 64)
482
+ self.stage4d = RSU4(128, 16, 64)
483
+ self.stage3d = RSU5(128, 16, 64)
484
+ self.stage2d = RSU6(128, 16, 64)
485
+ self.stage1d = RSU7(128, 16, 64)
486
+
487
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
488
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
489
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
490
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
491
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
492
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
493
+
494
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
495
+
496
+ def forward(self, x):
497
+
498
+ hx = x
499
+
500
+ # stage 1
501
+ hx1 = self.stage1(hx)
502
+ hx = self.pool12(hx1)
503
+
504
+ # stage 2
505
+ hx2 = self.stage2(hx)
506
+ hx = self.pool23(hx2)
507
+
508
+ # stage 3
509
+ hx3 = self.stage3(hx)
510
+ hx = self.pool34(hx3)
511
+
512
+ # stage 4
513
+ hx4 = self.stage4(hx)
514
+ hx = self.pool45(hx4)
515
+
516
+ # stage 5
517
+ hx5 = self.stage5(hx)
518
+ hx = self.pool56(hx5)
519
+
520
+ # stage 6
521
+ hx6 = self.stage6(hx)
522
+ hx6up = _upsample_like(hx6, hx5)
523
+
524
+ # decoder
525
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
526
+ hx5dup = _upsample_like(hx5d, hx4)
527
+
528
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
529
+ hx4dup = _upsample_like(hx4d, hx3)
530
+
531
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
532
+ hx3dup = _upsample_like(hx3d, hx2)
533
+
534
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
535
+ hx2dup = _upsample_like(hx2d, hx1)
536
+
537
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
538
+
539
+ # side output
540
+ d1 = self.side1(hx1d)
541
+
542
+ d2 = self.side2(hx2d)
543
+ d2 = _upsample_like(d2, d1)
544
+
545
+ d3 = self.side3(hx3d)
546
+ d3 = _upsample_like(d3, d1)
547
+
548
+ d4 = self.side4(hx4d)
549
+ d4 = _upsample_like(d4, d1)
550
+
551
+ d5 = self.side5(hx5d)
552
+ d5 = _upsample_like(d5, d1)
553
+
554
+ d6 = self.side6(hx6)
555
+ d6 = _upsample_like(d6, d1)
556
+
557
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
558
+
559
+ """
560
+ del hx1, hx2, hx3, hx4, hx5, hx6
561
+ del hx5d, hx4d, hx3d, hx2d, hx1d
562
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
563
+ """
564
+
565
+ return d0, d1, d2, d3, d4, d5, d6
cloth_segmentation/options/base_options.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import os
3
+
4
+
5
+ class parser(object):
6
+ def __init__(self):
7
+ self.name = "training_cloth_segm_u2net_exp1" # Expriment name
8
+ self.image_folder = "../imaterialist/train/" # image folder path
9
+ self.df_path = "../imaterialist/train.csv" # label csv path
10
+ self.distributed = False # True for multi gpu training
11
+ self.isTrain = True
12
+
13
+ self.fine_width = 192 * 4
14
+ self.fine_height = 192 * 4
15
+
16
+ # Mean std params
17
+ self.mean = 0.5
18
+ self.std = 0.5
19
+
20
+ self.batchSize = 2 # 12
21
+ self.nThreads = 2 # 3
22
+ self.max_dataset_size = float("inf")
23
+
24
+ self.serial_batches = False
25
+ self.continue_train = True
26
+ if self.continue_train:
27
+ self.unet_checkpoint = "prev_checkpoints/cloth_segm_unet_surgery.pth"
28
+
29
+ self.save_freq = 1000
30
+ self.print_freq = 10
31
+ self.image_log_freq = 100
32
+
33
+ self.iter = 100000
34
+ self.lr = 0.0002
35
+ self.clip_grad = 5
36
+
37
+ self.logs_dir = osp.join("logs", self.name)
38
+ self.save_dir = osp.join("results", self.name)
cloth_segmentation/samples.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Sample 001](assets/001.png)
2
+ ![Sample 002](assets/002.png)
3
+ ![Sample 003](assets/003.png)
4
+ ![Sample 004](assets/004.png)
5
+ ![Sample 007](assets/007.png)
6
+ ![Sample 008](assets/008.png)
7
+ ![Sample 009](assets/009.png)
8
+ ![Sample 010](assets/010.png)
9
+ ![Sample 011](assets/011.png)
10
+ ![Sample 012](assets/012.png)
11
+ ![Sample 013](assets/013.png)
12
+ ![Sample 014](assets/014.png)
13
+ ![Sample 015](assets/015.png)
14
+ ![Sample 016](assets/016.png)
15
+ ![Sample 017](assets/017.png)
16
+ ![Sample 022](assets/022.png)
17
+ ![Sample 023](assets/023.png)
18
+
19
+ ## With different poses
20
+
21
+ This model works well with different kind of poses too.
22
+ ![Sample 019](assets/019.png)
23
+ ![Sample 021](assets/021.png)
24
+
25
+ ## Limitations
26
+
27
+ This model doesn't work in the following condition
28
+ - Image containing multiple people
29
+ - Dress or cloth style which are extremely different from trained dataset <br>
30
+
31
+ ![Sample 005](assets/005.png)
32
+ ![Sample 006](assets/006.png)
33
+ ![Sample 020](assets/020.png)
cloth_segmentation/train.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import yaml
5
+ import cv2
6
+ import pprint
7
+ import traceback
8
+ import numpy as np
9
+
10
+ import warnings
11
+
12
+ warnings.filterwarnings("ignore", category=FutureWarning)
13
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+ from torch.autograd import Variable
20
+ import torch.distributed as dist
21
+ import torch.multiprocessing as mp
22
+ from torch.cuda.amp import autocast
23
+ from torch.nn.parallel import DistributedDataParallel as DDP
24
+ from torch.utils.tensorboard import SummaryWriter
25
+ from torchvision import models
26
+
27
+ from data.custom_dataset_data_loader import CustomDatasetDataLoader, sample_data
28
+
29
+
30
+ from options.base_options import parser
31
+ from utils.tensorboard_utils import board_add_images
32
+ from utils.saving_utils import save_checkpoints
33
+ from utils.saving_utils import load_checkpoint, load_checkpoint_mgpu
34
+ from utils.distributed import get_world_size, set_seed, synchronize, cleanup
35
+
36
+ from networks import U2NET
37
+
38
+
39
+ def options_printing_saving(opt):
40
+ os.makedirs(opt.logs_dir, exist_ok=True)
41
+ os.makedirs(opt.save_dir, exist_ok=True)
42
+ os.makedirs(os.path.join(opt.save_dir, "images"), exist_ok=True)
43
+ os.makedirs(os.path.join(opt.save_dir, "checkpoints"), exist_ok=True)
44
+
45
+ # Saving options in yml file
46
+ option_dict = vars(opt)
47
+ with open(os.path.join(opt.save_dir, "training_options.yml"), "w") as outfile:
48
+ yaml.dump(option_dict, outfile)
49
+
50
+ for key, value in option_dict.items():
51
+ print(key, value)
52
+
53
+
54
+ def training_loop(opt):
55
+
56
+ if opt.distributed:
57
+ local_rank = int(os.environ.get("LOCAL_RANK"))
58
+ # Unique only on individual node.
59
+ device = torch.device(f"cuda:{local_rank}")
60
+ else:
61
+ device = torch.device("cuda:0")
62
+ local_rank = 0
63
+
64
+ u_net = U2NET(in_ch=3, out_ch=4)
65
+ if opt.continue_train:
66
+ u_net = load_checkpoint(u_net, opt.unet_checkpoint)
67
+ u_net = u_net.to(device)
68
+ u_net.train()
69
+
70
+ if local_rank == 0:
71
+ with open(os.path.join(opt.save_dir, "networks.txt"), "w") as outfile:
72
+ print("<----U-2-Net---->", file=outfile)
73
+ print(u_net, file=outfile)
74
+
75
+ if opt.distributed:
76
+ u_net = nn.parallel.DistributedDataParallel(
77
+ u_net,
78
+ device_ids=[local_rank],
79
+ output_device=local_rank,
80
+ broadcast_buffers=False,
81
+ )
82
+ print("Going super fast with DistributedDataParallel")
83
+
84
+ # initialize optimizer
85
+ optimizer = optim.Adam(
86
+ u_net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0
87
+ )
88
+
89
+ custom_dataloader = CustomDatasetDataLoader()
90
+ custom_dataloader.initialize(opt)
91
+ loader = custom_dataloader.get_loader()
92
+
93
+ if local_rank == 0:
94
+ dataset_size = len(custom_dataloader)
95
+ print("Total number of images avaliable for training: %d" % dataset_size)
96
+ writer = SummaryWriter(opt.logs_dir)
97
+ print("Entering training loop!")
98
+
99
+ # loss function
100
+ weights = np.array([1, 1.5, 1.5, 1.5], dtype=np.float32)
101
+ weights = torch.from_numpy(weights).to(device)
102
+ loss_CE = nn.CrossEntropyLoss(weight=weights).to(device)
103
+
104
+ pbar = range(opt.iter)
105
+ get_data = sample_data(loader)
106
+
107
+ start_time = time.time()
108
+ # Main training loop
109
+ for itr in pbar:
110
+ data_batch = next(get_data)
111
+ image, label = data_batch
112
+ image = Variable(image.to(device))
113
+ label = label.type(torch.long)
114
+ label = Variable(label.to(device))
115
+
116
+ d0, d1, d2, d3, d4, d5, d6 = u_net(image)
117
+
118
+ loss0 = loss_CE(d0, label)
119
+ loss1 = loss_CE(d1, label)
120
+ loss2 = loss_CE(d2, label)
121
+ loss3 = loss_CE(d3, label)
122
+ loss4 = loss_CE(d4, label)
123
+ loss5 = loss_CE(d5, label)
124
+ loss6 = loss_CE(d6, label)
125
+ del d1, d2, d3, d4, d5, d6
126
+
127
+ total_loss = loss0 * 1.5 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
128
+
129
+ for param in u_net.parameters():
130
+ param.grad = None
131
+
132
+ total_loss.backward()
133
+ if opt.clip_grad != 0:
134
+ nn.utils.clip_grad_norm_(u_net.parameters(), opt.clip_grad)
135
+ optimizer.step()
136
+
137
+ if local_rank == 0:
138
+ # printing and saving work
139
+ if itr % opt.print_freq == 0:
140
+ pprint.pprint(
141
+ "[step-{:08d}] [time-{:.3f}] [total_loss-{:.6f}] [loss0-{:.6f}]".format(
142
+ itr, time.time() - start_time, total_loss, loss0
143
+ )
144
+ )
145
+
146
+ if itr % opt.image_log_freq == 0:
147
+ d0 = F.log_softmax(d0, dim=1)
148
+ d0 = torch.max(d0, dim=1, keepdim=True)[1]
149
+ visuals = [[image, torch.unsqueeze(label, dim=1) * 85, d0 * 85]]
150
+ board_add_images(writer, "grid", visuals, itr)
151
+
152
+ writer.add_scalar("total_loss", total_loss, itr)
153
+ writer.add_scalar("loss0", loss0, itr)
154
+
155
+ if itr % opt.save_freq == 0:
156
+ save_checkpoints(opt, itr, u_net)
157
+
158
+ print("Training done!")
159
+ if local_rank == 0:
160
+ itr += 1
161
+ save_checkpoints(opt, itr, u_net)
162
+
163
+
164
+ if __name__ == "__main__":
165
+
166
+ opt = parser()
167
+
168
+ if opt.distributed:
169
+ if int(os.environ.get("LOCAL_RANK")) == 0:
170
+ options_printing_saving(opt)
171
+ else:
172
+ options_printing_saving(opt)
173
+
174
+ try:
175
+ if opt.distributed:
176
+ print("Initialize Process Group...")
177
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
178
+ synchronize()
179
+
180
+ set_seed(1000)
181
+ training_loop(opt)
182
+ cleanup(opt.distributed)
183
+ print("Exiting..............")
184
+
185
+ except KeyboardInterrupt:
186
+ cleanup(opt.distributed)
187
+
188
+ except Exception:
189
+ traceback.print_exc(file=sys.stdout)
190
+ cleanup(opt.distributed)
cloth_segmentation/utils/distributed.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import random
5
+ import pickle
6
+ import torch
7
+ from torch import distributed as dist
8
+ from torch.utils.data.sampler import Sampler
9
+
10
+
11
+ def set_seed(seed):
12
+ torch.manual_seed(seed)
13
+ torch.cuda.manual_seed_all(seed)
14
+ torch.backends.cudnn.deterministic = True
15
+ torch.backends.cudnn.benchmark = True
16
+ np.random.seed(seed)
17
+ random.seed(seed)
18
+ os.environ['PYTHONHASHSEED'] = str(seed)
19
+
20
+
21
+ def synchronize():
22
+ if not dist.is_available():
23
+ return
24
+
25
+ if not dist.is_initialized():
26
+ return
27
+
28
+ world_size = dist.get_world_size()
29
+ if world_size == 1:
30
+ return
31
+
32
+ dist.barrier()
33
+
34
+
35
+ def cleanup(distributed):
36
+ if distributed:
37
+ dist.destroy_process_group()
38
+
39
+
40
+ def get_world_size():
41
+ if not dist.is_available():
42
+ return 1
43
+
44
+ if not dist.is_initialized():
45
+ return 1
46
+
47
+ return dist.get_world_size()
cloth_segmentation/utils/saving_utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import cv2
4
+ import numpy as np
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+
9
+
10
+ def load_checkpoint(model, checkpoint_path):
11
+ if not os.path.exists(checkpoint_path):
12
+ print("----No checkpoints at given path----")
13
+ return
14
+ model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device("cpu")))
15
+ print("----checkpoints loaded from path: {}----".format(checkpoint_path))
16
+ return model
17
+
18
+
19
+ def load_checkpoint_mgpu(model, checkpoint_path):
20
+ if not os.path.exists(checkpoint_path):
21
+ print("----No checkpoints at given path----")
22
+ return
23
+ model_state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
24
+ new_state_dict = OrderedDict()
25
+ for k, v in model_state_dict.items():
26
+ name = k[7:] # remove `module.`
27
+ new_state_dict[name] = v
28
+
29
+ model.load_state_dict(new_state_dict)
30
+ print("----checkpoints loaded from path: {}----".format(checkpoint_path))
31
+ return model
32
+
33
+
34
+ def save_checkpoint(model, save_path):
35
+ print(save_path)
36
+ if not os.path.exists(os.path.dirname(save_path)):
37
+ os.makedirs(os.path.dirname(save_path))
38
+ torch.save(model.state_dict(), save_path)
39
+
40
+
41
+ def save_checkpoints(opt, itr, net):
42
+ save_checkpoint(
43
+ net,
44
+ os.path.join(opt.save_dir, "checkpoints", "itr_{:08d}_u2net.pth".format(itr)),
45
+ )
cloth_segmentation/utils/tensorboard_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from torch.utils.tensorboard import SummaryWriter
7
+
8
+ # Adding image in tensorboardX
9
+
10
+
11
+ def tensor_for_board(img_tensor):
12
+ # map into [0,1]
13
+ tensor = (img_tensor.clone()+1) * 0.5
14
+ tensor.cpu().clamp(0, 1)
15
+
16
+ if tensor.size(1) == 1:
17
+ tensor = tensor.repeat(1, 3, 1, 1)
18
+
19
+ return tensor
20
+
21
+
22
+ def tensor_list_for_board(img_tensors_list):
23
+ grid_h = len(img_tensors_list)
24
+ grid_w = max(len(img_tensors) for img_tensors in img_tensors_list)
25
+
26
+ batch_size, channel, height, width = tensor_for_board(
27
+ img_tensors_list[0][0]).size()
28
+ canvas_h = grid_h * height
29
+ canvas_w = grid_w * width
30
+ canvas = torch.FloatTensor(
31
+ batch_size, channel, canvas_h, canvas_w).fill_(0.5)
32
+ for i, img_tensors in enumerate(img_tensors_list):
33
+ for j, img_tensor in enumerate(img_tensors):
34
+ offset_h = i * height
35
+ offset_w = j * width
36
+ tensor = tensor_for_board(img_tensor)
37
+ canvas[:, :, offset_h: offset_h + height,
38
+ offset_w: offset_w + width].copy_(tensor)
39
+
40
+ return canvas
41
+
42
+
43
+ def board_add_image(board, tag_name, img_tensor, step_count):
44
+ tensor = tensor_for_board(img_tensor)
45
+
46
+ for i, img in enumerate(tensor):
47
+ board.add_image('%s/%03d' % (tag_name, i), img, step_count)
48
+
49
+
50
+ def board_add_images(board, tag_name, img_tensors_list, step_count):
51
+ tensor = tensor_list_for_board(img_tensors_list)
52
+
53
+ for i, img in enumerate(tensor):
54
+ board.add_image('%s/%03d' % (tag_name, i), img, step_count)