lch01 commited on
Commit
f4ba42f
·
1 Parent(s): 28c1b3e

add dependencies

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. croco/LICENSE +52 -0
  2. croco/NOTICE +21 -0
  3. croco/README.MD +124 -0
  4. croco/assets/arch.jpg +0 -0
  5. croco/croco-stereo-flow-demo.ipynb +191 -0
  6. croco/datasets/__init__.py +0 -0
  7. croco/datasets/crops/README.MD +104 -0
  8. croco/datasets/crops/extract_crops_from_images.py +183 -0
  9. croco/datasets/habitat_sim/README.MD +76 -0
  10. croco/datasets/habitat_sim/__init__.py +0 -0
  11. croco/datasets/habitat_sim/generate_from_metadata.py +125 -0
  12. croco/datasets/habitat_sim/generate_from_metadata_files.py +36 -0
  13. croco/datasets/habitat_sim/generate_multiview_images.py +231 -0
  14. croco/datasets/habitat_sim/multiview_habitat_sim_generator.py +501 -0
  15. croco/datasets/habitat_sim/pack_metadata_files.py +80 -0
  16. croco/datasets/habitat_sim/paths.py +179 -0
  17. croco/datasets/pairs_dataset.py +162 -0
  18. croco/datasets/transforms.py +135 -0
  19. croco/interactive_demo.ipynb +271 -0
  20. croco/models/__pycache__/blocks.cpython-310.pyc +0 -0
  21. croco/models/__pycache__/blocks.cpython-311.pyc +0 -0
  22. croco/models/__pycache__/blocks.cpython-312.pyc +0 -0
  23. croco/models/__pycache__/croco.cpython-310.pyc +0 -0
  24. croco/models/__pycache__/croco.cpython-311.pyc +0 -0
  25. croco/models/__pycache__/croco.cpython-312.pyc +0 -0
  26. croco/models/__pycache__/dpt_block.cpython-310.pyc +0 -0
  27. croco/models/__pycache__/dpt_block.cpython-311.pyc +0 -0
  28. croco/models/__pycache__/dpt_block.cpython-312.pyc +0 -0
  29. croco/models/__pycache__/masking.cpython-310.pyc +0 -0
  30. croco/models/__pycache__/masking.cpython-311.pyc +0 -0
  31. croco/models/__pycache__/masking.cpython-312.pyc +0 -0
  32. croco/models/__pycache__/pos_embed.cpython-310.pyc +0 -0
  33. croco/models/__pycache__/pos_embed.cpython-311.pyc +0 -0
  34. croco/models/__pycache__/pos_embed.cpython-312.pyc +0 -0
  35. croco/models/blocks.py +385 -0
  36. croco/models/criterion.py +38 -0
  37. croco/models/croco.py +330 -0
  38. croco/models/croco_downstream.py +141 -0
  39. croco/models/curope/__init__.py +4 -0
  40. croco/models/curope/__pycache__/__init__.cpython-310.pyc +0 -0
  41. croco/models/curope/__pycache__/__init__.cpython-311.pyc +0 -0
  42. croco/models/curope/__pycache__/__init__.cpython-312.pyc +0 -0
  43. croco/models/curope/__pycache__/curope2d.cpython-310.pyc +0 -0
  44. croco/models/curope/__pycache__/curope2d.cpython-311.pyc +0 -0
  45. croco/models/curope/__pycache__/curope2d.cpython-312.pyc +0 -0
  46. croco/models/curope/curope.cpp +69 -0
  47. croco/models/curope/curope2d.py +40 -0
  48. croco/models/curope/kernels.cu +108 -0
  49. croco/models/curope/setup.py +34 -0
  50. croco/models/dpt_block.py +513 -0
croco/LICENSE ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
8
+
9
+
10
+ SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py
11
+
12
+ ***************************
13
+
14
+ NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py
15
+
16
+ This software is being redistributed in a modifiled form. The original form is available here:
17
+
18
+ https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+
20
+ This software in this file incorporates parts of the following software available here:
21
+
22
+ Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
23
+ available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE
24
+
25
+ MoCo v3: https://github.com/facebookresearch/moco-v3
26
+ available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE
27
+
28
+ DeiT: https://github.com/facebookresearch/deit
29
+ available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE
30
+
31
+
32
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
33
+
34
+ https://github.com/facebookresearch/mae/blob/main/LICENSE
35
+
36
+ Attribution-NonCommercial 4.0 International
37
+
38
+ ***************************
39
+
40
+ NOTICE WITH RESPECT TO THE FILE: models/blocks.py
41
+
42
+ This software is being redistributed in a modifiled form. The original form is available here:
43
+
44
+ https://github.com/rwightman/pytorch-image-models
45
+
46
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
47
+
48
+ https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
49
+
50
+ Apache License
51
+ Version 2.0, January 2004
52
+ http://www.apache.org/licenses/
croco/NOTICE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo
2
+ Copyright 2022-present NAVER Corp.
3
+
4
+ This project contains subcomponents with separate copyright notices and license terms.
5
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6
+
7
+ ====
8
+
9
+ facebookresearch/mae
10
+ https://github.com/facebookresearch/mae
11
+
12
+ Attribution-NonCommercial 4.0 International
13
+
14
+ ====
15
+
16
+ rwightman/pytorch-image-models
17
+ https://github.com/rwightman/pytorch-image-models
18
+
19
+ Apache License
20
+ Version 2.0, January 2004
21
+ http://www.apache.org/licenses/
croco/README.MD ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow
2
+
3
+ [[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)]
4
+
5
+ This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2:
6
+
7
+ ![image](assets/arch.jpg)
8
+
9
+ ```bibtex
10
+ @inproceedings{croco,
11
+ title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}},
12
+ author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}},
13
+ booktitle={{NeurIPS}},
14
+ year={2022}
15
+ }
16
+
17
+ @inproceedings{croco_v2,
18
+ title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}},
19
+ author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me},
20
+ booktitle={ICCV},
21
+ year={2023}
22
+ }
23
+ ```
24
+
25
+ ## License
26
+
27
+ The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information.
28
+ Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License.
29
+ Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license.
30
+
31
+ ## Preparation
32
+
33
+ 1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version.
34
+
35
+ ```bash
36
+ conda create -n croco python=3.7 cmake=3.14.0
37
+ conda activate croco
38
+ conda install habitat-sim headless -c conda-forge -c aihabitat
39
+ conda install pytorch torchvision -c pytorch
40
+ conda install notebook ipykernel matplotlib
41
+ conda install ipywidgets widgetsnbextension
42
+ conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation
43
+
44
+ ```
45
+
46
+ 2. Compile cuda kernels for RoPE
47
+
48
+ CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels.
49
+ ```bash
50
+ cd models/curope/
51
+ python setup.py build_ext --inplace
52
+ cd ../../
53
+ ```
54
+
55
+ This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only.
56
+ You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation.
57
+
58
+ In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded.
59
+
60
+ 3. Download pre-trained model
61
+
62
+ We provide several pre-trained models:
63
+
64
+ | modelname | pre-training data | pos. embed. | Encoder | Decoder |
65
+ |------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------|
66
+ | [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small |
67
+ | [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small |
68
+ | [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base |
69
+ | [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base |
70
+
71
+ To download a specific model, i.e., the first one (`CroCo.pth`)
72
+ ```bash
73
+ mkdir -p pretrained_models/
74
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/
75
+ ```
76
+
77
+ ## Reconstruction example
78
+
79
+ Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`)
80
+ ```bash
81
+ python demo.py
82
+ ```
83
+
84
+ ## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator
85
+
86
+ First download the test scene from Habitat:
87
+ ```bash
88
+ python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/
89
+ ```
90
+
91
+ Then, run the Notebook demo `interactive_demo.ipynb`.
92
+
93
+ In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo.
94
+ ![croco_interactive_demo](https://user-images.githubusercontent.com/1822210/200516576-7937bc6a-55f8-49ed-8618-3ddf89433ea4.jpg)
95
+
96
+ ## Pre-training
97
+
98
+ ### CroCo
99
+
100
+ To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command:
101
+ ```
102
+ torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/
103
+ ```
104
+
105
+ Our CroCo pre-training was launched on a single server with 4 GPUs.
106
+ It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training.
107
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
108
+ The first run can take a few minutes to start, to parse all available pre-training pairs.
109
+
110
+ ### CroCo v2
111
+
112
+ For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD).
113
+ Then, run the following command for the largest model (ViT-L encoder, Base decoder):
114
+ ```
115
+ torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/
116
+ ```
117
+
118
+ Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases.
119
+ The largest model should take around 12 days on A100.
120
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
121
+
122
+ ## Stereo matching and Optical flow downstream tasks
123
+
124
+ For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD).
croco/assets/arch.jpg ADDED
croco/croco-stereo-flow-demo.ipynb ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9bca0f41",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Simple inference example with CroCo-Stereo or CroCo-Flow"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "80653ef7",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
19
+ "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "4f033862",
25
+ "metadata": {},
26
+ "source": [
27
+ "First download the model(s) of your choice by running\n",
28
+ "```\n",
29
+ "bash stereoflow/download_model.sh crocostereo.pth\n",
30
+ "bash stereoflow/download_model.sh crocoflow.pth\n",
31
+ "```"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "id": "1fb2e392",
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "import torch\n",
42
+ "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
43
+ "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
44
+ "import matplotlib.pylab as plt"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "e0e25d77",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "from stereoflow.test import _load_model_and_criterion\n",
55
+ "from stereoflow.engine import tiled_pred\n",
56
+ "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
57
+ "from stereoflow.datasets_flow import flowToColor\n",
58
+ "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "id": "86a921f5",
64
+ "metadata": {},
65
+ "source": [
66
+ "### CroCo-Stereo example"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "id": "64e483cb",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "image1 = np.asarray(Image.open('<path_to_left_image>'))\n",
77
+ "image2 = np.asarray(Image.open('<path_to_right_image>'))"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "f0d04303",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "47dc14b5",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
98
+ "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
99
+ "with torch.inference_mode():\n",
100
+ " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
101
+ "pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "id": "583b9f16",
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "plt.imshow(vis_disparity(pred))\n",
112
+ "plt.axis('off')"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "id": "d2df5d70",
118
+ "metadata": {},
119
+ "source": [
120
+ "### CroCo-Flow example"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "id": "9ee257a7",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "image1 = np.asarray(Image.open('<path_to_first_image>'))\n",
131
+ "image2 = np.asarray(Image.open('<path_to_second_image>'))"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "id": "d5edccf0",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "b19692c3",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
152
+ "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
153
+ "with torch.inference_mode():\n",
154
+ " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
155
+ "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "id": "26f79db3",
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "plt.imshow(flowToColor(pred))\n",
166
+ "plt.axis('off')"
167
+ ]
168
+ }
169
+ ],
170
+ "metadata": {
171
+ "kernelspec": {
172
+ "display_name": "Python 3 (ipykernel)",
173
+ "language": "python",
174
+ "name": "python3"
175
+ },
176
+ "language_info": {
177
+ "codemirror_mode": {
178
+ "name": "ipython",
179
+ "version": 3
180
+ },
181
+ "file_extension": ".py",
182
+ "mimetype": "text/x-python",
183
+ "name": "python",
184
+ "nbconvert_exporter": "python",
185
+ "pygments_lexer": "ipython3",
186
+ "version": "3.9.7"
187
+ }
188
+ },
189
+ "nbformat": 4,
190
+ "nbformat_minor": 5
191
+ }
croco/datasets/__init__.py ADDED
File without changes
croco/datasets/crops/README.MD ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of crops from the real datasets
2
+
3
+ The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
4
+
5
+ ### Download the metadata of the crops to generate
6
+
7
+ First, download the metadata and put them in `./data/`:
8
+ ```
9
+ mkdir -p data
10
+ cd data/
11
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
12
+ unzip crop_metadata.zip
13
+ rm crop_metadata.zip
14
+ cd ..
15
+ ```
16
+
17
+ ### Prepare the original datasets
18
+
19
+ Second, download the original datasets in `./data/original_datasets/`.
20
+ ```
21
+ mkdir -p data/original_datasets
22
+ ```
23
+
24
+ ##### ARKitScenes
25
+
26
+ Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
27
+ The resulting file structure should be like:
28
+ ```
29
+ ./data/original_datasets/ARKitScenes/
30
+ └───Training
31
+ └───40753679
32
+ │ │ ultrawide
33
+ │ │ ...
34
+ └───40753686
35
+
36
+ ...
37
+ ```
38
+
39
+ ##### MegaDepth
40
+
41
+ Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
42
+ The resulting file structure should be like:
43
+
44
+ ```
45
+ ./data/original_datasets/MegaDepth/
46
+ └───0000
47
+ │ └───images
48
+ │ │ │ 1000557903_87fa96b8a4_o.jpg
49
+ │ │ └ ...
50
+ │ └─── ...
51
+ └───0001
52
+ │ │
53
+ │ └ ...
54
+ └─── ...
55
+ ```
56
+
57
+ ##### 3DStreetView
58
+
59
+ Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
60
+ The resulting file structure should be like:
61
+
62
+ ```
63
+ ./data/original_datasets/3DStreetView/
64
+ └───dataset_aligned
65
+ │ └───0002
66
+ │ │ │ 0000002_0000001_0000002_0000001.jpg
67
+ │ │ └ ...
68
+ │ └─── ...
69
+ └───dataset_unaligned
70
+ │ └───0003
71
+ │ │ │ 0000003_0000001_0000002_0000001.jpg
72
+ │ │ └ ...
73
+ │ └─── ...
74
+ ```
75
+
76
+ ##### IndoorVL
77
+
78
+ Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
79
+
80
+ ```
81
+ pip install kapture
82
+ mkdir -p ./data/original_datasets/IndoorVL
83
+ cd ./data/original_datasets/IndoorVL
84
+ kapture_download_dataset.py update
85
+ kapture_download_dataset.py install "HyundaiDepartmentStore_*"
86
+ kapture_download_dataset.py install "GangnamStation_*"
87
+ cd -
88
+ ```
89
+
90
+ ### Extract the crops
91
+
92
+ Now, extract the crops for each of the dataset:
93
+ ```
94
+ for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
95
+ do
96
+ python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
97
+ done
98
+ ```
99
+
100
+ ##### Note for IndoorVL
101
+
102
+ Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
103
+ To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
104
+ The impact on the performance is negligible.
croco/datasets/crops/extract_crops_from_images.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Extracting crops for pre-training
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import argparse
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ import functools
13
+ from multiprocessing import Pool
14
+ import math
15
+
16
+
17
+ def arg_parser():
18
+ parser = argparse.ArgumentParser(
19
+ "Generate cropped image pairs from image crop list"
20
+ )
21
+
22
+ parser.add_argument("--crops", type=str, required=True, help="crop file")
23
+ parser.add_argument("--root-dir", type=str, required=True, help="root directory")
24
+ parser.add_argument(
25
+ "--output-dir", type=str, required=True, help="output directory"
26
+ )
27
+ parser.add_argument("--imsize", type=int, default=256, help="size of the crops")
28
+ parser.add_argument(
29
+ "--nthread", type=int, required=True, help="number of simultaneous threads"
30
+ )
31
+ parser.add_argument(
32
+ "--max-subdir-levels",
33
+ type=int,
34
+ default=5,
35
+ help="maximum number of subdirectories",
36
+ )
37
+ parser.add_argument(
38
+ "--ideal-number-pairs-in-dir",
39
+ type=int,
40
+ default=500,
41
+ help="number of pairs stored in a dir",
42
+ )
43
+ return parser
44
+
45
+
46
+ def main(args):
47
+ listing_path = os.path.join(args.output_dir, "listing.txt")
48
+
49
+ print(f"Loading list of crops ... ({args.nthread} threads)")
50
+ crops, num_crops_to_generate = load_crop_file(args.crops)
51
+
52
+ print(f"Preparing jobs ({len(crops)} candidate image pairs)...")
53
+ num_levels = min(
54
+ math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)),
55
+ args.max_subdir_levels,
56
+ )
57
+ num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1 / num_levels))
58
+
59
+ jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
60
+ del crops
61
+
62
+ os.makedirs(args.output_dir, exist_ok=True)
63
+ mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
64
+ call = functools.partial(save_image_crops, args)
65
+
66
+ print(f"Generating cropped images to {args.output_dir} ...")
67
+ with open(listing_path, "w") as listing:
68
+ listing.write("# pair_path\n")
69
+ for results in tqdm(mmap(call, jobs), total=len(jobs)):
70
+ for path in results:
71
+ listing.write(f"{path}\n")
72
+ print("Finished writing listing to", listing_path)
73
+
74
+
75
+ def load_crop_file(path):
76
+ data = open(path).read().splitlines()
77
+ pairs = []
78
+ num_crops_to_generate = 0
79
+ for line in tqdm(data):
80
+ if line.startswith("#"):
81
+ continue
82
+ line = line.split(", ")
83
+ if len(line) < 8:
84
+ img1, img2, rotation = line
85
+ pairs.append((img1, img2, int(rotation), []))
86
+ else:
87
+ l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
88
+ rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
89
+ pairs[-1][-1].append((rect1, rect2))
90
+ num_crops_to_generate += 1
91
+ return pairs, num_crops_to_generate
92
+
93
+
94
+ def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
95
+ jobs = []
96
+ powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
97
+
98
+ def get_path(idx):
99
+ idx_array = []
100
+ d = idx
101
+ for level in range(num_levels - 1):
102
+ idx_array.append(idx // powers[level])
103
+ idx = idx % powers[level]
104
+ idx_array.append(d)
105
+ return "/".join(map(lambda x: hex(x)[2:], idx_array))
106
+
107
+ idx = 0
108
+ for pair_data in tqdm(pairs):
109
+ img1, img2, rotation, crops = pair_data
110
+ if -60 <= rotation and rotation <= 60:
111
+ rotation = 0 # most likely not a true rotation
112
+ paths = [get_path(idx + k) for k in range(len(crops))]
113
+ idx += len(crops)
114
+ jobs.append(((img1, img2), rotation, crops, paths))
115
+ return jobs
116
+
117
+
118
+ def load_image(path):
119
+ try:
120
+ return Image.open(path).convert("RGB")
121
+ except Exception as e:
122
+ print("skipping", path, e)
123
+ raise OSError()
124
+
125
+
126
+ def save_image_crops(args, data):
127
+ # load images
128
+ img_pair, rot, crops, paths = data
129
+ try:
130
+ img1, img2 = [
131
+ load_image(os.path.join(args.root_dir, impath)) for impath in img_pair
132
+ ]
133
+ except OSError as e:
134
+ return []
135
+
136
+ def area(sz):
137
+ return sz[0] * sz[1]
138
+
139
+ tgt_size = (args.imsize, args.imsize)
140
+
141
+ def prepare_crop(img, rect, rot=0):
142
+ # actual crop
143
+ img = img.crop(rect)
144
+
145
+ # resize to desired size
146
+ interp = (
147
+ Image.Resampling.LANCZOS
148
+ if area(img.size) > 4 * area(tgt_size)
149
+ else Image.Resampling.BICUBIC
150
+ )
151
+ img = img.resize(tgt_size, resample=interp)
152
+
153
+ # rotate the image
154
+ rot90 = (round(rot / 90) % 4) * 90
155
+ if rot90 == 90:
156
+ img = img.transpose(Image.Transpose.ROTATE_90)
157
+ elif rot90 == 180:
158
+ img = img.transpose(Image.Transpose.ROTATE_180)
159
+ elif rot90 == 270:
160
+ img = img.transpose(Image.Transpose.ROTATE_270)
161
+ return img
162
+
163
+ results = []
164
+ for (rect1, rect2), path in zip(crops, paths):
165
+ crop1 = prepare_crop(img1, rect1)
166
+ crop2 = prepare_crop(img2, rect2, rot)
167
+
168
+ fullpath1 = os.path.join(args.output_dir, path + "_1.jpg")
169
+ fullpath2 = os.path.join(args.output_dir, path + "_2.jpg")
170
+ os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
171
+
172
+ assert not os.path.isfile(fullpath1), fullpath1
173
+ assert not os.path.isfile(fullpath2), fullpath2
174
+ crop1.save(fullpath1)
175
+ crop2.save(fullpath2)
176
+ results.append(path)
177
+
178
+ return results
179
+
180
+
181
+ if __name__ == "__main__":
182
+ args = arg_parser().parse_args()
183
+ main(args)
croco/datasets/habitat_sim/README.MD ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of synthetic image pairs using Habitat-Sim
2
+
3
+ These instructions allow to generate pre-training pairs from the Habitat simulator.
4
+ As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
5
+
6
+ ### Download Habitat-Sim scenes
7
+ Download Habitat-Sim scenes:
8
+ - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
9
+ - We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
10
+ - Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
11
+ ```
12
+ ./data/
13
+ └──habitat-sim-data/
14
+ └──scene_datasets/
15
+ ├──hm3d/
16
+ ├──gibson/
17
+ ├──habitat-test-scenes/
18
+ ├──replica_cad_baked_lighting/
19
+ ├──replica_cad/
20
+ ├──ReplicaDataset/
21
+ └──scannet/
22
+ ```
23
+
24
+ ### Image pairs generation
25
+ We provide metadata to generate reproducible images pairs for pretraining and validation.
26
+ Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
27
+
28
+ Specifications:
29
+ - 256x256 resolution images, with 60 degrees field of view .
30
+ - Up to 1000 image pairs per scene.
31
+ - Number of scenes considered/number of images pairs per dataset:
32
+ - Scannet: 1097 scenes / 985 209 pairs
33
+ - HM3D:
34
+ - hm3d/train: 800 / 800k pairs
35
+ - hm3d/val: 100 scenes / 100k pairs
36
+ - hm3d/minival: 10 scenes / 10k pairs
37
+ - habitat-test-scenes: 3 scenes / 3k pairs
38
+ - replica_cad_baked_lighting: 13 scenes / 13k pairs
39
+
40
+ - Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
41
+
42
+ Download metadata and extract it:
43
+ ```bash
44
+ mkdir -p data/habitat_release_metadata/
45
+ cd data/habitat_release_metadata/
46
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
47
+ tar -xvf multiview_habitat_metadata.tar.gz
48
+ cd ../..
49
+ # Location of the metadata
50
+ METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
51
+ ```
52
+
53
+ Generate image pairs from metadata:
54
+ - The following command will print a list of commandlines to generate image pairs for each scene:
55
+ ```bash
56
+ # Target output directory
57
+ PAIRS_DATASET_DIR="./data/habitat_release/"
58
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
59
+ ```
60
+ - One can launch multiple of such commands in parallel e.g. using GNU Parallel:
61
+ ```bash
62
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
63
+ ```
64
+
65
+ ## Metadata generation
66
+
67
+ Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
68
+ ```bash
69
+ # Print commandlines to generate image pairs from the different scenes available.
70
+ PAIRS_DATASET_DIR=MY_CUSTOM_PATH
71
+ python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
72
+
73
+ # Once a dataset is generated, pack metadata files for reproducibility.
74
+ METADATA_DIR=MY_CUSTON_PATH
75
+ python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
76
+ ```
croco/datasets/habitat_sim/__init__.py ADDED
File without changes
croco/datasets/habitat_sim/generate_from_metadata.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
6
+ """
7
+ import os
8
+ from datasets.habitat_sim.multiview_habitat_sim_generator import (
9
+ MultiviewHabitatSimGenerator,
10
+ )
11
+ from datasets.habitat_sim.paths import SCENES_DATASET
12
+ import argparse
13
+ import quaternion
14
+ import PIL.Image
15
+ import cv2
16
+ import json
17
+ from tqdm import tqdm
18
+
19
+
20
+ def generate_multiview_images_from_metadata(
21
+ metadata_filename,
22
+ output_dir,
23
+ overload_params=dict(),
24
+ scene_datasets_paths=None,
25
+ exist_ok=False,
26
+ ):
27
+ """
28
+ Generate images from a metadata file for reproducibility purposes.
29
+ """
30
+ # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label
31
+ if scene_datasets_paths is not None:
32
+ scene_datasets_paths = dict(
33
+ sorted(scene_datasets_paths.items(), key=lambda x: len(x[0]), reverse=True)
34
+ )
35
+
36
+ with open(metadata_filename, "r") as f:
37
+ input_metadata = json.load(f)
38
+ metadata = dict()
39
+ for key, value in input_metadata.items():
40
+ # Optionally replace some paths
41
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
42
+ if scene_datasets_paths is not None:
43
+ for dataset_label, dataset_path in scene_datasets_paths.items():
44
+ if value.startswith(dataset_label):
45
+ value = os.path.normpath(
46
+ os.path.join(
47
+ dataset_path, os.path.relpath(value, dataset_label)
48
+ )
49
+ )
50
+ break
51
+ metadata[key] = value
52
+
53
+ # Overload some parameters
54
+ for key, value in overload_params.items():
55
+ metadata[key] = value
56
+
57
+ generation_entries = dict(
58
+ [
59
+ (key, value)
60
+ for key, value in metadata.items()
61
+ if not (key in ("multiviews", "output_dir", "generate_depth"))
62
+ ]
63
+ )
64
+ generate_depth = metadata["generate_depth"]
65
+
66
+ os.makedirs(output_dir, exist_ok=exist_ok)
67
+
68
+ generator = MultiviewHabitatSimGenerator(**generation_entries)
69
+
70
+ # Generate views
71
+ for idx_label, data in tqdm(metadata["multiviews"].items()):
72
+ positions = data["positions"]
73
+ orientations = data["orientations"]
74
+ n = len(positions)
75
+ for oidx in range(n):
76
+ observation = generator.render_viewpoint(
77
+ positions[oidx], quaternion.from_float_array(orientations[oidx])
78
+ )
79
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
80
+ # Color image saved using PIL
81
+ img = PIL.Image.fromarray(observation["color"][:, :, :3])
82
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
83
+ img.save(filename)
84
+ if generate_depth:
85
+ # Depth image as EXR file
86
+ filename = os.path.join(
87
+ output_dir, f"{idx_label}_{observation_label}_depth.exr"
88
+ )
89
+ cv2.imwrite(
90
+ filename,
91
+ observation["depth"],
92
+ [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF],
93
+ )
94
+ # Camera parameters
95
+ camera_params = dict(
96
+ [
97
+ (key, observation[key].tolist())
98
+ for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")
99
+ ]
100
+ )
101
+ filename = os.path.join(
102
+ output_dir, f"{idx_label}_{observation_label}_camera_params.json"
103
+ )
104
+ with open(filename, "w") as f:
105
+ json.dump(camera_params, f)
106
+ # Save metadata
107
+ with open(os.path.join(output_dir, "metadata.json"), "w") as f:
108
+ json.dump(metadata, f)
109
+
110
+ generator.close()
111
+
112
+
113
+ if __name__ == "__main__":
114
+ parser = argparse.ArgumentParser()
115
+ parser.add_argument("--metadata_filename", required=True)
116
+ parser.add_argument("--output_dir", required=True)
117
+ args = parser.parse_args()
118
+
119
+ generate_multiview_images_from_metadata(
120
+ metadata_filename=args.metadata_filename,
121
+ output_dir=args.output_dir,
122
+ scene_datasets_paths=SCENES_DATASET,
123
+ overload_params=dict(),
124
+ exist_ok=True,
125
+ )
croco/datasets/habitat_sim/generate_from_metadata_files.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script generating commandlines to generate image pairs from metadata files.
6
+ """
7
+ import os
8
+ import glob
9
+ from tqdm import tqdm
10
+ import argparse
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--input_dir", required=True)
15
+ parser.add_argument("--output_dir", required=True)
16
+ parser.add_argument(
17
+ "--prefix",
18
+ default="",
19
+ help="Commanline prefix, useful e.g. to setup environment.",
20
+ )
21
+ args = parser.parse_args()
22
+
23
+ input_metadata_filenames = glob.iglob(
24
+ f"{args.input_dir}/**/metadata.json", recursive=True
25
+ )
26
+
27
+ for metadata_filename in tqdm(input_metadata_filenames):
28
+ output_dir = os.path.join(
29
+ args.output_dir,
30
+ os.path.relpath(os.path.dirname(metadata_filename), args.input_dir),
31
+ )
32
+ # Do not process the scene if the metadata file already exists
33
+ if os.path.exists(os.path.join(output_dir, "metadata.json")):
34
+ continue
35
+ commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
36
+ print(commandline)
croco/datasets/habitat_sim/generate_multiview_images.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from tqdm import tqdm
6
+ import argparse
7
+ import PIL.Image
8
+ import numpy as np
9
+ import json
10
+ from datasets.habitat_sim.multiview_habitat_sim_generator import (
11
+ MultiviewHabitatSimGenerator,
12
+ NoNaviguableSpaceError,
13
+ )
14
+ from datasets.habitat_sim.paths import list_scenes_available
15
+ import cv2
16
+ import quaternion
17
+ import shutil
18
+
19
+
20
+ def generate_multiview_images_for_scene(
21
+ scene_dataset_config_file,
22
+ scene,
23
+ navmesh,
24
+ output_dir,
25
+ views_count,
26
+ size,
27
+ exist_ok=False,
28
+ generate_depth=False,
29
+ **kwargs,
30
+ ):
31
+ """
32
+ Generate tuples of overlapping views for a given scene.
33
+ generate_depth: generate depth images and camera parameters.
34
+ """
35
+ if os.path.exists(output_dir) and not exist_ok:
36
+ print(f"Scene {scene}: data already generated. Ignoring generation.")
37
+ return
38
+ try:
39
+ print(f"Scene {scene}: {size} multiview acquisitions to generate...")
40
+ os.makedirs(output_dir, exist_ok=exist_ok)
41
+
42
+ metadata_filename = os.path.join(output_dir, "metadata.json")
43
+
44
+ metadata_template = dict(
45
+ scene_dataset_config_file=scene_dataset_config_file,
46
+ scene=scene,
47
+ navmesh=navmesh,
48
+ views_count=views_count,
49
+ size=size,
50
+ generate_depth=generate_depth,
51
+ **kwargs,
52
+ )
53
+ metadata_template["multiviews"] = dict()
54
+
55
+ if os.path.exists(metadata_filename):
56
+ print("Metadata file already exists:", metadata_filename)
57
+ print("Loading already generated metadata file...")
58
+ with open(metadata_filename, "r") as f:
59
+ metadata = json.load(f)
60
+
61
+ for key in metadata_template.keys():
62
+ if key != "multiviews":
63
+ assert (
64
+ metadata_template[key] == metadata[key]
65
+ ), f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}."
66
+ else:
67
+ print("No temporary file found. Starting generation from scratch...")
68
+ metadata = metadata_template
69
+
70
+ starting_id = len(metadata["multiviews"])
71
+ print(f"Starting generation from index {starting_id}/{size}...")
72
+ if starting_id >= size:
73
+ print("Generation already done.")
74
+ return
75
+
76
+ generator = MultiviewHabitatSimGenerator(
77
+ scene_dataset_config_file=scene_dataset_config_file,
78
+ scene=scene,
79
+ navmesh=navmesh,
80
+ views_count=views_count,
81
+ size=size,
82
+ **kwargs,
83
+ )
84
+
85
+ for idx in tqdm(range(starting_id, size)):
86
+ # Generate / re-generate the observations
87
+ try:
88
+ data = generator[idx]
89
+ observations = data["observations"]
90
+ positions = data["positions"]
91
+ orientations = data["orientations"]
92
+
93
+ idx_label = f"{idx:08}"
94
+ for oidx, observation in enumerate(observations):
95
+ observation_label = (
96
+ f"{oidx + 1}" # Leonid is indexing starting from 1
97
+ )
98
+ # Color image saved using PIL
99
+ img = PIL.Image.fromarray(observation["color"][:, :, :3])
100
+ filename = os.path.join(
101
+ output_dir, f"{idx_label}_{observation_label}.jpeg"
102
+ )
103
+ img.save(filename)
104
+ if generate_depth:
105
+ # Depth image as EXR file
106
+ filename = os.path.join(
107
+ output_dir, f"{idx_label}_{observation_label}_depth.exr"
108
+ )
109
+ cv2.imwrite(
110
+ filename,
111
+ observation["depth"],
112
+ [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF],
113
+ )
114
+ # Camera parameters
115
+ camera_params = dict(
116
+ [
117
+ (key, observation[key].tolist())
118
+ for key in (
119
+ "camera_intrinsics",
120
+ "R_cam2world",
121
+ "t_cam2world",
122
+ )
123
+ ]
124
+ )
125
+ filename = os.path.join(
126
+ output_dir,
127
+ f"{idx_label}_{observation_label}_camera_params.json",
128
+ )
129
+ with open(filename, "w") as f:
130
+ json.dump(camera_params, f)
131
+ metadata["multiviews"][idx_label] = {
132
+ "positions": positions.tolist(),
133
+ "orientations": orientations.tolist(),
134
+ "covisibility_ratios": data["covisibility_ratios"].tolist(),
135
+ "valid_fractions": data["valid_fractions"].tolist(),
136
+ "pairwise_visibility_ratios": data[
137
+ "pairwise_visibility_ratios"
138
+ ].tolist(),
139
+ }
140
+ except RecursionError:
141
+ print(
142
+ "Recursion error: unable to sample observations for this scene. We will stop there."
143
+ )
144
+ break
145
+
146
+ # Regularly save a temporary metadata file, in case we need to restart the generation
147
+ if idx % 10 == 0:
148
+ with open(metadata_filename, "w") as f:
149
+ json.dump(metadata, f)
150
+
151
+ # Save metadata
152
+ with open(metadata_filename, "w") as f:
153
+ json.dump(metadata, f)
154
+
155
+ generator.close()
156
+ except NoNaviguableSpaceError:
157
+ pass
158
+
159
+
160
+ def create_commandline(scene_data, generate_depth, exist_ok=False):
161
+ """
162
+ Create a commandline string to generate a scene.
163
+ """
164
+
165
+ def my_formatting(val):
166
+ if val is None or val == "":
167
+ return '""'
168
+ else:
169
+ return val
170
+
171
+ commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)}
172
+ --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)}
173
+ --navmesh {my_formatting(scene_data.navmesh)}
174
+ --output_dir {my_formatting(scene_data.output_dir)}
175
+ --generate_depth {int(generate_depth)}
176
+ --exist_ok {int(exist_ok)}
177
+ """
178
+ commandline = " ".join(commandline.split())
179
+ return commandline
180
+
181
+
182
+ if __name__ == "__main__":
183
+ os.umask(2)
184
+
185
+ parser = argparse.ArgumentParser(
186
+ description="""Example of use -- listing commands to generate data for scenes available:
187
+ > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands
188
+ """
189
+ )
190
+
191
+ parser.add_argument("--output_dir", type=str, required=True)
192
+ parser.add_argument(
193
+ "--list_commands", action="store_true", help="list commandlines to run if true"
194
+ )
195
+ parser.add_argument("--scene", type=str, default="")
196
+ parser.add_argument("--scene_dataset_config_file", type=str, default="")
197
+ parser.add_argument("--navmesh", type=str, default="")
198
+
199
+ parser.add_argument("--generate_depth", type=int, default=1)
200
+ parser.add_argument("--exist_ok", type=int, default=0)
201
+
202
+ kwargs = dict(resolution=(256, 256), hfov=60, views_count=2, size=1000)
203
+
204
+ args = parser.parse_args()
205
+ generate_depth = bool(args.generate_depth)
206
+ exist_ok = bool(args.exist_ok)
207
+
208
+ if args.list_commands:
209
+ # Listing scenes available...
210
+ scenes_data = list_scenes_available(base_output_dir=args.output_dir)
211
+
212
+ for scene_data in scenes_data:
213
+ print(
214
+ create_commandline(
215
+ scene_data, generate_depth=generate_depth, exist_ok=exist_ok
216
+ )
217
+ )
218
+ else:
219
+ if args.scene == "" or args.output_dir == "":
220
+ print("Missing scene or output dir argument!")
221
+ print(parser.format_help())
222
+ else:
223
+ generate_multiview_images_for_scene(
224
+ scene=args.scene,
225
+ scene_dataset_config_file=args.scene_dataset_config_file,
226
+ navmesh=args.navmesh,
227
+ output_dir=args.output_dir,
228
+ exist_ok=exist_ok,
229
+ generate_depth=generate_depth,
230
+ **kwargs,
231
+ )
croco/datasets/habitat_sim/multiview_habitat_sim_generator.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ import numpy as np
6
+ import quaternion
7
+ import habitat_sim
8
+ import json
9
+ from sklearn.neighbors import NearestNeighbors
10
+ import cv2
11
+
12
+ # OpenCV to habitat camera convention transformation
13
+ R_OPENCV2HABITAT = np.stack(
14
+ (habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0
15
+ )
16
+ R_HABITAT2OPENCV = R_OPENCV2HABITAT.T
17
+ DEG2RAD = np.pi / 180
18
+
19
+
20
+ def compute_camera_intrinsics(height, width, hfov):
21
+ f = width / 2 / np.tan(hfov / 2 * np.pi / 180)
22
+ cu, cv = width / 2, height / 2
23
+ return f, cu, cv
24
+
25
+
26
+ def compute_camera_pose_opencv_convention(camera_position, camera_orientation):
27
+ R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT
28
+ t_cam2world = np.asarray(camera_position)
29
+ return R_cam2world, t_cam2world
30
+
31
+
32
+ def compute_pointmap(depthmap, hfov):
33
+ """Compute a HxWx3 pointmap in camera frame from a HxW depth map."""
34
+ height, width = depthmap.shape
35
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
36
+ # Cast depth map to point
37
+ z_cam = depthmap
38
+ u, v = np.meshgrid(range(width), range(height))
39
+ x_cam = (u - cu) / f * z_cam
40
+ y_cam = (v - cv) / f * z_cam
41
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1)
42
+ return X_cam
43
+
44
+
45
+ def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation):
46
+ """Return a 3D point cloud corresponding to valid pixels of the depth map"""
47
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(
48
+ camera_position, camera_rotation
49
+ )
50
+
51
+ X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov)
52
+ valid_mask = X_cam[:, :, 2] != 0.0
53
+
54
+ X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()]
55
+ X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3)
56
+ return X_world
57
+
58
+
59
+ def compute_pointcloud_overlaps_scikit(
60
+ pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False
61
+ ):
62
+ """
63
+ Compute 'overlapping' metrics based on a distance threshold between two point clouds.
64
+ """
65
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(pointcloud2)
66
+ distances, indices = nbrs.kneighbors(pointcloud1)
67
+ intersection1 = np.count_nonzero(distances.flatten() < distance_threshold)
68
+
69
+ data = {"intersection1": intersection1, "size1": len(pointcloud1)}
70
+ if compute_symmetric:
71
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(pointcloud1)
72
+ distances, indices = nbrs.kneighbors(pointcloud2)
73
+ intersection2 = np.count_nonzero(distances.flatten() < distance_threshold)
74
+ data["intersection2"] = intersection2
75
+ data["size2"] = len(pointcloud2)
76
+
77
+ return data
78
+
79
+
80
+ def _append_camera_parameters(observation, hfov, camera_location, camera_rotation):
81
+ """
82
+ Add camera parameters to the observation dictionnary produced by Habitat-Sim
83
+ In-place modifications.
84
+ """
85
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(
86
+ camera_location, camera_rotation
87
+ )
88
+ height, width = observation["depth"].shape
89
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
90
+ K = np.asarray([[f, 0, cu], [0, f, cv], [0, 0, 1.0]])
91
+ observation["camera_intrinsics"] = K
92
+ observation["t_cam2world"] = t_cam2world
93
+ observation["R_cam2world"] = R_cam2world
94
+
95
+
96
+ def look_at(eye, center, up, return_cam2world=True):
97
+ """
98
+ Return camera pose looking at a given center point.
99
+ Analogous of gluLookAt function, using OpenCV camera convention.
100
+ """
101
+ z = center - eye
102
+ z /= np.linalg.norm(z, axis=-1, keepdims=True)
103
+ y = -up
104
+ y = y - np.sum(y * z, axis=-1, keepdims=True) * z
105
+ y /= np.linalg.norm(y, axis=-1, keepdims=True)
106
+ x = np.cross(y, z, axis=-1)
107
+
108
+ if return_cam2world:
109
+ R = np.stack((x, y, z), axis=-1)
110
+ t = eye
111
+ else:
112
+ # World to camera transformation
113
+ # Transposed matrix
114
+ R = np.stack((x, y, z), axis=-2)
115
+ t = -np.einsum("...ij, ...j", R, eye)
116
+ return R, t
117
+
118
+
119
+ def look_at_for_habitat(eye, center, up, return_cam2world=True):
120
+ R, t = look_at(eye, center, up)
121
+ orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T)
122
+ return orientation, t
123
+
124
+
125
+ def generate_orientation_noise(pan_range, tilt_range, roll_range):
126
+ return (
127
+ quaternion.from_rotation_vector(
128
+ np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP
129
+ )
130
+ * quaternion.from_rotation_vector(
131
+ np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT
132
+ )
133
+ * quaternion.from_rotation_vector(
134
+ np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT
135
+ )
136
+ )
137
+
138
+
139
+ class NoNaviguableSpaceError(RuntimeError):
140
+ def __init__(self, *args):
141
+ super().__init__(*args)
142
+
143
+
144
+ class MultiviewHabitatSimGenerator:
145
+ def __init__(
146
+ self,
147
+ scene,
148
+ navmesh,
149
+ scene_dataset_config_file,
150
+ resolution=(240, 320),
151
+ views_count=2,
152
+ hfov=60,
153
+ gpu_id=0,
154
+ size=10000,
155
+ minimum_covisibility=0.5,
156
+ transform=None,
157
+ ):
158
+ self.scene = scene
159
+ self.navmesh = navmesh
160
+ self.scene_dataset_config_file = scene_dataset_config_file
161
+ self.resolution = resolution
162
+ self.views_count = views_count
163
+ assert self.views_count >= 1
164
+ self.hfov = hfov
165
+ self.gpu_id = gpu_id
166
+ self.size = size
167
+ self.transform = transform
168
+
169
+ # Noise added to camera orientation
170
+ self.pan_range = (-3, 3)
171
+ self.tilt_range = (-10, 10)
172
+ self.roll_range = (-5, 5)
173
+
174
+ # Height range to sample cameras
175
+ self.height_range = (1.2, 1.8)
176
+
177
+ # Random steps between the camera views
178
+ self.random_steps_count = 5
179
+ self.random_step_variance = 2.0
180
+
181
+ # Minimum fraction of the scene which should be valid (well defined depth)
182
+ self.minimum_valid_fraction = 0.7
183
+
184
+ # Distance threshold to see to select pairs
185
+ self.distance_threshold = 0.05
186
+ # Minimum IoU of a view point cloud with respect to the reference view to be kept.
187
+ self.minimum_covisibility = minimum_covisibility
188
+
189
+ # Maximum number of retries.
190
+ self.max_attempts_count = 100
191
+
192
+ self.seed = None
193
+ self._lazy_initialization()
194
+
195
+ def _lazy_initialization(self):
196
+ # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
197
+ if self.seed == None:
198
+ # Re-seed numpy generator
199
+ np.random.seed()
200
+ self.seed = np.random.randint(2**32 - 1)
201
+ sim_cfg = habitat_sim.SimulatorConfiguration()
202
+ sim_cfg.scene_id = self.scene
203
+ if (
204
+ self.scene_dataset_config_file is not None
205
+ and self.scene_dataset_config_file != ""
206
+ ):
207
+ sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
208
+ sim_cfg.random_seed = self.seed
209
+ sim_cfg.load_semantic_mesh = False
210
+ sim_cfg.gpu_device_id = self.gpu_id
211
+
212
+ depth_sensor_spec = habitat_sim.CameraSensorSpec()
213
+ depth_sensor_spec.uuid = "depth"
214
+ depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
215
+ depth_sensor_spec.resolution = self.resolution
216
+ depth_sensor_spec.hfov = self.hfov
217
+ depth_sensor_spec.position = [0.0, 0.0, 0]
218
+ depth_sensor_spec.orientation
219
+
220
+ rgb_sensor_spec = habitat_sim.CameraSensorSpec()
221
+ rgb_sensor_spec.uuid = "color"
222
+ rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
223
+ rgb_sensor_spec.resolution = self.resolution
224
+ rgb_sensor_spec.hfov = self.hfov
225
+ rgb_sensor_spec.position = [0.0, 0.0, 0]
226
+ agent_cfg = habitat_sim.agent.AgentConfiguration(
227
+ sensor_specifications=[rgb_sensor_spec, depth_sensor_spec]
228
+ )
229
+
230
+ cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
231
+ self.sim = habitat_sim.Simulator(cfg)
232
+ if self.navmesh is not None and self.navmesh != "":
233
+ # Use pre-computed navmesh when available (usually better than those generated automatically)
234
+ self.sim.pathfinder.load_nav_mesh(self.navmesh)
235
+
236
+ if not self.sim.pathfinder.is_loaded:
237
+ # Try to compute a navmesh
238
+ navmesh_settings = habitat_sim.NavMeshSettings()
239
+ navmesh_settings.set_defaults()
240
+ self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)
241
+
242
+ # Ensure that the navmesh is not empty
243
+ if not self.sim.pathfinder.is_loaded:
244
+ raise NoNaviguableSpaceError(
245
+ f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})"
246
+ )
247
+
248
+ self.agent = self.sim.initialize_agent(agent_id=0)
249
+
250
+ def close(self):
251
+ self.sim.close()
252
+
253
+ def __del__(self):
254
+ self.sim.close()
255
+
256
+ def __len__(self):
257
+ return self.size
258
+
259
+ def sample_random_viewpoint(self):
260
+ """Sample a random viewpoint using the navmesh"""
261
+ nav_point = self.sim.pathfinder.get_random_navigable_point()
262
+
263
+ # Sample a random viewpoint height
264
+ viewpoint_height = np.random.uniform(*self.height_range)
265
+ viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
266
+ viewpoint_orientation = quaternion.from_rotation_vector(
267
+ np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP
268
+ ) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
269
+ return viewpoint_position, viewpoint_orientation, nav_point
270
+
271
+ def sample_other_random_viewpoint(self, observed_point, nav_point):
272
+ """Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point."""
273
+ other_nav_point = nav_point
274
+
275
+ walk_directions = self.random_step_variance * np.asarray([1, 0, 1])
276
+ for i in range(self.random_steps_count):
277
+ temp = self.sim.pathfinder.snap_point(
278
+ other_nav_point + walk_directions * np.random.normal(size=3)
279
+ )
280
+ # Snapping may return nan when it fails
281
+ if not np.isnan(temp[0]):
282
+ other_nav_point = temp
283
+
284
+ other_viewpoint_height = np.random.uniform(*self.height_range)
285
+ other_viewpoint_position = (
286
+ other_nav_point + other_viewpoint_height * habitat_sim.geo.UP
287
+ )
288
+
289
+ # Set viewing direction towards the central point
290
+ rotation, position = look_at_for_habitat(
291
+ eye=other_viewpoint_position,
292
+ center=observed_point,
293
+ up=habitat_sim.geo.UP,
294
+ return_cam2world=True,
295
+ )
296
+ rotation = rotation * generate_orientation_noise(
297
+ self.pan_range, self.tilt_range, self.roll_range
298
+ )
299
+ return position, rotation, other_nav_point
300
+
301
+ def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud):
302
+ """Check if a viewpoint is valid and overlaps significantly with a reference one."""
303
+ # Observation
304
+ pixels_count = self.resolution[0] * self.resolution[1]
305
+ valid_fraction = len(other_pointcloud) / pixels_count
306
+ assert valid_fraction <= 1.0 and valid_fraction >= 0.0
307
+ overlap = compute_pointcloud_overlaps_scikit(
308
+ ref_pointcloud,
309
+ other_pointcloud,
310
+ self.distance_threshold,
311
+ compute_symmetric=True,
312
+ )
313
+ covisibility = min(
314
+ overlap["intersection1"] / pixels_count,
315
+ overlap["intersection2"] / pixels_count,
316
+ )
317
+ is_valid = (valid_fraction >= self.minimum_valid_fraction) and (
318
+ covisibility >= self.minimum_covisibility
319
+ )
320
+ return is_valid, valid_fraction, covisibility
321
+
322
+ def is_other_viewpoint_overlapping(
323
+ self, ref_pointcloud, observation, position, rotation
324
+ ):
325
+ """Check if a viewpoint is valid and overlaps significantly with a reference one."""
326
+ # Observation
327
+ other_pointcloud = compute_pointcloud(
328
+ observation["depth"], self.hfov, position, rotation
329
+ )
330
+ return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
331
+
332
+ def render_viewpoint(self, viewpoint_position, viewpoint_orientation):
333
+ agent_state = habitat_sim.AgentState()
334
+ agent_state.position = viewpoint_position
335
+ agent_state.rotation = viewpoint_orientation
336
+ self.agent.set_state(agent_state)
337
+ viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)
338
+ _append_camera_parameters(
339
+ viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation
340
+ )
341
+ return viewpoint_observations
342
+
343
+ def __getitem__(self, useless_idx):
344
+ ref_position, ref_orientation, nav_point = self.sample_random_viewpoint()
345
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
346
+ # Extract point cloud
347
+ ref_pointcloud = compute_pointcloud(
348
+ depthmap=ref_observations["depth"],
349
+ hfov=self.hfov,
350
+ camera_position=ref_position,
351
+ camera_rotation=ref_orientation,
352
+ )
353
+
354
+ pixels_count = self.resolution[0] * self.resolution[1]
355
+ ref_valid_fraction = len(ref_pointcloud) / pixels_count
356
+ assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0
357
+ if ref_valid_fraction < self.minimum_valid_fraction:
358
+ # This should produce a recursion error at some point when something is very wrong.
359
+ return self[0]
360
+ # Pick an reference observed point in the point cloud
361
+ observed_point = np.mean(ref_pointcloud, axis=0)
362
+
363
+ # Add the first image as reference
364
+ viewpoints_observations = [ref_observations]
365
+ viewpoints_covisibility = [ref_valid_fraction]
366
+ viewpoints_positions = [ref_position]
367
+ viewpoints_orientations = [quaternion.as_float_array(ref_orientation)]
368
+ viewpoints_clouds = [ref_pointcloud]
369
+ viewpoints_valid_fractions = [ref_valid_fraction]
370
+
371
+ for _ in range(self.views_count - 1):
372
+ # Generate an other viewpoint using some dummy random walk
373
+ successful_sampling = False
374
+ for sampling_attempt in range(self.max_attempts_count):
375
+ position, rotation, _ = self.sample_other_random_viewpoint(
376
+ observed_point, nav_point
377
+ )
378
+ # Observation
379
+ other_viewpoint_observations = self.render_viewpoint(position, rotation)
380
+ other_pointcloud = compute_pointcloud(
381
+ other_viewpoint_observations["depth"], self.hfov, position, rotation
382
+ )
383
+
384
+ is_valid, valid_fraction, covisibility = (
385
+ self.is_other_pointcloud_overlapping(
386
+ ref_pointcloud, other_pointcloud
387
+ )
388
+ )
389
+ if is_valid:
390
+ successful_sampling = True
391
+ break
392
+ if not successful_sampling:
393
+ print("WARNING: Maximum number of attempts reached.")
394
+ # Dirty hack, try using a novel original viewpoint
395
+ return self[0]
396
+ viewpoints_observations.append(other_viewpoint_observations)
397
+ viewpoints_covisibility.append(covisibility)
398
+ viewpoints_positions.append(position)
399
+ viewpoints_orientations.append(
400
+ quaternion.as_float_array(rotation)
401
+ ) # WXYZ convention for the quaternion encoding.
402
+ viewpoints_clouds.append(other_pointcloud)
403
+ viewpoints_valid_fractions.append(valid_fraction)
404
+
405
+ # Estimate relations between all pairs of images
406
+ pairwise_visibility_ratios = np.ones(
407
+ (len(viewpoints_observations), len(viewpoints_observations))
408
+ )
409
+ for i in range(len(viewpoints_observations)):
410
+ pairwise_visibility_ratios[i, i] = viewpoints_valid_fractions[i]
411
+ for j in range(i + 1, len(viewpoints_observations)):
412
+ overlap = compute_pointcloud_overlaps_scikit(
413
+ viewpoints_clouds[i],
414
+ viewpoints_clouds[j],
415
+ self.distance_threshold,
416
+ compute_symmetric=True,
417
+ )
418
+ pairwise_visibility_ratios[i, j] = (
419
+ overlap["intersection1"] / pixels_count
420
+ )
421
+ pairwise_visibility_ratios[j, i] = (
422
+ overlap["intersection2"] / pixels_count
423
+ )
424
+
425
+ # IoU is relative to the image 0
426
+ data = {
427
+ "observations": viewpoints_observations,
428
+ "positions": np.asarray(viewpoints_positions),
429
+ "orientations": np.asarray(viewpoints_orientations),
430
+ "covisibility_ratios": np.asarray(viewpoints_covisibility),
431
+ "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float),
432
+ "pairwise_visibility_ratios": np.asarray(
433
+ pairwise_visibility_ratios, dtype=float
434
+ ),
435
+ }
436
+
437
+ if self.transform is not None:
438
+ data = self.transform(data)
439
+ return data
440
+
441
+ def generate_random_spiral_trajectory(
442
+ self,
443
+ images_count=100,
444
+ max_radius=0.5,
445
+ half_turns=5,
446
+ use_constant_orientation=False,
447
+ ):
448
+ """
449
+ Return a list of images corresponding to a spiral trajectory from a random starting point.
450
+ Useful to generate nice visualisations.
451
+ Use an even number of half turns to get a nice "C1-continuous" loop effect
452
+ """
453
+ ref_position, ref_orientation, navpoint = self.sample_random_viewpoint()
454
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
455
+ ref_pointcloud = compute_pointcloud(
456
+ depthmap=ref_observations["depth"],
457
+ hfov=self.hfov,
458
+ camera_position=ref_position,
459
+ camera_rotation=ref_orientation,
460
+ )
461
+ pixels_count = self.resolution[0] * self.resolution[1]
462
+ if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction:
463
+ # Dirty hack: ensure that the valid part of the image is significant
464
+ return self.generate_random_spiral_trajectory(
465
+ images_count, max_radius, half_turns, use_constant_orientation
466
+ )
467
+
468
+ # Pick an observed point in the point cloud
469
+ observed_point = np.mean(ref_pointcloud, axis=0)
470
+ ref_R, ref_t = compute_camera_pose_opencv_convention(
471
+ ref_position, ref_orientation
472
+ )
473
+
474
+ images = []
475
+ is_valid = []
476
+ # Spiral trajectory, use_constant orientation
477
+ for i, alpha in enumerate(np.linspace(0, 1, images_count)):
478
+ r = max_radius * np.abs(
479
+ np.sin(alpha * np.pi)
480
+ ) # Increase then decrease the radius
481
+ theta = alpha * half_turns * np.pi
482
+ x = r * np.cos(theta)
483
+ y = r * np.sin(theta)
484
+ z = 0.0
485
+ position = (
486
+ ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3, 1)).flatten()
487
+ )
488
+ if use_constant_orientation:
489
+ orientation = ref_orientation
490
+ else:
491
+ # trajectory looking at a mean point in front of the ref observation
492
+ orientation, position = look_at_for_habitat(
493
+ eye=position, center=observed_point, up=habitat_sim.geo.UP
494
+ )
495
+ observations = self.render_viewpoint(position, orientation)
496
+ images.append(observations["color"][..., :3])
497
+ _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(
498
+ ref_pointcloud, observations, position, orientation
499
+ )
500
+ is_valid.append(_is_valid)
501
+ return images, np.all(is_valid)
croco/datasets/habitat_sim/pack_metadata_files.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ """
4
+ Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere.
5
+ """
6
+ import os
7
+ import glob
8
+ from tqdm import tqdm
9
+ import shutil
10
+ import json
11
+ from datasets.habitat_sim.paths import *
12
+ import argparse
13
+ import collections
14
+
15
+ if __name__ == "__main__":
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("input_dir")
18
+ parser.add_argument("output_dir")
19
+ args = parser.parse_args()
20
+
21
+ input_dirname = args.input_dir
22
+ output_dirname = args.output_dir
23
+
24
+ input_metadata_filenames = glob.iglob(
25
+ f"{input_dirname}/**/metadata.json", recursive=True
26
+ )
27
+
28
+ images_count = collections.defaultdict(lambda: 0)
29
+
30
+ os.makedirs(output_dirname)
31
+ for input_filename in tqdm(input_metadata_filenames):
32
+ # Ignore empty files
33
+ with open(input_filename, "r") as f:
34
+ original_metadata = json.load(f)
35
+ if (
36
+ "multiviews" not in original_metadata
37
+ or len(original_metadata["multiviews"]) == 0
38
+ ):
39
+ print("No views in", input_filename)
40
+ continue
41
+
42
+ relpath = os.path.relpath(input_filename, input_dirname)
43
+ print(relpath)
44
+
45
+ # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability.
46
+ # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern.
47
+ scenes_dataset_paths = dict(
48
+ sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True)
49
+ )
50
+ metadata = dict()
51
+ for key, value in original_metadata.items():
52
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
53
+ known_path = False
54
+ for dataset, dataset_path in scenes_dataset_paths.items():
55
+ if value.startswith(dataset_path):
56
+ value = os.path.join(
57
+ dataset, os.path.relpath(value, dataset_path)
58
+ )
59
+ known_path = True
60
+ break
61
+ if not known_path:
62
+ raise KeyError("Unknown path:" + value)
63
+ metadata[key] = value
64
+
65
+ # Compile some general statistics while packing data
66
+ scene_split = metadata["scene"].split("/")
67
+ upper_level = (
68
+ "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0]
69
+ )
70
+ images_count[upper_level] += len(metadata["multiviews"])
71
+
72
+ output_filename = os.path.join(output_dirname, relpath)
73
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
74
+ with open(output_filename, "w") as f:
75
+ json.dump(metadata, f)
76
+
77
+ # Print statistics
78
+ print("Images count:")
79
+ for upper_level, count in images_count.items():
80
+ print(f"- {upper_level}: {count}")
croco/datasets/habitat_sim/paths.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Paths to Habitat-Sim scenes
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import collections
11
+ from tqdm import tqdm
12
+
13
+
14
+ # Hardcoded path to the different scene datasets
15
+ SCENES_DATASET = {
16
+ "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/",
17
+ "gibson": "./data/habitat-sim-data/scene_datasets/gibson/",
18
+ "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/",
19
+ "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/",
20
+ "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/",
21
+ "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/",
22
+ "scannet": "./data/habitat-sim/scene_datasets/scannet/",
23
+ }
24
+
25
+ SceneData = collections.namedtuple(
26
+ "SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"]
27
+ )
28
+
29
+
30
+ def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]):
31
+ scene_dataset_config_file = os.path.join(
32
+ base_path, "replicaCAD.scene_dataset_config.json"
33
+ )
34
+ scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"]
35
+ navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + [
36
+ "empty_stage.navmesh"
37
+ ]
38
+ scenes_data = []
39
+ for idx in range(len(scenes)):
40
+ output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx])
41
+ # Add scene
42
+ data = SceneData(
43
+ scene_dataset_config_file=scene_dataset_config_file,
44
+ scene=scenes[idx] + ".scene_instance.json",
45
+ navmesh=os.path.join(base_path, navmeshes[idx]),
46
+ output_dir=output_dir,
47
+ )
48
+ scenes_data.append(data)
49
+ return scenes_data
50
+
51
+
52
+ def list_replica_cad_baked_lighting_scenes(
53
+ base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]
54
+ ):
55
+ scene_dataset_config_file = os.path.join(
56
+ base_path, "replicaCAD_baked.scene_dataset_config.json"
57
+ )
58
+ scenes = sum(
59
+ [[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], []
60
+ )
61
+ navmeshes = "" # [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
62
+ scenes_data = []
63
+ for idx in range(len(scenes)):
64
+ output_dir = os.path.join(
65
+ base_output_dir, "replica_cad_baked_lighting", scenes[idx]
66
+ )
67
+ data = SceneData(
68
+ scene_dataset_config_file=scene_dataset_config_file,
69
+ scene=scenes[idx],
70
+ navmesh="",
71
+ output_dir=output_dir,
72
+ )
73
+ scenes_data.append(data)
74
+ return scenes_data
75
+
76
+
77
+ def list_replica_scenes(base_output_dir, base_path):
78
+ scenes_data = []
79
+ for scene_id in os.listdir(base_path):
80
+ scene = os.path.join(base_path, scene_id, "mesh.ply")
81
+ navmesh = os.path.join(
82
+ base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh"
83
+ ) # Not sure if I should use it
84
+ scene_dataset_config_file = ""
85
+ output_dir = os.path.join(base_output_dir, scene_id)
86
+ # Add scene only if it does not exist already, or if exist_ok
87
+ data = SceneData(
88
+ scene_dataset_config_file=scene_dataset_config_file,
89
+ scene=scene,
90
+ navmesh=navmesh,
91
+ output_dir=output_dir,
92
+ )
93
+ scenes_data.append(data)
94
+ return scenes_data
95
+
96
+
97
+ def list_scenes(base_output_dir, base_path):
98
+ """
99
+ Generic method iterating through a base_path folder to find scenes.
100
+ """
101
+ scenes_data = []
102
+ for root, dirs, files in os.walk(base_path, followlinks=True):
103
+ folder_scenes_data = []
104
+ for file in files:
105
+ name, ext = os.path.splitext(file)
106
+ if ext == ".glb":
107
+ scene = os.path.join(root, name + ".glb")
108
+ navmesh = os.path.join(root, name + ".navmesh")
109
+ if not os.path.exists(navmesh):
110
+ navmesh = ""
111
+ relpath = os.path.relpath(root, base_path)
112
+ output_dir = os.path.abspath(
113
+ os.path.join(base_output_dir, relpath, name)
114
+ )
115
+ data = SceneData(
116
+ scene_dataset_config_file="",
117
+ scene=scene,
118
+ navmesh=navmesh,
119
+ output_dir=output_dir,
120
+ )
121
+ folder_scenes_data.append(data)
122
+
123
+ # Specific check for HM3D:
124
+ # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version.
125
+ basis_scenes = [
126
+ data.scene[: -len(".basis.glb")]
127
+ for data in folder_scenes_data
128
+ if data.scene.endswith(".basis.glb")
129
+ ]
130
+ if len(basis_scenes) != 0:
131
+ folder_scenes_data = [
132
+ data
133
+ for data in folder_scenes_data
134
+ if not (data.scene[: -len(".glb")] in basis_scenes)
135
+ ]
136
+
137
+ scenes_data.extend(folder_scenes_data)
138
+ return scenes_data
139
+
140
+
141
+ def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET):
142
+ scenes_data = []
143
+
144
+ # HM3D
145
+ for split in ("minival", "train", "val", "examples"):
146
+ scenes_data += list_scenes(
147
+ base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"),
148
+ base_path=f"{scenes_dataset_paths['hm3d']}/{split}",
149
+ )
150
+
151
+ # Gibson
152
+ scenes_data += list_scenes(
153
+ base_output_dir=os.path.join(base_output_dir, "gibson"),
154
+ base_path=scenes_dataset_paths["gibson"],
155
+ )
156
+
157
+ # Habitat test scenes (just a few)
158
+ scenes_data += list_scenes(
159
+ base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"),
160
+ base_path=scenes_dataset_paths["habitat-test-scenes"],
161
+ )
162
+
163
+ # ReplicaCAD (baked lightning)
164
+ scenes_data += list_replica_cad_baked_lighting_scenes(
165
+ base_output_dir=base_output_dir
166
+ )
167
+
168
+ # ScanNet
169
+ scenes_data += list_scenes(
170
+ base_output_dir=os.path.join(base_output_dir, "scannet"),
171
+ base_path=scenes_dataset_paths["scannet"],
172
+ )
173
+
174
+ # Replica
175
+ list_replica_scenes(
176
+ base_output_dir=os.path.join(base_output_dir, "replica"),
177
+ base_path=scenes_dataset_paths["replica"],
178
+ )
179
+ return scenes_data
croco/datasets/pairs_dataset.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+
8
+ from datasets.transforms import get_pair_transforms
9
+
10
+
11
+ def load_image(impath):
12
+ return Image.open(impath)
13
+
14
+
15
+ def load_pairs_from_cache_file(fname, root=""):
16
+ assert os.path.isfile(
17
+ fname
18
+ ), "cannot parse pairs from {:s}, file does not exist".format(fname)
19
+ with open(fname, "r") as fid:
20
+ lines = fid.read().strip().splitlines()
21
+ pairs = [
22
+ (os.path.join(root, l.split()[0]), os.path.join(root, l.split()[1]))
23
+ for l in lines
24
+ ]
25
+ return pairs
26
+
27
+
28
+ def load_pairs_from_list_file(fname, root=""):
29
+ assert os.path.isfile(
30
+ fname
31
+ ), "cannot parse pairs from {:s}, file does not exist".format(fname)
32
+ with open(fname, "r") as fid:
33
+ lines = fid.read().strip().splitlines()
34
+ pairs = [
35
+ (os.path.join(root, l + "_1.jpg"), os.path.join(root, l + "_2.jpg"))
36
+ for l in lines
37
+ if not l.startswith("#")
38
+ ]
39
+ return pairs
40
+
41
+
42
+ def write_cache_file(fname, pairs, root=""):
43
+ if len(root) > 0:
44
+ if not root.endswith("/"):
45
+ root += "/"
46
+ assert os.path.isdir(root)
47
+ s = ""
48
+ for im1, im2 in pairs:
49
+ if len(root) > 0:
50
+ assert im1.startswith(root), im1
51
+ assert im2.startswith(root), im2
52
+ s += "{:s} {:s}\n".format(im1[len(root) :], im2[len(root) :])
53
+ with open(fname, "w") as fid:
54
+ fid.write(s[:-1])
55
+
56
+
57
+ def parse_and_cache_all_pairs(dname, data_dir="./data/"):
58
+ if dname == "habitat_release":
59
+ dirname = os.path.join(data_dir, "habitat_release")
60
+ assert os.path.isdir(dirname), (
61
+ "cannot find folder for habitat_release pairs: " + dirname
62
+ )
63
+ cache_file = os.path.join(dirname, "pairs.txt")
64
+ assert not os.path.isfile(cache_file), (
65
+ "cache file already exists: " + cache_file
66
+ )
67
+
68
+ print("Parsing pairs for dataset: " + dname)
69
+ pairs = []
70
+ for root, dirs, files in os.walk(dirname):
71
+ if "val" in root:
72
+ continue
73
+ dirs.sort()
74
+ pairs += [
75
+ (
76
+ os.path.join(root, f),
77
+ os.path.join(root, f[: -len("_1.jpeg")] + "_2.jpeg"),
78
+ )
79
+ for f in sorted(files)
80
+ if f.endswith("_1.jpeg")
81
+ ]
82
+ print("Found {:,} pairs".format(len(pairs)))
83
+ print("Writing cache to: " + cache_file)
84
+ write_cache_file(cache_file, pairs, root=dirname)
85
+
86
+ else:
87
+ raise NotImplementedError("Unknown dataset: " + dname)
88
+
89
+
90
+ def dnames_to_image_pairs(dnames, data_dir="./data/"):
91
+ """
92
+ dnames: list of datasets with image pairs, separated by +
93
+ """
94
+ all_pairs = []
95
+ for dname in dnames.split("+"):
96
+ if dname == "habitat_release":
97
+ dirname = os.path.join(data_dir, "habitat_release")
98
+ assert os.path.isdir(dirname), (
99
+ "cannot find folder for habitat_release pairs: " + dirname
100
+ )
101
+ cache_file = os.path.join(dirname, "pairs.txt")
102
+ assert os.path.isfile(cache_file), (
103
+ "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "
104
+ + cache_file
105
+ )
106
+ pairs = load_pairs_from_cache_file(cache_file, root=dirname)
107
+ elif dname in ["ARKitScenes", "MegaDepth", "3DStreetView", "IndoorVL"]:
108
+ dirname = os.path.join(data_dir, dname + "_crops")
109
+ assert os.path.isdir(
110
+ dirname
111
+ ), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
112
+ list_file = os.path.join(dirname, "listing.txt")
113
+ assert os.path.isfile(
114
+ list_file
115
+ ), "cannot find list file for {:s} pairs, see instructions. {:s}".format(
116
+ dname, list_file
117
+ )
118
+ pairs = load_pairs_from_list_file(list_file, root=dirname)
119
+ print(" {:s}: {:,} pairs".format(dname, len(pairs)))
120
+ all_pairs += pairs
121
+ if "+" in dnames:
122
+ print(" Total: {:,} pairs".format(len(all_pairs)))
123
+ return all_pairs
124
+
125
+
126
+ class PairsDataset(Dataset):
127
+
128
+ def __init__(
129
+ self, dnames, trfs="", totensor=True, normalize=True, data_dir="./data/"
130
+ ):
131
+ super().__init__()
132
+ self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
133
+ self.transforms = get_pair_transforms(
134
+ transform_str=trfs, totensor=totensor, normalize=normalize
135
+ )
136
+
137
+ def __len__(self):
138
+ return len(self.image_pairs)
139
+
140
+ def __getitem__(self, index):
141
+ im1path, im2path = self.image_pairs[index]
142
+ im1 = load_image(im1path)
143
+ im2 = load_image(im2path)
144
+ if self.transforms is not None:
145
+ im1, im2 = self.transforms(im1, im2)
146
+ return im1, im2
147
+
148
+
149
+ if __name__ == "__main__":
150
+ import argparse
151
+
152
+ parser = argparse.ArgumentParser(
153
+ prog="Computing and caching list of pairs for a given dataset"
154
+ )
155
+ parser.add_argument(
156
+ "--data_dir", default="./data/", type=str, help="path where data are stored"
157
+ )
158
+ parser.add_argument(
159
+ "--dataset", default="habitat_release", type=str, help="name of the dataset"
160
+ )
161
+ args = parser.parse_args()
162
+ parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)
croco/datasets/transforms.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+ import torchvision.transforms
6
+ import torchvision.transforms.functional as F
7
+
8
+ # "Pair": apply a transform on a pair
9
+ # "Both": apply the exact same transform to both images
10
+
11
+
12
+ class ComposePair(torchvision.transforms.Compose):
13
+ def __call__(self, img1, img2):
14
+ for t in self.transforms:
15
+ img1, img2 = t(img1, img2)
16
+ return img1, img2
17
+
18
+
19
+ class NormalizeBoth(torchvision.transforms.Normalize):
20
+ def forward(self, img1, img2):
21
+ img1 = super().forward(img1)
22
+ img2 = super().forward(img2)
23
+ return img1, img2
24
+
25
+
26
+ class ToTensorBoth(torchvision.transforms.ToTensor):
27
+ def __call__(self, img1, img2):
28
+ img1 = super().__call__(img1)
29
+ img2 = super().__call__(img2)
30
+ return img1, img2
31
+
32
+
33
+ class RandomCropPair(torchvision.transforms.RandomCrop):
34
+ # the crop will be intentionally different for the two images with this class
35
+ def forward(self, img1, img2):
36
+ img1 = super().forward(img1)
37
+ img2 = super().forward(img2)
38
+ return img1, img2
39
+
40
+
41
+ class ColorJitterPair(torchvision.transforms.ColorJitter):
42
+ # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob
43
+ def __init__(self, assymetric_prob, **kwargs):
44
+ super().__init__(**kwargs)
45
+ self.assymetric_prob = assymetric_prob
46
+
47
+ def jitter_one(
48
+ self,
49
+ img,
50
+ fn_idx,
51
+ brightness_factor,
52
+ contrast_factor,
53
+ saturation_factor,
54
+ hue_factor,
55
+ ):
56
+ for fn_id in fn_idx:
57
+ if fn_id == 0 and brightness_factor is not None:
58
+ img = F.adjust_brightness(img, brightness_factor)
59
+ elif fn_id == 1 and contrast_factor is not None:
60
+ img = F.adjust_contrast(img, contrast_factor)
61
+ elif fn_id == 2 and saturation_factor is not None:
62
+ img = F.adjust_saturation(img, saturation_factor)
63
+ elif fn_id == 3 and hue_factor is not None:
64
+ img = F.adjust_hue(img, hue_factor)
65
+ return img
66
+
67
+ def forward(self, img1, img2):
68
+
69
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = (
70
+ self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
71
+ )
72
+ img1 = self.jitter_one(
73
+ img1,
74
+ fn_idx,
75
+ brightness_factor,
76
+ contrast_factor,
77
+ saturation_factor,
78
+ hue_factor,
79
+ )
80
+ if torch.rand(1) < self.assymetric_prob: # assymetric:
81
+ (
82
+ fn_idx,
83
+ brightness_factor,
84
+ contrast_factor,
85
+ saturation_factor,
86
+ hue_factor,
87
+ ) = self.get_params(
88
+ self.brightness, self.contrast, self.saturation, self.hue
89
+ )
90
+ img2 = self.jitter_one(
91
+ img2,
92
+ fn_idx,
93
+ brightness_factor,
94
+ contrast_factor,
95
+ saturation_factor,
96
+ hue_factor,
97
+ )
98
+ return img1, img2
99
+
100
+
101
+ def get_pair_transforms(transform_str, totensor=True, normalize=True):
102
+ # transform_str is eg crop224+color
103
+ trfs = []
104
+ for s in transform_str.split("+"):
105
+ if s.startswith("crop"):
106
+ size = int(s[len("crop") :])
107
+ trfs.append(RandomCropPair(size))
108
+ elif s == "acolor":
109
+ trfs.append(
110
+ ColorJitterPair(
111
+ assymetric_prob=1.0,
112
+ brightness=(0.6, 1.4),
113
+ contrast=(0.6, 1.4),
114
+ saturation=(0.6, 1.4),
115
+ hue=0.0,
116
+ )
117
+ )
118
+ elif s == "": # if transform_str was ""
119
+ pass
120
+ else:
121
+ raise NotImplementedError("Unknown augmentation: " + s)
122
+
123
+ if totensor:
124
+ trfs.append(ToTensorBoth())
125
+ if normalize:
126
+ trfs.append(
127
+ NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
128
+ )
129
+
130
+ if len(trfs) == 0:
131
+ return None
132
+ elif len(trfs) == 1:
133
+ return trfs
134
+ else:
135
+ return ComposePair(trfs)
croco/interactive_demo.ipynb ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Interactive demo of Cross-view Completion."
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
17
+ "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "import torch\n",
27
+ "import numpy as np\n",
28
+ "from models.croco import CroCoNet\n",
29
+ "from ipywidgets import interact, interactive, fixed, interact_manual\n",
30
+ "import ipywidgets as widgets\n",
31
+ "import matplotlib.pyplot as plt\n",
32
+ "import quaternion\n",
33
+ "import models.masking"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {},
39
+ "source": [
40
+ "### Load CroCo model"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')\n",
50
+ "model = CroCoNet( **ckpt.get('croco_kwargs',{}))\n",
51
+ "msg = model.load_state_dict(ckpt['model'], strict=True)\n",
52
+ "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
53
+ "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
54
+ "model = model.eval()\n",
55
+ "model = model.to(device=device)\n",
56
+ "print(msg)\n",
57
+ "\n",
58
+ "def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):\n",
59
+ " \"\"\"\n",
60
+ " Perform Cross-View completion using two input images, specified using Numpy arrays.\n",
61
+ " \"\"\"\n",
62
+ " # Replace the mask generator\n",
63
+ " model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)\n",
64
+ "\n",
65
+ " # ImageNet-1k color normalization\n",
66
+ " imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)\n",
67
+ " imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)\n",
68
+ "\n",
69
+ " normalize_input_colors = True\n",
70
+ " is_output_normalized = True\n",
71
+ " with torch.no_grad():\n",
72
+ " # Cast data to torch\n",
73
+ " target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
74
+ " ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
75
+ "\n",
76
+ " if normalize_input_colors:\n",
77
+ " ref_image = (ref_image - imagenet_mean) / imagenet_std\n",
78
+ " target_image = (target_image - imagenet_mean) / imagenet_std\n",
79
+ "\n",
80
+ " out, mask, _ = model(target_image, ref_image)\n",
81
+ " # # get target\n",
82
+ " if not is_output_normalized:\n",
83
+ " predicted_image = model.unpatchify(out)\n",
84
+ " else:\n",
85
+ " # The output only contains higher order information,\n",
86
+ " # we retrieve mean and standard deviation from the actual target image\n",
87
+ " patchified = model.patchify(target_image)\n",
88
+ " mean = patchified.mean(dim=-1, keepdim=True)\n",
89
+ " var = patchified.var(dim=-1, keepdim=True)\n",
90
+ " pred_renorm = out * (var + 1.e-6)**.5 + mean\n",
91
+ " predicted_image = model.unpatchify(pred_renorm)\n",
92
+ "\n",
93
+ " image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])\n",
94
+ " masked_target_image = (1 - image_masks) * target_image\n",
95
+ " \n",
96
+ " if not reconstruct_unmasked_patches:\n",
97
+ " # Replace unmasked patches by their actual values\n",
98
+ " predicted_image = predicted_image * image_masks + masked_target_image\n",
99
+ "\n",
100
+ " # Unapply color normalization\n",
101
+ " if normalize_input_colors:\n",
102
+ " predicted_image = predicted_image * imagenet_std + imagenet_mean\n",
103
+ " masked_target_image = masked_target_image * imagenet_std + imagenet_mean\n",
104
+ " \n",
105
+ " # Cast to Numpy\n",
106
+ " masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
107
+ " predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
108
+ " return masked_target_image, predicted_image"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "markdown",
113
+ "metadata": {},
114
+ "source": [
115
+ "### Use the Habitat simulator to render images from arbitrary viewpoints (requires habitat_sim to be installed)"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "import os\n",
125
+ "os.environ[\"MAGNUM_LOG\"]=\"quiet\"\n",
126
+ "os.environ[\"HABITAT_SIM_LOG\"]=\"quiet\"\n",
127
+ "import habitat_sim\n",
128
+ "\n",
129
+ "scene = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb\"\n",
130
+ "navmesh = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh\"\n",
131
+ "\n",
132
+ "sim_cfg = habitat_sim.SimulatorConfiguration()\n",
133
+ "if use_gpu: sim_cfg.gpu_device_id = 0\n",
134
+ "sim_cfg.scene_id = scene\n",
135
+ "sim_cfg.load_semantic_mesh = False\n",
136
+ "rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n",
137
+ "rgb_sensor_spec.uuid = \"color\"\n",
138
+ "rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n",
139
+ "rgb_sensor_spec.resolution = (224,224)\n",
140
+ "rgb_sensor_spec.hfov = 56.56\n",
141
+ "rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n",
142
+ "rgb_sensor_spec.orientation = [0, 0, 0]\n",
143
+ "agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])\n",
144
+ "\n",
145
+ "\n",
146
+ "cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n",
147
+ "sim = habitat_sim.Simulator(cfg)\n",
148
+ "if navmesh is not None:\n",
149
+ " sim.pathfinder.load_nav_mesh(navmesh)\n",
150
+ "agent = sim.initialize_agent(agent_id=0)\n",
151
+ "\n",
152
+ "def sample_random_viewpoint():\n",
153
+ " \"\"\" Sample a random viewpoint using the navmesh \"\"\"\n",
154
+ " nav_point = sim.pathfinder.get_random_navigable_point()\n",
155
+ " # Sample a random viewpoint height\n",
156
+ " viewpoint_height = np.random.uniform(1.0, 1.6)\n",
157
+ " viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP\n",
158
+ " viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)\n",
159
+ " return viewpoint_position, viewpoint_orientation\n",
160
+ "\n",
161
+ "def render_viewpoint(position, orientation):\n",
162
+ " agent_state = habitat_sim.AgentState()\n",
163
+ " agent_state.position = position\n",
164
+ " agent_state.rotation = orientation\n",
165
+ " agent.set_state(agent_state)\n",
166
+ " viewpoint_observations = sim.get_sensor_observations(agent_ids=0)\n",
167
+ " image = viewpoint_observations['color'][:,:,:3]\n",
168
+ " image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)\n",
169
+ " return image"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "markdown",
174
+ "metadata": {},
175
+ "source": [
176
+ "### Sample a random reference view"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "ref_position, ref_orientation = sample_random_viewpoint()\n",
186
+ "ref_image = render_viewpoint(ref_position, ref_orientation)\n",
187
+ "plt.clf()\n",
188
+ "fig, axes = plt.subplots(1,1, squeeze=False, num=1)\n",
189
+ "axes[0,0].imshow(ref_image)\n",
190
+ "for ax in axes.flatten():\n",
191
+ " ax.set_xticks([])\n",
192
+ " ax.set_yticks([])"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {},
198
+ "source": [
199
+ "### Interactive cross-view completion using CroCo"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "reconstruct_unmasked_patches = False\n",
209
+ "\n",
210
+ "def show_demo(masking_ratio, x, y, z, panorama, elevation):\n",
211
+ " R = quaternion.as_rotation_matrix(ref_orientation)\n",
212
+ " target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]\n",
213
+ " target_orientation = (ref_orientation\n",
214
+ " * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) \n",
215
+ " * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))\n",
216
+ " \n",
217
+ " ref_image = render_viewpoint(ref_position, ref_orientation)\n",
218
+ " target_image = render_viewpoint(target_position, target_orientation)\n",
219
+ "\n",
220
+ " masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)\n",
221
+ "\n",
222
+ " fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)\n",
223
+ " axes[0].imshow(ref_image)\n",
224
+ " axes[0].set_xlabel(\"Reference\")\n",
225
+ " axes[1].imshow(masked_target_image)\n",
226
+ " axes[1].set_xlabel(\"Masked target\")\n",
227
+ " axes[2].imshow(predicted_image)\n",
228
+ " axes[2].set_xlabel(\"Reconstruction\") \n",
229
+ " axes[3].imshow(target_image)\n",
230
+ " axes[3].set_xlabel(\"Target\")\n",
231
+ " for ax in axes.flatten():\n",
232
+ " ax.set_xticks([])\n",
233
+ " ax.set_yticks([])\n",
234
+ "\n",
235
+ "interact(show_demo,\n",
236
+ " masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),\n",
237
+ " x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
238
+ " y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
239
+ " z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
240
+ " panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),\n",
241
+ " elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));"
242
+ ]
243
+ }
244
+ ],
245
+ "metadata": {
246
+ "kernelspec": {
247
+ "display_name": "Python 3 (ipykernel)",
248
+ "language": "python",
249
+ "name": "python3"
250
+ },
251
+ "language_info": {
252
+ "codemirror_mode": {
253
+ "name": "ipython",
254
+ "version": 3
255
+ },
256
+ "file_extension": ".py",
257
+ "mimetype": "text/x-python",
258
+ "name": "python",
259
+ "nbconvert_exporter": "python",
260
+ "pygments_lexer": "ipython3",
261
+ "version": "3.7.13"
262
+ },
263
+ "vscode": {
264
+ "interpreter": {
265
+ "hash": "f9237820cd248d7e07cb4fb9f0e4508a85d642f19d831560c0a4b61f3e907e67"
266
+ }
267
+ }
268
+ },
269
+ "nbformat": 4,
270
+ "nbformat_minor": 2
271
+ }
croco/models/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (9.73 kB). View file
 
croco/models/__pycache__/blocks.cpython-311.pyc ADDED
Binary file (19.8 kB). View file
 
croco/models/__pycache__/blocks.cpython-312.pyc ADDED
Binary file (17.5 kB). View file
 
croco/models/__pycache__/croco.cpython-310.pyc ADDED
Binary file (7.66 kB). View file
 
croco/models/__pycache__/croco.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
croco/models/__pycache__/croco.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
croco/models/__pycache__/dpt_block.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
croco/models/__pycache__/dpt_block.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
croco/models/__pycache__/dpt_block.cpython-312.pyc ADDED
Binary file (16.8 kB). View file
 
croco/models/__pycache__/masking.cpython-310.pyc ADDED
Binary file (892 Bytes). View file
 
croco/models/__pycache__/masking.cpython-311.pyc ADDED
Binary file (1.42 kB). View file
 
croco/models/__pycache__/masking.cpython-312.pyc ADDED
Binary file (1.28 kB). View file
 
croco/models/__pycache__/pos_embed.cpython-310.pyc ADDED
Binary file (4.89 kB). View file
 
croco/models/__pycache__/pos_embed.cpython-311.pyc ADDED
Binary file (8.89 kB). View file
 
croco/models/__pycache__/pos_embed.cpython-312.pyc ADDED
Binary file (8.33 kB). View file
 
croco/models/blocks.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Main encoder/decoder blocks
7
+ # --------------------------------------------------------
8
+ # References:
9
+ # timm
10
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
11
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
12
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
13
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
14
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from itertools import repeat
21
+ import collections.abc
22
+ from torch.nn.functional import scaled_dot_product_attention
23
+
24
+
25
+ def _ntuple(n):
26
+ def parse(x):
27
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
28
+ return x
29
+ return tuple(repeat(x, n))
30
+
31
+ return parse
32
+
33
+
34
+ to_2tuple = _ntuple(2)
35
+
36
+
37
+ def drop_path(
38
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
39
+ ):
40
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
41
+ if drop_prob == 0.0 or not training:
42
+ return x
43
+ keep_prob = 1 - drop_prob
44
+ shape = (x.shape[0],) + (1,) * (
45
+ x.ndim - 1
46
+ ) # work with diff dim tensors, not just 2D ConvNets
47
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
48
+ if keep_prob > 0.0 and scale_by_keep:
49
+ random_tensor.div_(keep_prob)
50
+ return x * random_tensor
51
+
52
+
53
+ class DropPath(nn.Module):
54
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
55
+
56
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
57
+ super(DropPath, self).__init__()
58
+ self.drop_prob = drop_prob
59
+ self.scale_by_keep = scale_by_keep
60
+
61
+ def forward(self, x):
62
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
63
+
64
+ def extra_repr(self):
65
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
66
+
67
+
68
+ class Mlp(nn.Module):
69
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
70
+
71
+ def __init__(
72
+ self,
73
+ in_features,
74
+ hidden_features=None,
75
+ out_features=None,
76
+ act_layer=nn.GELU,
77
+ bias=True,
78
+ drop=0.0,
79
+ ):
80
+ super().__init__()
81
+ out_features = out_features or in_features
82
+ hidden_features = hidden_features or in_features
83
+ bias = to_2tuple(bias)
84
+ drop_probs = to_2tuple(drop)
85
+
86
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
87
+ self.act = act_layer()
88
+ self.drop1 = nn.Dropout(drop_probs[0])
89
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
90
+ self.drop2 = nn.Dropout(drop_probs[1])
91
+
92
+ def forward(self, x):
93
+ return self.drop2(self.fc2(self.drop1(self.act(self.fc1(x)))))
94
+
95
+
96
+ class Attention(nn.Module):
97
+
98
+ def __init__(
99
+ self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
100
+ ):
101
+ super().__init__()
102
+ self.num_heads = num_heads
103
+ head_dim = dim // num_heads
104
+ self.scale = head_dim**-0.5
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+ self.rope = rope.float() if rope is not None else None
110
+
111
+ def forward(self, x, xpos):
112
+ B, N, C = x.shape
113
+
114
+ qkv = (
115
+ self.qkv(x)
116
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
117
+ .transpose(1, 3)
118
+ )
119
+ q, k, v = [qkv[:, :, i] for i in range(3)]
120
+ # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
121
+
122
+ q_type = q.dtype
123
+ k_type = k.dtype
124
+ if self.rope is not None:
125
+ q = q.to(torch.float16)
126
+ k = k.to(torch.float16)
127
+ with torch.autocast(device_type="cuda", enabled=False):
128
+ q = self.rope(q, xpos)
129
+ k = self.rope(k, xpos)
130
+ q = q.to(q_type)
131
+ k = k.to(k_type)
132
+
133
+ # attn = (q @ k.transpose(-2, -1)) * self.scale
134
+ # attn = attn.softmax(dim=-1)
135
+ # attn = self.attn_drop(attn)
136
+
137
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
138
+ # x = memory_efficient_attention(query=q.permute(0, 2, 1, 3), key=k.permute(0, 2, 1, 3), value=v.permute(0, 2, 1, 3), p=self.attn_drop.p, scale=self.scale).reshape(B, N, C)
139
+ x = (
140
+ scaled_dot_product_attention(
141
+ query=q, key=k, value=v, dropout_p=self.attn_drop.p, scale=self.scale
142
+ )
143
+ .transpose(1, 2)
144
+ .reshape(B, N, C)
145
+ )
146
+ x = self.proj(x)
147
+ x = self.proj_drop(x)
148
+ return x
149
+
150
+
151
+ class Block(nn.Module):
152
+
153
+ def __init__(
154
+ self,
155
+ dim,
156
+ num_heads,
157
+ mlp_ratio=4.0,
158
+ qkv_bias=False,
159
+ drop=0.0,
160
+ attn_drop=0.0,
161
+ drop_path=0.0,
162
+ act_layer=nn.GELU,
163
+ norm_layer=nn.LayerNorm,
164
+ rope=None,
165
+ ):
166
+ super().__init__()
167
+ self.norm1 = norm_layer(dim)
168
+ self.attn = Attention(
169
+ dim,
170
+ rope=rope,
171
+ num_heads=num_heads,
172
+ qkv_bias=qkv_bias,
173
+ attn_drop=attn_drop,
174
+ proj_drop=drop,
175
+ )
176
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
177
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
178
+ self.norm2 = norm_layer(dim)
179
+ mlp_hidden_dim = int(dim * mlp_ratio)
180
+ self.mlp = Mlp(
181
+ in_features=dim,
182
+ hidden_features=mlp_hidden_dim,
183
+ act_layer=act_layer,
184
+ drop=drop,
185
+ )
186
+
187
+ def forward(self, x, xpos):
188
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
189
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
190
+ return x
191
+
192
+
193
+ class CrossAttention(nn.Module):
194
+
195
+ def __init__(
196
+ self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
197
+ ):
198
+ super().__init__()
199
+ self.num_heads = num_heads
200
+ head_dim = dim // num_heads
201
+ self.scale = head_dim**-0.5
202
+
203
+ self.projq = nn.Linear(dim, dim, bias=qkv_bias)
204
+ self.projk = nn.Linear(dim, dim, bias=qkv_bias)
205
+ self.projv = nn.Linear(dim, dim, bias=qkv_bias)
206
+ self.attn_drop = nn.Dropout(attn_drop)
207
+ self.proj = nn.Linear(dim, dim)
208
+ self.proj_drop = nn.Dropout(proj_drop)
209
+
210
+ self.rope = rope.float() if rope is not None else None
211
+
212
+ def forward(self, query, key, value, qpos, kpos):
213
+ B, Nq, C = query.shape
214
+ Nk = key.shape[1]
215
+ Nv = value.shape[1]
216
+
217
+ q = (
218
+ self.projq(query)
219
+ .reshape(B, Nq, self.num_heads, C // self.num_heads)
220
+ .permute(0, 2, 1, 3)
221
+ )
222
+ k = (
223
+ self.projk(key)
224
+ .reshape(B, Nk, self.num_heads, C // self.num_heads)
225
+ .permute(0, 2, 1, 3)
226
+ )
227
+ v = (
228
+ self.projv(value)
229
+ .reshape(B, Nv, self.num_heads, C // self.num_heads)
230
+ .permute(0, 2, 1, 3)
231
+ )
232
+
233
+ q_type = q.dtype
234
+ k_type = k.dtype
235
+ if self.rope is not None:
236
+ if qpos is not None:
237
+ q = q.to(torch.float16)
238
+ with torch.autocast(device_type="cuda", enabled=False):
239
+ q = self.rope(q, qpos)
240
+ q = q.to(q_type)
241
+
242
+ if kpos is not None:
243
+ k = k.to(torch.float16)
244
+ with torch.autocast(device_type="cuda", enabled=False):
245
+ k = self.rope(k, kpos)
246
+ k = k.to(k_type)
247
+
248
+ # attn = (q @ k.transpose(-2, -1)) * self.scale
249
+ # attn = attn.softmax(dim=-1)
250
+ # attn = self.attn_drop(attn)
251
+
252
+ # x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
253
+
254
+ # x = memory_efficient_attention(query=q.permute(0, 2, 1, 3), key=k.permute(0, 2, 1, 3), value=v.permute(0, 2, 1, 3), p=self.attn_drop.p, scale=self.scale).reshape(B, Nq, C)
255
+ x = (
256
+ scaled_dot_product_attention(
257
+ query=q, key=k, value=v, dropout_p=self.attn_drop.p, scale=self.scale
258
+ )
259
+ .transpose(1, 2)
260
+ .reshape(B, Nq, C)
261
+ )
262
+
263
+ x = self.proj(x)
264
+ x = self.proj_drop(x)
265
+ return x
266
+
267
+
268
+ class DecoderBlock(nn.Module):
269
+
270
+ def __init__(
271
+ self,
272
+ dim,
273
+ num_heads,
274
+ mlp_ratio=4.0,
275
+ qkv_bias=False,
276
+ drop=0.0,
277
+ attn_drop=0.0,
278
+ drop_path=0.0,
279
+ act_layer=nn.GELU,
280
+ norm_layer=nn.LayerNorm,
281
+ norm_mem=True,
282
+ rope=None,
283
+ ):
284
+ super().__init__()
285
+ self.norm1 = norm_layer(dim)
286
+ self.attn = Attention(
287
+ dim,
288
+ rope=rope,
289
+ num_heads=num_heads,
290
+ qkv_bias=qkv_bias,
291
+ attn_drop=attn_drop,
292
+ proj_drop=drop,
293
+ )
294
+ self.cross_attn = CrossAttention(
295
+ dim,
296
+ rope=rope,
297
+ num_heads=num_heads,
298
+ qkv_bias=qkv_bias,
299
+ attn_drop=attn_drop,
300
+ proj_drop=drop,
301
+ )
302
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
303
+ self.norm2 = norm_layer(dim)
304
+ self.norm3 = norm_layer(dim)
305
+ mlp_hidden_dim = int(dim * mlp_ratio)
306
+ self.mlp = Mlp(
307
+ in_features=dim,
308
+ hidden_features=mlp_hidden_dim,
309
+ act_layer=act_layer,
310
+ drop=drop,
311
+ )
312
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
313
+
314
+ def forward(self, x, y, xpos, ypos):
315
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
316
+ y_ = self.norm_y(y)
317
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
318
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
319
+ return x, y
320
+
321
+
322
+ # patch embedding
323
+ class PositionGetter(object):
324
+ """return positions of patches"""
325
+
326
+ def __init__(self):
327
+ self.cache_positions = {}
328
+
329
+ def __call__(self, b, h, w, device):
330
+ if not (h, w) in self.cache_positions:
331
+ x = torch.arange(w, device=device)
332
+ y = torch.arange(h, device=device)
333
+ self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2)
334
+ pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
335
+ return pos
336
+
337
+
338
+ class PatchEmbed(nn.Module):
339
+ """just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
340
+
341
+ def __init__(
342
+ self,
343
+ img_size=224,
344
+ patch_size=16,
345
+ in_chans=3,
346
+ embed_dim=768,
347
+ norm_layer=None,
348
+ flatten=True,
349
+ ):
350
+ super().__init__()
351
+ img_size = to_2tuple(img_size)
352
+ patch_size = to_2tuple(patch_size)
353
+ self.img_size = img_size
354
+ self.patch_size = patch_size
355
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
356
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
357
+ self.flatten = flatten
358
+
359
+ self.proj = nn.Conv2d(
360
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
361
+ )
362
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
363
+
364
+ self.position_getter = PositionGetter()
365
+
366
+ def forward(self, x):
367
+ B, C, H, W = x.shape
368
+ torch._assert(
369
+ H == self.img_size[0],
370
+ f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
371
+ )
372
+ torch._assert(
373
+ W == self.img_size[1],
374
+ f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
375
+ )
376
+ x = self.proj(x)
377
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
378
+ if self.flatten:
379
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
380
+ x = self.norm(x)
381
+ return x, pos
382
+
383
+ def _init_weights(self):
384
+ w = self.proj.weight.data
385
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
croco/models/criterion.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Criterion to train CroCo
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # MAE: https://github.com/facebookresearch/mae
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+
14
+ class MaskedMSE(torch.nn.Module):
15
+
16
+ def __init__(self, norm_pix_loss=False, masked=True):
17
+ """
18
+ norm_pix_loss: normalize each patch by their pixel mean and variance
19
+ masked: compute loss over the masked patches only
20
+ """
21
+ super().__init__()
22
+ self.norm_pix_loss = norm_pix_loss
23
+ self.masked = masked
24
+
25
+ def forward(self, pred, mask, target):
26
+
27
+ if self.norm_pix_loss:
28
+ mean = target.mean(dim=-1, keepdim=True)
29
+ var = target.var(dim=-1, keepdim=True)
30
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
31
+
32
+ loss = (pred - target) ** 2
33
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
34
+ if self.masked:
35
+ loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches
36
+ else:
37
+ loss = loss.mean() # mean loss
38
+ return loss
croco/models/croco.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # CroCo model during pretraining
7
+ # --------------------------------------------------------
8
+
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
14
+ from functools import partial
15
+
16
+ from models.blocks import Block, DecoderBlock, PatchEmbed
17
+ from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
18
+ from models.masking import RandomMask
19
+
20
+ from transformers import PretrainedConfig
21
+ from transformers import PreTrainedModel
22
+
23
+
24
+ class CrocoConfig(PretrainedConfig):
25
+ model_type = "croco"
26
+
27
+ def __init__(
28
+ self,
29
+ img_size=224, # input image size
30
+ patch_size=16, # patch_size
31
+ mask_ratio=0.9, # ratios of masked tokens
32
+ enc_embed_dim=768, # encoder feature dimension
33
+ enc_depth=12, # encoder depth
34
+ enc_num_heads=12, # encoder number of heads in the transformer block
35
+ dec_embed_dim=512, # decoder feature dimension
36
+ dec_depth=8, # decoder depth
37
+ dec_num_heads=16, # decoder number of heads in the transformer block
38
+ mlp_ratio=4,
39
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
40
+ norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
41
+ pos_embed="cosine", # positional embedding (either cosine or RoPE100)
42
+ ):
43
+ super().__init__()
44
+ self.img_size = img_size
45
+ self.patch_size = patch_size
46
+ self.mask_ratio = mask_ratio
47
+ self.enc_embed_dim = enc_embed_dim
48
+ self.enc_depth = enc_depth
49
+ self.enc_num_heads = enc_num_heads
50
+ self.dec_embed_dim = dec_embed_dim
51
+ self.dec_depth = dec_depth
52
+ self.dec_num_heads = dec_num_heads
53
+ self.mlp_ratio = mlp_ratio
54
+ self.norm_layer = norm_layer
55
+ self.norm_im2_in_dec = norm_im2_in_dec
56
+ self.pos_embed = pos_embed
57
+
58
+
59
+ class CroCoNet(PreTrainedModel):
60
+
61
+ config_class = CrocoConfig
62
+ base_model_prefix = "croco"
63
+
64
+ def __init__(self, config: CrocoConfig):
65
+
66
+ super().__init__(config)
67
+
68
+ # patch embeddings (with initialization done as in MAE)
69
+ self._set_patch_embed(config.img_size, config.patch_size, config.enc_embed_dim)
70
+
71
+ # mask generations
72
+ self._set_mask_generator(self.patch_embed.num_patches, config.mask_ratio)
73
+
74
+ self.pos_embed = config.pos_embed
75
+ if config.pos_embed == "cosine":
76
+ # positional embedding of the encoder
77
+ enc_pos_embed = get_2d_sincos_pos_embed(
78
+ config.enc_embed_dim,
79
+ int(self.patch_embed.num_patches**0.5),
80
+ n_cls_token=0,
81
+ )
82
+ self.register_buffer(
83
+ "enc_pos_embed", torch.from_numpy(enc_pos_embed).float()
84
+ )
85
+ # positional embedding of the decoder
86
+ dec_pos_embed = get_2d_sincos_pos_embed(
87
+ config.dec_embed_dim,
88
+ int(self.patch_embed.num_patches**0.5),
89
+ n_cls_token=0,
90
+ )
91
+ self.register_buffer(
92
+ "dec_pos_embed", torch.from_numpy(dec_pos_embed).float()
93
+ )
94
+ # pos embedding in each block
95
+ self.rope = None # nothing for cosine
96
+ elif config.pos_embed.startswith("RoPE"): # eg RoPE100
97
+ self.enc_pos_embed = None # nothing to add in the encoder with RoPE
98
+ self.dec_pos_embed = None # nothing to add in the decoder with RoPE
99
+ if RoPE2D is None:
100
+ raise ImportError(
101
+ "Cannot find cuRoPE2D, please install it following the README instructions"
102
+ )
103
+ freq = float(config.pos_embed[len("RoPE") :])
104
+ self.rope = RoPE2D(freq=freq)
105
+ else:
106
+ raise NotImplementedError("Unknown pos_embed " + config.pos_embed)
107
+
108
+ # transformer for the encoder
109
+ self.enc_depth = config.enc_depth
110
+ self.enc_embed_dim = config.enc_embed_dim
111
+ self.enc_blocks = nn.ModuleList(
112
+ [
113
+ Block(
114
+ config.enc_embed_dim,
115
+ config.enc_num_heads,
116
+ config.mlp_ratio,
117
+ qkv_bias=True,
118
+ norm_layer=config.norm_layer,
119
+ rope=self.rope,
120
+ )
121
+ for i in range(config.enc_depth)
122
+ ]
123
+ )
124
+ self.enc_norm = config.norm_layer(config.enc_embed_dim)
125
+
126
+ # masked tokens
127
+ # self._set_mask_token(config.dec_embed_dim)
128
+ self.mask_token = None
129
+
130
+ # decoder
131
+ self._set_decoder(
132
+ config.enc_embed_dim,
133
+ config.dec_embed_dim,
134
+ config.dec_num_heads,
135
+ config.dec_depth,
136
+ config.mlp_ratio,
137
+ config.norm_layer,
138
+ config.norm_im2_in_dec,
139
+ )
140
+
141
+ # prediction head
142
+ self._set_prediction_head(config.dec_embed_dim, config.patch_size)
143
+
144
+ # initializer weights
145
+ self.initialize_weights()
146
+
147
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
148
+ self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
149
+
150
+ def _set_mask_generator(self, num_patches, mask_ratio):
151
+ self.mask_generator = RandomMask(num_patches, mask_ratio)
152
+
153
+ def _set_mask_token(self, dec_embed_dim):
154
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
155
+
156
+ def _set_decoder(
157
+ self,
158
+ enc_embed_dim,
159
+ dec_embed_dim,
160
+ dec_num_heads,
161
+ dec_depth,
162
+ mlp_ratio,
163
+ norm_layer,
164
+ norm_im2_in_dec,
165
+ ):
166
+ self.dec_depth = dec_depth
167
+ self.dec_embed_dim = dec_embed_dim
168
+ # transfer from encoder to decoder
169
+ self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
170
+ # transformer for the decoder
171
+ self.dec_blocks = nn.ModuleList(
172
+ [
173
+ DecoderBlock(
174
+ dec_embed_dim,
175
+ dec_num_heads,
176
+ mlp_ratio=mlp_ratio,
177
+ qkv_bias=True,
178
+ norm_layer=norm_layer,
179
+ norm_mem=norm_im2_in_dec,
180
+ rope=self.rope,
181
+ )
182
+ for i in range(dec_depth)
183
+ ]
184
+ )
185
+ # final norm layer
186
+ self.dec_norm = norm_layer(dec_embed_dim)
187
+
188
+ def _set_prediction_head(self, dec_embed_dim, patch_size):
189
+ self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
190
+
191
+ def initialize_weights(self):
192
+ # patch embed
193
+ self.patch_embed._init_weights()
194
+ # mask tokens
195
+ if self.mask_token is not None:
196
+ torch.nn.init.normal_(self.mask_token, std=0.02)
197
+ # linears and layer norms
198
+ self.apply(self._init_weights)
199
+
200
+ def _init_weights(self, m):
201
+ if isinstance(m, nn.Linear):
202
+ # we use xavier_uniform following official JAX ViT:
203
+ torch.nn.init.xavier_uniform_(m.weight)
204
+ if isinstance(m, nn.Linear) and m.bias is not None:
205
+ nn.init.constant_(m.bias, 0)
206
+ elif isinstance(m, nn.LayerNorm):
207
+ nn.init.constant_(m.bias, 0)
208
+ nn.init.constant_(m.weight, 1.0)
209
+
210
+ def _encode_image(self, image, do_mask=False, return_all_blocks=False):
211
+ """
212
+ image has B x 3 x img_size x img_size
213
+ do_mask: whether to perform masking or not
214
+ return_all_blocks: if True, return the features at the end of every block
215
+ instead of just the features from the last block (eg for some prediction heads)
216
+ """
217
+ # embed the image into patches (x has size B x Npatches x C)
218
+ # and get position if each return patch (pos has size B x Npatches x 2)
219
+ x, pos = self.patch_embed(image)
220
+ # add positional embedding without cls token
221
+ if self.enc_pos_embed is not None:
222
+ x = x + self.enc_pos_embed[None, ...]
223
+ # apply masking
224
+ B, N, C = x.size()
225
+ if do_mask:
226
+ masks = self.mask_generator(x)
227
+ x = x[~masks].view(B, -1, C)
228
+ posvis = pos[~masks].view(B, -1, 2)
229
+ else:
230
+ B, N, C = x.size()
231
+ masks = torch.zeros((B, N), dtype=bool)
232
+ posvis = pos
233
+ # now apply the transformer encoder and normalization
234
+ if return_all_blocks:
235
+ out = []
236
+ for blk in self.enc_blocks:
237
+ x = blk(x, posvis)
238
+ out.append(x)
239
+ out[-1] = self.enc_norm(out[-1])
240
+ return out, pos, masks
241
+ else:
242
+ for blk in self.enc_blocks:
243
+ x = blk(x, posvis)
244
+ x = self.enc_norm(x)
245
+ return x, pos, masks
246
+
247
+ def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
248
+ """
249
+ return_all_blocks: if True, return the features at the end of every block
250
+ instead of just the features from the last block (eg for some prediction heads)
251
+
252
+ masks1 can be None => assume image1 fully visible
253
+ """
254
+ # encoder to decoder layer
255
+ visf1 = self.decoder_embed(feat1)
256
+ f2 = self.decoder_embed(feat2)
257
+ # append masked tokens to the sequence
258
+ B, Nenc, C = visf1.size()
259
+ if masks1 is None: # downstreams
260
+ f1_ = visf1
261
+ else: # pretraining
262
+ Ntotal = masks1.size(1)
263
+ f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
264
+ f1_[~masks1] = visf1.view(B * Nenc, C)
265
+ # add positional embedding
266
+ if self.dec_pos_embed is not None:
267
+ f1_ = f1_ + self.dec_pos_embed
268
+ f2 = f2 + self.dec_pos_embed
269
+ # apply Transformer blocks
270
+ out = f1_
271
+ out2 = f2
272
+ if return_all_blocks:
273
+ _out, out = out, []
274
+ for blk in self.dec_blocks:
275
+ _out, out2 = blk(_out, out2, pos1, pos2)
276
+ out.append(_out)
277
+ out[-1] = self.dec_norm(out[-1])
278
+ else:
279
+ for blk in self.dec_blocks:
280
+ out, out2 = blk(out, out2, pos1, pos2)
281
+ out = self.dec_norm(out)
282
+ return out
283
+
284
+ def patchify(self, imgs):
285
+ """
286
+ imgs: (B, 3, H, W)
287
+ x: (B, L, patch_size**2 *3)
288
+ """
289
+ p = self.patch_embed.patch_size[0]
290
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
291
+
292
+ h = w = imgs.shape[2] // p
293
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
294
+ x = torch.einsum("nchpwq->nhwpqc", x)
295
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
296
+
297
+ return x
298
+
299
+ def unpatchify(self, x, channels=3):
300
+ """
301
+ x: (N, L, patch_size**2 *channels)
302
+ imgs: (N, 3, H, W)
303
+ """
304
+ patch_size = self.patch_embed.patch_size[0]
305
+ h = w = int(x.shape[1] ** 0.5)
306
+ assert h * w == x.shape[1]
307
+ x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
308
+ x = torch.einsum("nhwpqc->nchpwq", x)
309
+ imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
310
+ return imgs
311
+
312
+ # def forward(self, img1, img2):
313
+ # """
314
+ # img1: tensor of size B x 3 x img_size x img_size
315
+ # img2: tensor of size B x 3 x img_size x img_size
316
+
317
+ # out will be B x N x (3*patch_size*patch_size)
318
+ # masks are also returned as B x N just in case
319
+ # """
320
+ # # encoder of the masked first image
321
+ # feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
322
+ # # encoder of the second image
323
+ # feat2, pos2, _ = self._encode_image(img2, do_mask=False)
324
+ # # decoder
325
+ # decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
326
+ # # prediction head
327
+ # out = self.prediction_head(decfeat)
328
+ # # get target
329
+ # target = self.patchify(img1)
330
+ # return out, mask1, target
croco/models/croco_downstream.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # CroCo model for downstream tasks
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ from .croco import CroCoNet
11
+
12
+
13
+ def croco_args_from_ckpt(ckpt):
14
+ if "croco_kwargs" in ckpt: # CroCo v2 released models
15
+ return ckpt["croco_kwargs"]
16
+ elif "args" in ckpt and hasattr(
17
+ ckpt["args"], "model"
18
+ ): # pretrained using the official code release
19
+ s = ckpt[
20
+ "args"
21
+ ].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)"
22
+ assert s.startswith("CroCoNet(")
23
+ return eval(
24
+ "dict" + s[len("CroCoNet") :]
25
+ ) # transform it into the string of a dictionary and evaluate it
26
+ else: # CroCo v1 released models
27
+ return dict()
28
+
29
+
30
+ class CroCoDownstreamMonocularEncoder(CroCoNet):
31
+
32
+ def __init__(self, head, **kwargs):
33
+ """Build network for monocular downstream task, only using the encoder.
34
+ It takes an extra argument head, that is called with the features
35
+ and a dictionary img_info containing 'width' and 'height' keys
36
+ The head is setup with the croconet arguments in this init function
37
+ NOTE: It works by *calling super().__init__() but with redefined setters
38
+
39
+ """
40
+ super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs)
41
+ head.setup(self)
42
+ self.head = head
43
+
44
+ def _set_mask_generator(self, *args, **kwargs):
45
+ """No mask generator"""
46
+ return
47
+
48
+ def _set_mask_token(self, *args, **kwargs):
49
+ """No mask token"""
50
+ self.mask_token = None
51
+ return
52
+
53
+ def _set_decoder(self, *args, **kwargs):
54
+ """No decoder"""
55
+ return
56
+
57
+ def _set_prediction_head(self, *args, **kwargs):
58
+ """No 'prediction head' for downstream tasks."""
59
+ return
60
+
61
+ def forward(self, img):
62
+ """
63
+ img if of size batch_size x 3 x h x w
64
+ """
65
+ B, C, H, W = img.size()
66
+ img_info = {"height": H, "width": W}
67
+ need_all_layers = (
68
+ hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks
69
+ )
70
+ out, _, _ = self._encode_image(
71
+ img, do_mask=False, return_all_blocks=need_all_layers
72
+ )
73
+ return self.head(out, img_info)
74
+
75
+
76
+ class CroCoDownstreamBinocular(CroCoNet):
77
+
78
+ def __init__(self, head, **kwargs):
79
+ """Build network for binocular downstream task
80
+ It takes an extra argument head, that is called with the features
81
+ and a dictionary img_info containing 'width' and 'height' keys
82
+ The head is setup with the croconet arguments in this init function
83
+ """
84
+ super(CroCoDownstreamBinocular, self).__init__(**kwargs)
85
+ head.setup(self)
86
+ self.head = head
87
+
88
+ def _set_mask_generator(self, *args, **kwargs):
89
+ """No mask generator"""
90
+ return
91
+
92
+ def _set_mask_token(self, *args, **kwargs):
93
+ """No mask token"""
94
+ self.mask_token = None
95
+ return
96
+
97
+ def _set_prediction_head(self, *args, **kwargs):
98
+ """No prediction head for downstream tasks, define your own head"""
99
+ return
100
+
101
+ def encode_image_pairs(self, img1, img2, return_all_blocks=False):
102
+ """run encoder for a pair of images
103
+ it is actually ~5% faster to concatenate the images along the batch dimension
104
+ than to encode them separately
105
+ """
106
+ ## the two commented lines below is the naive version with separate encoding
107
+ # out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks)
108
+ # out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False)
109
+ ## and now the faster version
110
+ out, pos, _ = self._encode_image(
111
+ torch.cat((img1, img2), dim=0),
112
+ do_mask=False,
113
+ return_all_blocks=return_all_blocks,
114
+ )
115
+ if return_all_blocks:
116
+ out, out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out])))
117
+ out2 = out2[-1]
118
+ else:
119
+ out, out2 = out.chunk(2, dim=0)
120
+ pos, pos2 = pos.chunk(2, dim=0)
121
+ return out, out2, pos, pos2
122
+
123
+ def forward(self, img1, img2):
124
+ B, C, H, W = img1.size()
125
+ img_info = {"height": H, "width": W}
126
+ return_all_blocks = (
127
+ hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks
128
+ )
129
+ out, out2, pos, pos2 = self.encode_image_pairs(
130
+ img1, img2, return_all_blocks=return_all_blocks
131
+ )
132
+ if return_all_blocks:
133
+ decout = self._decoder(
134
+ out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks
135
+ )
136
+ decout = out + decout
137
+ else:
138
+ decout = self._decoder(
139
+ out, pos, None, out2, pos2, return_all_blocks=return_all_blocks
140
+ )
141
+ return self.head(decout, img_info)
croco/models/curope/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ from .curope2d import cuRoPE2D
croco/models/curope/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (193 Bytes). View file
 
croco/models/curope/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (219 Bytes). View file
 
croco/models/curope/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (203 Bytes). View file
 
croco/models/curope/__pycache__/curope2d.cpython-310.pyc ADDED
Binary file (1.62 kB). View file
 
croco/models/curope/__pycache__/curope2d.cpython-311.pyc ADDED
Binary file (2.63 kB). View file
 
croco/models/curope/__pycache__/curope2d.cpython-312.pyc ADDED
Binary file (2.36 kB). View file
 
croco/models/curope/curope.cpp ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (C) 2022-present Naver Corporation. All rights reserved.
3
+ Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+
8
+ // forward declaration
9
+ void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd );
10
+
11
+ void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd )
12
+ {
13
+ const int B = tokens.size(0);
14
+ const int N = tokens.size(1);
15
+ const int H = tokens.size(2);
16
+ const int D = tokens.size(3) / 4;
17
+
18
+ auto tok = tokens.accessor<float, 4>();
19
+ auto pos = positions.accessor<int64_t, 3>();
20
+
21
+ for (int b = 0; b < B; b++) {
22
+ for (int x = 0; x < 2; x++) { // y and then x (2d)
23
+ for (int n = 0; n < N; n++) {
24
+
25
+ // grab the token position
26
+ const int p = pos[b][n][x];
27
+
28
+ for (int h = 0; h < H; h++) {
29
+ for (int d = 0; d < D; d++) {
30
+ // grab the two values
31
+ float u = tok[b][n][h][d+0+x*2*D];
32
+ float v = tok[b][n][h][d+D+x*2*D];
33
+
34
+ // grab the cos,sin
35
+ const float inv_freq = fwd * p / powf(base, d/float(D));
36
+ float c = cosf(inv_freq);
37
+ float s = sinf(inv_freq);
38
+
39
+ // write the result
40
+ tok[b][n][h][d+0+x*2*D] = u*c - v*s;
41
+ tok[b][n][h][d+D+x*2*D] = v*c + u*s;
42
+ }
43
+ }
44
+ }
45
+ }
46
+ }
47
+ }
48
+
49
+ void rope_2d( torch::Tensor tokens, // B,N,H,D
50
+ const torch::Tensor positions, // B,N,2
51
+ const float base,
52
+ const float fwd )
53
+ {
54
+ TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions");
55
+ TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions");
56
+ TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions");
57
+ TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions");
58
+ TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2");
59
+ TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" );
60
+
61
+ if (tokens.is_cuda())
62
+ rope_2d_cuda( tokens, positions, base, fwd );
63
+ else
64
+ rope_2d_cpu( tokens, positions, base, fwd );
65
+ }
66
+
67
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
68
+ m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward");
69
+ }
croco/models/curope/curope2d.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+
6
+ try:
7
+ import curope as _kernels # run `python setup.py install`
8
+ except ModuleNotFoundError:
9
+ from . import curope as _kernels # run `python setup.py build_ext --inplace`
10
+
11
+
12
+ class cuRoPE2D_func(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ def forward(ctx, tokens, positions, base, F0=1):
16
+ ctx.save_for_backward(positions)
17
+ ctx.saved_base = base
18
+ ctx.saved_F0 = F0
19
+ # tokens = tokens.clone() # uncomment this if inplace doesn't work
20
+ _kernels.rope_2d(tokens, positions, base, F0)
21
+ ctx.mark_dirty(tokens)
22
+ return tokens
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_res):
26
+ positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0
27
+ _kernels.rope_2d(grad_res, positions, base, -F0)
28
+ ctx.mark_dirty(grad_res)
29
+ return grad_res, None, None, None
30
+
31
+
32
+ class cuRoPE2D(torch.nn.Module):
33
+ def __init__(self, freq=100.0, F0=1.0):
34
+ super().__init__()
35
+ self.base = freq
36
+ self.F0 = F0
37
+
38
+ def forward(self, tokens, positions):
39
+ cuRoPE2D_func.apply(tokens.transpose(1, 2), positions, self.base, self.F0)
40
+ return tokens
croco/models/curope/kernels.cu ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (C) 2022-present Naver Corporation. All rights reserved.
3
+ Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+ #include <cuda.h>
8
+ #include <cuda_runtime.h>
9
+ #include <vector>
10
+
11
+ #define CHECK_CUDA(tensor) {\
12
+ TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
13
+ TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
14
+ void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
15
+
16
+
17
+ template < typename scalar_t >
18
+ __global__ void rope_2d_cuda_kernel(
19
+ //scalar_t* __restrict__ tokens,
20
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> tokens,
21
+ const int64_t* __restrict__ pos,
22
+ const float base,
23
+ const float fwd )
24
+ // const int N, const int H, const int D )
25
+ {
26
+ // tokens shape = (B, N, H, D)
27
+ const int N = tokens.size(1);
28
+ const int H = tokens.size(2);
29
+ const int D = tokens.size(3);
30
+
31
+ // each block update a single token, for all heads
32
+ // each thread takes care of a single output
33
+ extern __shared__ float shared[];
34
+ float* shared_inv_freq = shared + D;
35
+
36
+ const int b = blockIdx.x / N;
37
+ const int n = blockIdx.x % N;
38
+
39
+ const int Q = D / 4;
40
+ // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D]
41
+ // u_Y v_Y u_X v_X
42
+
43
+ // shared memory: first, compute inv_freq
44
+ if (threadIdx.x < Q)
45
+ shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q));
46
+ __syncthreads();
47
+
48
+ // start of X or Y part
49
+ const int X = threadIdx.x < D/2 ? 0 : 1;
50
+ const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X
51
+
52
+ // grab the cos,sin appropriate for me
53
+ const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q];
54
+ const float cos = cosf(freq);
55
+ const float sin = sinf(freq);
56
+ /*
57
+ float* shared_cos_sin = shared + D + D/4;
58
+ if ((threadIdx.x % (D/2)) < Q)
59
+ shared_cos_sin[m+0] = cosf(freq);
60
+ else
61
+ shared_cos_sin[m+Q] = sinf(freq);
62
+ __syncthreads();
63
+ const float cos = shared_cos_sin[m+0];
64
+ const float sin = shared_cos_sin[m+Q];
65
+ */
66
+
67
+ for (int h = 0; h < H; h++)
68
+ {
69
+ // then, load all the token for this head in shared memory
70
+ shared[threadIdx.x] = tokens[b][n][h][threadIdx.x];
71
+ __syncthreads();
72
+
73
+ const float u = shared[m];
74
+ const float v = shared[m+Q];
75
+
76
+ // write output
77
+ if ((threadIdx.x % (D/2)) < Q)
78
+ tokens[b][n][h][threadIdx.x] = u*cos - v*sin;
79
+ else
80
+ tokens[b][n][h][threadIdx.x] = v*cos + u*sin;
81
+ }
82
+ }
83
+
84
+ void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd )
85
+ {
86
+ const int B = tokens.size(0); // batch size
87
+ const int N = tokens.size(1); // sequence length
88
+ const int H = tokens.size(2); // number of heads
89
+ const int D = tokens.size(3); // dimension per head
90
+
91
+ TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous");
92
+ TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous");
93
+ TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape");
94
+ TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4");
95
+
96
+ // one block for each layer, one thread per local-max
97
+ const int THREADS_PER_BLOCK = D;
98
+ const int N_BLOCKS = B * N; // each block takes care of H*D values
99
+ const int SHARED_MEM = sizeof(float) * (D + D/4);
100
+
101
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] {
102
+ rope_2d_cuda_kernel<scalar_t> <<<N_BLOCKS, THREADS_PER_BLOCK, SHARED_MEM>>> (
103
+ //tokens.data_ptr<scalar_t>(),
104
+ tokens.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
105
+ pos.data_ptr<int64_t>(),
106
+ base, fwd); //, N, H, D );
107
+ }));
108
+ }
croco/models/curope/setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ from setuptools import setup
5
+ from torch import cuda
6
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7
+
8
+ # compile for all possible CUDA architectures
9
+ all_cuda_archs = cuda.get_gencode_flags().replace("compute=", "arch=").split()
10
+ # alternatively, you can list cuda archs that you want, eg:
11
+ # all_cuda_archs = [
12
+ # '-gencode', 'arch=compute_70,code=sm_70',
13
+ # '-gencode', 'arch=compute_75,code=sm_75',
14
+ # '-gencode', 'arch=compute_80,code=sm_80',
15
+ # '-gencode', 'arch=compute_86,code=sm_86'
16
+ # ]
17
+
18
+ setup(
19
+ name="curope",
20
+ ext_modules=[
21
+ CUDAExtension(
22
+ name="curope",
23
+ sources=[
24
+ "curope.cpp",
25
+ "kernels.cu",
26
+ ],
27
+ extra_compile_args=dict(
28
+ nvcc=["-O3", "--ptxas-options=-v", "--use_fast_math"] + all_cuda_archs,
29
+ cxx=["-O3"],
30
+ ),
31
+ )
32
+ ],
33
+ cmdclass={"build_ext": BuildExtension},
34
+ )
croco/models/dpt_block.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # DPT head for ViTs
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # https://github.com/isl-org/DPT
9
+ # https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+ from typing import Union, Tuple, Iterable, List, Optional, Dict
16
+
17
+
18
+ def pair(t):
19
+ return t if isinstance(t, tuple) else (t, t)
20
+
21
+
22
+ def make_scratch(in_shape, out_shape, groups=1, expand=False):
23
+ scratch = nn.Module()
24
+
25
+ out_shape1 = out_shape
26
+ out_shape2 = out_shape
27
+ out_shape3 = out_shape
28
+ out_shape4 = out_shape
29
+ if expand == True:
30
+ out_shape1 = out_shape
31
+ out_shape2 = out_shape * 2
32
+ out_shape3 = out_shape * 4
33
+ out_shape4 = out_shape * 8
34
+
35
+ scratch.layer1_rn = nn.Conv2d(
36
+ in_shape[0],
37
+ out_shape1,
38
+ kernel_size=3,
39
+ stride=1,
40
+ padding=1,
41
+ bias=False,
42
+ groups=groups,
43
+ )
44
+ scratch.layer2_rn = nn.Conv2d(
45
+ in_shape[1],
46
+ out_shape2,
47
+ kernel_size=3,
48
+ stride=1,
49
+ padding=1,
50
+ bias=False,
51
+ groups=groups,
52
+ )
53
+ scratch.layer3_rn = nn.Conv2d(
54
+ in_shape[2],
55
+ out_shape3,
56
+ kernel_size=3,
57
+ stride=1,
58
+ padding=1,
59
+ bias=False,
60
+ groups=groups,
61
+ )
62
+ scratch.layer4_rn = nn.Conv2d(
63
+ in_shape[3],
64
+ out_shape4,
65
+ kernel_size=3,
66
+ stride=1,
67
+ padding=1,
68
+ bias=False,
69
+ groups=groups,
70
+ )
71
+
72
+ scratch.layer_rn = nn.ModuleList(
73
+ [
74
+ scratch.layer1_rn,
75
+ scratch.layer2_rn,
76
+ scratch.layer3_rn,
77
+ scratch.layer4_rn,
78
+ ]
79
+ )
80
+
81
+ return scratch
82
+
83
+
84
+ class ResidualConvUnit_custom(nn.Module):
85
+ """Residual convolution module."""
86
+
87
+ def __init__(self, features, activation, bn):
88
+ """Init.
89
+ Args:
90
+ features (int): number of features
91
+ """
92
+ super().__init__()
93
+
94
+ self.bn = bn
95
+
96
+ self.groups = 1
97
+
98
+ self.conv1 = nn.Conv2d(
99
+ features,
100
+ features,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1,
104
+ bias=not self.bn,
105
+ groups=self.groups,
106
+ )
107
+
108
+ self.conv2 = nn.Conv2d(
109
+ features,
110
+ features,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1,
114
+ bias=not self.bn,
115
+ groups=self.groups,
116
+ )
117
+
118
+ if self.bn == True:
119
+ self.bn1 = nn.BatchNorm2d(features)
120
+ self.bn2 = nn.BatchNorm2d(features)
121
+
122
+ self.activation = activation
123
+
124
+ self.skip_add = nn.quantized.FloatFunctional()
125
+
126
+ def forward(self, x):
127
+ """Forward pass.
128
+ Args:
129
+ x (tensor): input
130
+ Returns:
131
+ tensor: output
132
+ """
133
+
134
+ out = self.activation(x)
135
+ out = self.conv1(out)
136
+ if self.bn == True:
137
+ out = self.bn1(out)
138
+
139
+ out = self.activation(out)
140
+ out = self.conv2(out)
141
+ if self.bn == True:
142
+ out = self.bn2(out)
143
+
144
+ if self.groups > 1:
145
+ out = self.conv_merge(out)
146
+
147
+ return self.skip_add.add(out, x)
148
+
149
+
150
+ class FeatureFusionBlock_custom(nn.Module):
151
+ """Feature fusion block."""
152
+
153
+ def __init__(
154
+ self,
155
+ features,
156
+ activation,
157
+ deconv=False,
158
+ bn=False,
159
+ expand=False,
160
+ align_corners=True,
161
+ width_ratio=1,
162
+ ):
163
+ """Init.
164
+ Args:
165
+ features (int): number of features
166
+ """
167
+ super(FeatureFusionBlock_custom, self).__init__()
168
+ self.width_ratio = width_ratio
169
+
170
+ self.deconv = deconv
171
+ self.align_corners = align_corners
172
+
173
+ self.groups = 1
174
+
175
+ self.expand = expand
176
+ out_features = features
177
+ if self.expand == True:
178
+ out_features = features // 2
179
+
180
+ self.out_conv = nn.Conv2d(
181
+ features,
182
+ out_features,
183
+ kernel_size=1,
184
+ stride=1,
185
+ padding=0,
186
+ bias=True,
187
+ groups=1,
188
+ )
189
+
190
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
191
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
192
+
193
+ self.skip_add = nn.quantized.FloatFunctional()
194
+
195
+ def forward(self, *xs):
196
+ """Forward pass.
197
+ Returns:
198
+ tensor: output
199
+ """
200
+ output = xs[0]
201
+
202
+ if len(xs) == 2:
203
+ res = self.resConfUnit1(xs[1])
204
+ if self.width_ratio != 1:
205
+ res = F.interpolate(
206
+ res, size=(output.shape[2], output.shape[3]), mode="bilinear"
207
+ )
208
+
209
+ output = self.skip_add.add(output, res)
210
+ # output += res
211
+
212
+ output = self.resConfUnit2(output)
213
+
214
+ if self.width_ratio != 1:
215
+ # and output.shape[3] < self.width_ratio * output.shape[2]
216
+ # size=(image.shape[])
217
+ if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
218
+ shape = 3 * output.shape[3]
219
+ else:
220
+ shape = int(self.width_ratio * 2 * output.shape[2])
221
+ output = F.interpolate(
222
+ output, size=(2 * output.shape[2], shape), mode="bilinear"
223
+ )
224
+ else:
225
+ output = nn.functional.interpolate(
226
+ output,
227
+ scale_factor=2,
228
+ mode="bilinear",
229
+ align_corners=self.align_corners,
230
+ )
231
+ output = self.out_conv(output)
232
+ return output
233
+
234
+
235
+ def make_fusion_block(features, use_bn, width_ratio=1):
236
+ return FeatureFusionBlock_custom(
237
+ features,
238
+ nn.ReLU(False),
239
+ deconv=False,
240
+ bn=use_bn,
241
+ expand=False,
242
+ align_corners=True,
243
+ width_ratio=width_ratio,
244
+ )
245
+
246
+
247
+ class Interpolate(nn.Module):
248
+ """Interpolation module."""
249
+
250
+ def __init__(self, scale_factor, mode, align_corners=False):
251
+ """Init.
252
+ Args:
253
+ scale_factor (float): scaling
254
+ mode (str): interpolation mode
255
+ """
256
+ super(Interpolate, self).__init__()
257
+
258
+ self.interp = nn.functional.interpolate
259
+ self.scale_factor = scale_factor
260
+ self.mode = mode
261
+ self.align_corners = align_corners
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+ Args:
266
+ x (tensor): input
267
+ Returns:
268
+ tensor: interpolated data
269
+ """
270
+
271
+ x = self.interp(
272
+ x,
273
+ scale_factor=self.scale_factor,
274
+ mode=self.mode,
275
+ align_corners=self.align_corners,
276
+ )
277
+
278
+ return x
279
+
280
+
281
+ class DPTOutputAdapter(nn.Module):
282
+ """DPT output adapter.
283
+
284
+ :param num_cahnnels: Number of output channels
285
+ :param stride_level: tride level compared to the full-sized image.
286
+ E.g. 4 for 1/4th the size of the image.
287
+ :param patch_size_full: Int or tuple of the patch size over the full image size.
288
+ Patch size for smaller inputs will be computed accordingly.
289
+ :param hooks: Index of intermediate layers
290
+ :param layer_dims: Dimension of intermediate layers
291
+ :param feature_dim: Feature dimension
292
+ :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
293
+ :param use_bn: If set to True, activates batch norm
294
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
295
+ """
296
+
297
+ def __init__(
298
+ self,
299
+ num_channels: int = 1,
300
+ stride_level: int = 1,
301
+ patch_size: Union[int, Tuple[int, int]] = 16,
302
+ main_tasks: Iterable[str] = ("rgb",),
303
+ hooks: List[int] = [2, 5, 8, 11],
304
+ layer_dims: List[int] = [96, 192, 384, 768],
305
+ feature_dim: int = 256,
306
+ last_dim: int = 32,
307
+ use_bn: bool = False,
308
+ dim_tokens_enc: Optional[int] = None,
309
+ head_type: str = "regression",
310
+ output_width_ratio=1,
311
+ **kwargs
312
+ ):
313
+ super().__init__()
314
+ self.num_channels = num_channels
315
+ self.stride_level = stride_level
316
+ self.patch_size = pair(patch_size)
317
+ self.main_tasks = main_tasks
318
+ self.hooks = hooks
319
+ self.layer_dims = layer_dims
320
+ self.feature_dim = feature_dim
321
+ self.dim_tokens_enc = (
322
+ dim_tokens_enc * len(self.main_tasks)
323
+ if dim_tokens_enc is not None
324
+ else None
325
+ )
326
+ self.head_type = head_type
327
+
328
+ # Actual patch height and width, taking into account stride of input
329
+ self.P_H = max(1, self.patch_size[0] // stride_level)
330
+ self.P_W = max(1, self.patch_size[1] // stride_level)
331
+
332
+ self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
333
+
334
+ self.scratch.refinenet1 = make_fusion_block(
335
+ feature_dim, use_bn, output_width_ratio
336
+ )
337
+ self.scratch.refinenet2 = make_fusion_block(
338
+ feature_dim, use_bn, output_width_ratio
339
+ )
340
+ self.scratch.refinenet3 = make_fusion_block(
341
+ feature_dim, use_bn, output_width_ratio
342
+ )
343
+ self.scratch.refinenet4 = make_fusion_block(
344
+ feature_dim, use_bn, output_width_ratio
345
+ )
346
+
347
+ if self.head_type == "regression":
348
+ # The "DPTDepthModel" head
349
+ self.head = nn.Sequential(
350
+ nn.Conv2d(
351
+ feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1
352
+ ),
353
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
354
+ nn.Conv2d(
355
+ feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1
356
+ ),
357
+ nn.ReLU(True),
358
+ nn.Conv2d(
359
+ last_dim, self.num_channels, kernel_size=1, stride=1, padding=0
360
+ ),
361
+ )
362
+ elif self.head_type == "semseg":
363
+ # The "DPTSegmentationModel" head
364
+ self.head = nn.Sequential(
365
+ nn.Conv2d(
366
+ feature_dim, feature_dim, kernel_size=3, padding=1, bias=False
367
+ ),
368
+ nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
369
+ nn.ReLU(True),
370
+ nn.Dropout(0.1, False),
371
+ nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
372
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
373
+ )
374
+ else:
375
+ raise ValueError('DPT head_type must be "regression" or "semseg".')
376
+
377
+ if self.dim_tokens_enc is not None:
378
+ self.init(dim_tokens_enc=dim_tokens_enc)
379
+
380
+ def init(self, dim_tokens_enc=768):
381
+ """
382
+ Initialize parts of decoder that are dependent on dimension of encoder tokens.
383
+ Should be called when setting up MultiMAE.
384
+
385
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
386
+ """
387
+ # print(dim_tokens_enc)
388
+
389
+ # Set up activation postprocessing layers
390
+ if isinstance(dim_tokens_enc, int):
391
+ dim_tokens_enc = 4 * [dim_tokens_enc]
392
+
393
+ self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
394
+
395
+ self.act_1_postprocess = nn.Sequential(
396
+ nn.Conv2d(
397
+ in_channels=self.dim_tokens_enc[0],
398
+ out_channels=self.layer_dims[0],
399
+ kernel_size=1,
400
+ stride=1,
401
+ padding=0,
402
+ ),
403
+ nn.ConvTranspose2d(
404
+ in_channels=self.layer_dims[0],
405
+ out_channels=self.layer_dims[0],
406
+ kernel_size=4,
407
+ stride=4,
408
+ padding=0,
409
+ bias=True,
410
+ dilation=1,
411
+ groups=1,
412
+ ),
413
+ )
414
+
415
+ self.act_2_postprocess = nn.Sequential(
416
+ nn.Conv2d(
417
+ in_channels=self.dim_tokens_enc[1],
418
+ out_channels=self.layer_dims[1],
419
+ kernel_size=1,
420
+ stride=1,
421
+ padding=0,
422
+ ),
423
+ nn.ConvTranspose2d(
424
+ in_channels=self.layer_dims[1],
425
+ out_channels=self.layer_dims[1],
426
+ kernel_size=2,
427
+ stride=2,
428
+ padding=0,
429
+ bias=True,
430
+ dilation=1,
431
+ groups=1,
432
+ ),
433
+ )
434
+
435
+ self.act_3_postprocess = nn.Sequential(
436
+ nn.Conv2d(
437
+ in_channels=self.dim_tokens_enc[2],
438
+ out_channels=self.layer_dims[2],
439
+ kernel_size=1,
440
+ stride=1,
441
+ padding=0,
442
+ )
443
+ )
444
+
445
+ self.act_4_postprocess = nn.Sequential(
446
+ nn.Conv2d(
447
+ in_channels=self.dim_tokens_enc[3],
448
+ out_channels=self.layer_dims[3],
449
+ kernel_size=1,
450
+ stride=1,
451
+ padding=0,
452
+ ),
453
+ nn.Conv2d(
454
+ in_channels=self.layer_dims[3],
455
+ out_channels=self.layer_dims[3],
456
+ kernel_size=3,
457
+ stride=2,
458
+ padding=1,
459
+ ),
460
+ )
461
+
462
+ self.act_postprocess = nn.ModuleList(
463
+ [
464
+ self.act_1_postprocess,
465
+ self.act_2_postprocess,
466
+ self.act_3_postprocess,
467
+ self.act_4_postprocess,
468
+ ]
469
+ )
470
+
471
+ def adapt_tokens(self, encoder_tokens):
472
+ # Adapt tokens
473
+ x = []
474
+ x.append(encoder_tokens[:, :])
475
+ x = torch.cat(x, dim=-1)
476
+ return x
477
+
478
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size):
479
+ # input_info: Dict):
480
+ assert (
481
+ self.dim_tokens_enc is not None
482
+ ), "Need to call init(dim_tokens_enc) function first"
483
+ H, W = image_size
484
+
485
+ # Number of patches in height and width
486
+ N_H = H // (self.stride_level * self.P_H)
487
+ N_W = W // (self.stride_level * self.P_W)
488
+
489
+ # Hook decoder onto 4 layers from specified ViT layers
490
+ layers = [encoder_tokens[hook] for hook in self.hooks]
491
+
492
+ # Extract only task-relevant tokens and ignore global tokens.
493
+ layers = [self.adapt_tokens(l) for l in layers]
494
+
495
+ # Reshape tokens to spatial representation
496
+ layers = [
497
+ rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers
498
+ ]
499
+
500
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
501
+ # Project layers to chosen feature dim
502
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
503
+
504
+ # Fuse layers using refinement stages
505
+ path_4 = self.scratch.refinenet4(layers[3])
506
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
507
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
508
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
509
+
510
+ # Output head
511
+ out = self.head(path_1)
512
+
513
+ return out