Spaces:
Running
on
Zero
Running
on
Zero
zixinz
commited on
Commit
·
5a0778e
1
Parent(s):
f1483c5
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +56 -0
- code_depth/LICENSE +201 -0
- code_depth/README.md +120 -0
- code_depth/app.py +152 -0
- code_depth/assets/example_videos/Tokyo-Walk_rgb.mp4 +3 -0
- code_depth/assets/example_videos/davis_rollercoaster.mp4 +3 -0
- code_depth/assets/teaser_video_v2.png +3 -0
- code_depth/benchmark/README.md +34 -0
- code_depth/benchmark/__init__.py +0 -0
- code_depth/benchmark/dataset_extract/dataset_extract_bonn.py +86 -0
- code_depth/benchmark/dataset_extract/dataset_extract_kitti.py +84 -0
- code_depth/benchmark/dataset_extract/dataset_extract_nyuv2.py +76 -0
- code_depth/benchmark/dataset_extract/dataset_extract_scannet.py +124 -0
- code_depth/benchmark/dataset_extract/dataset_extract_sintel.py +110 -0
- code_depth/benchmark/dataset_extract/eval_utils.py +140 -0
- code_depth/benchmark/eval/eval.py +265 -0
- code_depth/benchmark/eval/eval.sh +30 -0
- code_depth/benchmark/eval/eval_500.sh +30 -0
- code_depth/benchmark/eval/eval_tae.py +295 -0
- code_depth/benchmark/eval/eval_tae.sh +18 -0
- code_depth/benchmark/eval/metric.py +117 -0
- code_depth/benchmark/infer/infer.py +65 -0
- code_depth/get_weights.sh +6 -0
- code_depth/large_files.txt +2 -0
- code_depth/requirements.txt +14 -0
- code_depth/run.py +81 -0
- code_depth/run_images_rord.py +112 -0
- code_depth/run_single_image.py +69 -0
- code_depth/utils/dc_utils.py +86 -0
- code_depth/utils/util.py +74 -0
- code_depth/video_depth_anything/dinov2.py +415 -0
- code_depth/video_depth_anything/dinov2_layers/__init__.py +11 -0
- code_depth/video_depth_anything/dinov2_layers/attention.py +83 -0
- code_depth/video_depth_anything/dinov2_layers/block.py +252 -0
- code_depth/video_depth_anything/dinov2_layers/drop_path.py +35 -0
- code_depth/video_depth_anything/dinov2_layers/layer_scale.py +28 -0
- code_depth/video_depth_anything/dinov2_layers/mlp.py +41 -0
- code_depth/video_depth_anything/dinov2_layers/patch_embed.py +89 -0
- code_depth/video_depth_anything/dinov2_layers/swiglu_ffn.py +63 -0
- code_depth/video_depth_anything/dpt.py +160 -0
- code_depth/video_depth_anything/dpt_temporal.py +114 -0
- code_depth/video_depth_anything/motion_module/attention.py +429 -0
- code_depth/video_depth_anything/motion_module/motion_module.py +297 -0
- code_depth/video_depth_anything/util/blocks.py +162 -0
- code_depth/video_depth_anything/util/transform.py +158 -0
- code_depth/video_depth_anything/video_depth.py +156 -0
- code_edit/.gradio/certificate.pem +31 -0
- code_edit/Flux_fill_d2i.py +53 -0
- code_edit/Flux_fill_infer_depth.py +64 -0
- code_edit/README.md +93 -0
app.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pathlib
|
| 3 |
+
import subprocess
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import spaces
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
# ---------- 权重下载:强制在 code_depth 下执行你的脚本 ----------
|
| 9 |
+
BASE_DIR = pathlib.Path(__file__).resolve().parent
|
| 10 |
+
SCRIPT_DIR = BASE_DIR / "code_depth"
|
| 11 |
+
GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh"
|
| 12 |
+
|
| 13 |
+
def ensure_executable(path: pathlib.Path):
|
| 14 |
+
if not path.exists():
|
| 15 |
+
raise FileNotFoundError(f"Download script not found: {path}")
|
| 16 |
+
os.chmod(path, os.stat(path).st_mode | 0o111)
|
| 17 |
+
|
| 18 |
+
def ensure_weights() -> str:
|
| 19 |
+
"""
|
| 20 |
+
在 code_depth 目录下运行 get_weights.sh。
|
| 21 |
+
该脚本会在 code_depth/ 下创建 checkpoints/ 并下载权重。
|
| 22 |
+
返回绝对路径:<repo_root>/code_depth/checkpoints
|
| 23 |
+
"""
|
| 24 |
+
ensure_executable(GET_WEIGHTS_SH)
|
| 25 |
+
# 你脚本的工作目录需要是 code_depth
|
| 26 |
+
subprocess.run(
|
| 27 |
+
["bash", str(GET_WEIGHTS_SH)],
|
| 28 |
+
check=True,
|
| 29 |
+
cwd=str(SCRIPT_DIR),
|
| 30 |
+
env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
|
| 31 |
+
)
|
| 32 |
+
ckpt_dir = SCRIPT_DIR / "checkpoints"
|
| 33 |
+
return str(ckpt_dir)
|
| 34 |
+
|
| 35 |
+
# 启动时先拉权重(不开 Persistent Storage 时,重建环境会清空;重启后会自动再拉一次)
|
| 36 |
+
try:
|
| 37 |
+
CKPT_DIR = ensure_weights()
|
| 38 |
+
print(f"✅ Weights ready in: {CKPT_DIR}")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"⚠️ Failed to prepare weights: {e}")
|
| 41 |
+
CKPT_DIR = str(SCRIPT_DIR / "checkpoints") # 仍然给个路径,后续可检查是否存在
|
| 42 |
+
|
| 43 |
+
# ---------- Gradio 推理函数 ----------
|
| 44 |
+
@spaces.GPU
|
| 45 |
+
def greet(n: float):
|
| 46 |
+
# 在 GPU worker 里拿 device
|
| 47 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 48 |
+
zero = torch.tensor([0.0], device=device)
|
| 49 |
+
# 仅示例输出,你可以在这里用 CKPT_DIR 加载你的模型
|
| 50 |
+
print(f"Device in greet(): {device}")
|
| 51 |
+
print(f"Using checkpoints from: {CKPT_DIR}")
|
| 52 |
+
return f"Hello {(zero + n).item()} Tensor (device={device})"
|
| 53 |
+
|
| 54 |
+
demo = gr.Interface(fn=greet, inputs=gr.Number(label="n"), outputs=gr.Text())
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
code_depth/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
code_depth/README.md
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h1>Video Depth Anything</h1>
|
| 3 |
+
|
| 4 |
+
[**Sili Chen**](https://github.com/SiliChen321) · [**Hengkai Guo**](https://guohengkai.github.io/)<sup>†</sup> · [**Shengnan Zhu**](https://github.com/Shengnan-Zhu) · [**Feihu Zhang**](https://github.com/zhizunhu)
|
| 5 |
+
<br>
|
| 6 |
+
[**Zilong Huang**](http://speedinghzl.github.io/) · [**Jiashi Feng**](https://scholar.google.com.sg/citations?user=Q8iay0gAAAAJ&hl=en) · [**Bingyi Kang**](https://bingykang.github.io/)<sup>†</sup>
|
| 7 |
+
<br>
|
| 8 |
+
ByteDance
|
| 9 |
+
<br>
|
| 10 |
+
†Corresponding author
|
| 11 |
+
|
| 12 |
+
<a href="https://arxiv.org/abs/2501.12375"><img src='https://img.shields.io/badge/arXiv-Video Depth Anything-red' alt='Paper PDF'></a>
|
| 13 |
+
<a href='https://videodepthanything.github.io'><img src='https://img.shields.io/badge/Project_Page-Video Depth Anything-green' alt='Project Page'></a>
|
| 14 |
+
<a href='https://huggingface.co/spaces/depth-anything/Video-Depth-Anything'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
</div>
|
| 18 |
+
|
| 19 |
+
This work presents **Video Depth Anything** based on [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2), which can be applied to arbitrarily long videos without compromising quality, consistency, or generalization ability. Compared with other diffusion-based models, it enjoys faster inference speed, fewer parameters, and higher consistent depth accuracy.
|
| 20 |
+
|
| 21 |
+

|
| 22 |
+
|
| 23 |
+
## News
|
| 24 |
+
- **2025-03-11:** Add full dataset inference and evaluation scripts.
|
| 25 |
+
- **2025-02-08:** Enable autocast inference. Support grayscale video, NPZ and EXR output formats.
|
| 26 |
+
- **2025-01-21:** Paper, project page, code, models, and demo are all released.
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## Release Notes
|
| 30 |
+
- **2025-02-08:** 🚀🚀🚀 Inference speed and memory usage improvement
|
| 31 |
+
<table>
|
| 32 |
+
<thead>
|
| 33 |
+
<tr>
|
| 34 |
+
<th rowspan="2" style="text-align: center;">Model</th>
|
| 35 |
+
<th colspan="2">Latency (ms)</th>
|
| 36 |
+
<th colspan="2">GPU VRAM (GB)</th>
|
| 37 |
+
</tr>
|
| 38 |
+
<tr>
|
| 39 |
+
<th>FP32</th>
|
| 40 |
+
<th>FP16</th>
|
| 41 |
+
<th>FP32</th>
|
| 42 |
+
<th>FP16</th>
|
| 43 |
+
</tr>
|
| 44 |
+
</thead>
|
| 45 |
+
<tbody>
|
| 46 |
+
<tr>
|
| 47 |
+
<td>Video-Depth-Anything-V2-Small</td>
|
| 48 |
+
<td>9.1</td>
|
| 49 |
+
<td><strong>7.5</strong></td>
|
| 50 |
+
<td>7.3</td>
|
| 51 |
+
<td><strong>6.8</strong></td>
|
| 52 |
+
</tr>
|
| 53 |
+
<tr>
|
| 54 |
+
<td>Video-Depth-Anything-V2-Large</td>
|
| 55 |
+
<td>67</td>
|
| 56 |
+
<td><strong>14</strong></td>
|
| 57 |
+
<td>26.7</td>
|
| 58 |
+
<td><strong>23.6</strong></td>
|
| 59 |
+
</tbody>
|
| 60 |
+
</table>
|
| 61 |
+
|
| 62 |
+
The Latency and GPU VRAM results are obtained on a single A100 GPU with input of shape 1 x 32 x 518 × 518.
|
| 63 |
+
|
| 64 |
+
## Pre-trained Models
|
| 65 |
+
We provide **two models** of varying scales for robust and consistent video depth estimation:
|
| 66 |
+
|
| 67 |
+
| Model | Params | Checkpoint |
|
| 68 |
+
|:-|-:|:-:|
|
| 69 |
+
| Video-Depth-Anything-V2-Small | 28.4M | [Download](https://huggingface.co/depth-anything/Video-Depth-Anything-Small/resolve/main/video_depth_anything_vits.pth?download=true) |
|
| 70 |
+
| Video-Depth-Anything-V2-Large | 381.8M | [Download](https://huggingface.co/depth-anything/Video-Depth-Anything-Large/resolve/main/video_depth_anything_vitl.pth?download=true) |
|
| 71 |
+
|
| 72 |
+
## Usage
|
| 73 |
+
|
| 74 |
+
### Preparation
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
git clone https://github.com/DepthAnything/Video-Depth-Anything
|
| 78 |
+
cd Video-Depth-Anything
|
| 79 |
+
pip install -r requirements.txt
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Download the checkpoints listed [here](#pre-trained-models) and put them under the `checkpoints` directory.
|
| 83 |
+
```bash
|
| 84 |
+
bash get_weights.sh
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Inference a video
|
| 88 |
+
```bash
|
| 89 |
+
python3 run.py --input_video ./assets/example_videos/davis_rollercoaster.mp4 --output_dir ./outputs --encoder vitl
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Options:
|
| 93 |
+
- `--input_video`: path of input video
|
| 94 |
+
- `--output_dir`: path to save the output results
|
| 95 |
+
- `--input_size` (optional): By default, we use input size `518` for model inference.
|
| 96 |
+
- `--max_res` (optional): By default, we use maximum resolution `1280` for model inference.
|
| 97 |
+
- `--encoder` (optional): `vits` for Video-Depth-Anything-V2-Small, `vitl` for Video-Depth-Anything-V2-Large.
|
| 98 |
+
- `--max_len` (optional): maximum length of the input video, `-1` means no limit
|
| 99 |
+
- `--target_fps` (optional): target fps of the input video, `-1` means the original fps
|
| 100 |
+
- `--fp32` (optional): Use `fp32` precision for inference. By default, we use `fp16`.
|
| 101 |
+
- `--grayscale` (optional): Save the grayscale depth map, without applying color palette.
|
| 102 |
+
- `--save_npz` (optional): Save the depth map in `npz` format.
|
| 103 |
+
- `--save_exr` (optional): Save the depth map in `exr` format.
|
| 104 |
+
|
| 105 |
+
## Citation
|
| 106 |
+
|
| 107 |
+
If you find this project useful, please consider citing:
|
| 108 |
+
|
| 109 |
+
```bibtex
|
| 110 |
+
@article{video_depth_anything,
|
| 111 |
+
title={Video Depth Anything: Consistent Depth Estimation for Super-Long Videos},
|
| 112 |
+
author={Chen, Sili and Guo, Hengkai and Zhu, Shengnan and Zhang, Feihu and Huang, Zilong and Feng, Jiashi and Kang, Bingyi}
|
| 113 |
+
journal={arXiv:2501.12375},
|
| 114 |
+
year={2025}
|
| 115 |
+
}
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
## LICENSE
|
| 120 |
+
Video-Depth-Anything-Small model is under the Apache-2.0 license. Video-Depth-Anything-Large model is under the CC-BY-NC-4.0 license. For business cooperation, please send an email to Hengkai Guo at [email protected].
|
code_depth/app.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import gradio as gr
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import os
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from video_depth_anything.video_depth import VideoDepthAnything
|
| 21 |
+
from utils.dc_utils import read_video_frames, save_video
|
| 22 |
+
|
| 23 |
+
examples = [
|
| 24 |
+
['assets/example_videos/davis_rollercoaster.mp4', -1, -1, 1280],
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
model_configs = {
|
| 28 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 29 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
encoder='vitl'
|
| 33 |
+
|
| 34 |
+
video_depth_anything = VideoDepthAnything(**model_configs[encoder])
|
| 35 |
+
video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{encoder}.pth', map_location='cpu'), strict=True)
|
| 36 |
+
video_depth_anything = video_depth_anything.to('cuda').eval()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def infer_video_depth(
|
| 40 |
+
input_video: str,
|
| 41 |
+
max_len: int = -1,
|
| 42 |
+
target_fps: int = -1,
|
| 43 |
+
max_res: int = 1280,
|
| 44 |
+
output_dir: str = './outputs',
|
| 45 |
+
input_size: int = 518,
|
| 46 |
+
):
|
| 47 |
+
frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
|
| 48 |
+
depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device='cuda')
|
| 49 |
+
|
| 50 |
+
video_name = os.path.basename(input_video)
|
| 51 |
+
if not os.path.exists(output_dir):
|
| 52 |
+
os.makedirs(output_dir)
|
| 53 |
+
|
| 54 |
+
processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
|
| 55 |
+
depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
|
| 56 |
+
save_video(frames, processed_video_path, fps=fps)
|
| 57 |
+
save_video(depths, depth_vis_path, fps=fps, is_depths=True)
|
| 58 |
+
|
| 59 |
+
return [processed_video_path, depth_vis_path]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def construct_demo():
|
| 63 |
+
with gr.Blocks(analytics_enabled=False) as demo:
|
| 64 |
+
gr.Markdown(
|
| 65 |
+
f"""
|
| 66 |
+
blablabla
|
| 67 |
+
"""
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
with gr.Row(equal_height=True):
|
| 71 |
+
with gr.Column(scale=1):
|
| 72 |
+
input_video = gr.Video(label="Input Video")
|
| 73 |
+
|
| 74 |
+
# with gr.Tab(label="Output"):
|
| 75 |
+
with gr.Column(scale=2):
|
| 76 |
+
with gr.Row(equal_height=True):
|
| 77 |
+
processed_video = gr.Video(
|
| 78 |
+
label="Preprocessed video",
|
| 79 |
+
interactive=False,
|
| 80 |
+
autoplay=True,
|
| 81 |
+
loop=True,
|
| 82 |
+
show_share_button=True,
|
| 83 |
+
scale=5,
|
| 84 |
+
)
|
| 85 |
+
depth_vis_video = gr.Video(
|
| 86 |
+
label="Generated Depth Video",
|
| 87 |
+
interactive=False,
|
| 88 |
+
autoplay=True,
|
| 89 |
+
loop=True,
|
| 90 |
+
show_share_button=True,
|
| 91 |
+
scale=5,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
with gr.Row(equal_height=True):
|
| 95 |
+
with gr.Column(scale=1):
|
| 96 |
+
with gr.Row(equal_height=False):
|
| 97 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 98 |
+
max_len = gr.Slider(
|
| 99 |
+
label="max process length",
|
| 100 |
+
minimum=-1,
|
| 101 |
+
maximum=1000,
|
| 102 |
+
value=-1,
|
| 103 |
+
step=1,
|
| 104 |
+
)
|
| 105 |
+
target_fps = gr.Slider(
|
| 106 |
+
label="target FPS",
|
| 107 |
+
minimum=-1,
|
| 108 |
+
maximum=30,
|
| 109 |
+
value=15,
|
| 110 |
+
step=1,
|
| 111 |
+
)
|
| 112 |
+
max_res = gr.Slider(
|
| 113 |
+
label="max side resolution",
|
| 114 |
+
minimum=480,
|
| 115 |
+
maximum=1920,
|
| 116 |
+
value=1280,
|
| 117 |
+
step=1,
|
| 118 |
+
)
|
| 119 |
+
generate_btn = gr.Button("Generate")
|
| 120 |
+
with gr.Column(scale=2):
|
| 121 |
+
pass
|
| 122 |
+
|
| 123 |
+
gr.Examples(
|
| 124 |
+
examples=examples,
|
| 125 |
+
inputs=[
|
| 126 |
+
input_video,
|
| 127 |
+
max_len,
|
| 128 |
+
target_fps,
|
| 129 |
+
max_res
|
| 130 |
+
],
|
| 131 |
+
outputs=[processed_video, depth_vis_video],
|
| 132 |
+
fn=infer_video_depth,
|
| 133 |
+
cache_examples="lazy",
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
generate_btn.click(
|
| 137 |
+
fn=infer_video_depth,
|
| 138 |
+
inputs=[
|
| 139 |
+
input_video,
|
| 140 |
+
max_len,
|
| 141 |
+
target_fps,
|
| 142 |
+
max_res
|
| 143 |
+
],
|
| 144 |
+
outputs=[processed_video, depth_vis_video],
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
return demo
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
demo = construct_demo()
|
| 151 |
+
demo.queue()
|
| 152 |
+
demo.launch(server_name="0.0.0.0")
|
code_depth/assets/example_videos/Tokyo-Walk_rgb.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:097f16c33dd8c8d1d2a24d9ea31a90b76bd0ee324b958a47385183e3547a63a8
|
| 3 |
+
size 2251450
|
code_depth/assets/example_videos/davis_rollercoaster.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7268cbecd9806a1e90a416de50dc02e50b4ae01428d5971837cf679dd0c87cb8
|
| 3 |
+
size 1809560
|
code_depth/assets/teaser_video_v2.png
ADDED
|
Git LFS Details
|
code_depth/benchmark/README.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BENCHMARK
|
| 2 |
+
|
| 3 |
+
## Prepare Dataset
|
| 4 |
+
Download datasets from the following links:
|
| 5 |
+
[sintel](http://sintel.is.tue.mpg.de/) [kitti](https://www.cvlibs.net/datasets/kitti/) [bonn](https://www.ipb.uni-bonn.de/data/rgbd-dynamic-dataset/index.html) [scannet](http://www.scan-net.org/) [nyuv2](https://cs.nyu.edu/~fergus/datasets/nyu_depth_v2.html)
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
pip3 install natsort
|
| 9 |
+
cd benchmark/dataset_extract
|
| 10 |
+
python3 dataset_extrtact${dataset}.py
|
| 11 |
+
```
|
| 12 |
+
This script will extract the dataset to the `benchmark/dataset_extract/dataset` folder. It will also generate the json file for the dataset.
|
| 13 |
+
|
| 14 |
+
## Run inference
|
| 15 |
+
```bash
|
| 16 |
+
python3 benchmark/infer/infer.py \
|
| 17 |
+
--infer_path ${out_path} \
|
| 18 |
+
--json_file ${json_path} \
|
| 19 |
+
--datasets ${dataset}
|
| 20 |
+
```
|
| 21 |
+
Options:
|
| 22 |
+
- `--infer_path`: path to save the output results
|
| 23 |
+
- `--json_file`: path to the json file for the dataset
|
| 24 |
+
- `--datasets`: dataset name, choose from `sintel`, `kitti`, `bonn`, `scannet`, `nyuv2`
|
| 25 |
+
|
| 26 |
+
## Run evaluation
|
| 27 |
+
```bash
|
| 28 |
+
## tae
|
| 29 |
+
bash benchmark/eval/eval_tae.sh ${out_path} benchmark/dataset_extract/dataset
|
| 30 |
+
## ~110frame like DepthCrafter
|
| 31 |
+
bash benchmark/eval/eval.sh ${out_path} benchmark/dataset_extract/dataset
|
| 32 |
+
## ~500frame
|
| 33 |
+
bash benchmark/eval/eval_500.sh ${out_path} benchmark/dataset_extract/dataset
|
| 34 |
+
```
|
code_depth/benchmark/__init__.py
ADDED
|
File without changes
|
code_depth/benchmark/dataset_extract/dataset_extract_bonn.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os.path as osp
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import cv2
|
| 7 |
+
import csv
|
| 8 |
+
import json
|
| 9 |
+
import glob
|
| 10 |
+
import shutil
|
| 11 |
+
from natsort import natsorted
|
| 12 |
+
|
| 13 |
+
from eval_utils import gen_json, get_sorted_files, even_or_odd, copy_crop_files
|
| 14 |
+
|
| 15 |
+
def extract_bonn(
|
| 16 |
+
root,
|
| 17 |
+
depth_root,
|
| 18 |
+
saved_dir,
|
| 19 |
+
sample_len,
|
| 20 |
+
datatset_name,
|
| 21 |
+
):
|
| 22 |
+
scenes_names = os.listdir(depth_root)
|
| 23 |
+
all_samples = []
|
| 24 |
+
for i, seq_name in enumerate(tqdm(scenes_names)):
|
| 25 |
+
# load all images
|
| 26 |
+
all_img_names = get_sorted_files(
|
| 27 |
+
root=osp.join(depth_root, seq_name, "rgb"), suffix=".png"
|
| 28 |
+
)
|
| 29 |
+
all_depth_names = get_sorted_files(
|
| 30 |
+
root=osp.join(depth_root, seq_name, "depth"), suffix=".png"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
seq_len = len(all_img_names)
|
| 34 |
+
step = sample_len if sample_len > 0 else seq_len
|
| 35 |
+
|
| 36 |
+
for ref_idx in range(0, seq_len, step):
|
| 37 |
+
print(f"Progress: {seq_name}, {ref_idx // step + 1} / {seq_len//step}")
|
| 38 |
+
|
| 39 |
+
if (ref_idx + step) <= seq_len:
|
| 40 |
+
ref_e = ref_idx + step
|
| 41 |
+
else:
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
for idx in range(ref_idx, ref_e):
|
| 45 |
+
im_path = osp.join(
|
| 46 |
+
root, seq_name, "rgb", all_img_names[idx]
|
| 47 |
+
)
|
| 48 |
+
depth_path = osp.join(
|
| 49 |
+
depth_root, seq_name, "depth", all_depth_names[idx]
|
| 50 |
+
)
|
| 51 |
+
out_img_path = osp.join(
|
| 52 |
+
saved_dir, datatset_name,seq_name, "rgb", all_img_names[idx]
|
| 53 |
+
)
|
| 54 |
+
out_depth_path = osp.join(
|
| 55 |
+
saved_dir, datatset_name,seq_name, "depth", all_depth_names[idx]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
copy_crop_files(
|
| 59 |
+
im_path=im_path,
|
| 60 |
+
depth_path=depth_path,
|
| 61 |
+
out_img_path=out_img_path,
|
| 62 |
+
out_depth_path=out_depth_path,
|
| 63 |
+
dataset=datatset_name,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# 110 frames like DepthCraft
|
| 67 |
+
out_json_path = osp.join(saved_dir, datatset_name, "bonn_video.json")
|
| 68 |
+
gen_json(
|
| 69 |
+
root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
|
| 70 |
+
start_id=30, end_id=140, step=1, save_path=out_json_path)
|
| 71 |
+
|
| 72 |
+
#~500 frames in paper
|
| 73 |
+
out_json_path = osp.join(saved_dir, datatset_name, "bonn_video_500.json")
|
| 74 |
+
gen_json(
|
| 75 |
+
root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
|
| 76 |
+
start_id=0, end_id=500, step=1, save_path=out_json_path)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
extract_bonn(
|
| 81 |
+
root="path/to/Bonn-RGBD",
|
| 82 |
+
depth_root="path/to/Bonn-RGBD",
|
| 83 |
+
saved_dir="./benchmark/datasets/",
|
| 84 |
+
sample_len=-1,
|
| 85 |
+
datatset_name="bonn",
|
| 86 |
+
)
|
code_depth/benchmark/dataset_extract/dataset_extract_kitti.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os.path as osp
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import csv
|
| 7 |
+
import cv2
|
| 8 |
+
import json
|
| 9 |
+
import glob
|
| 10 |
+
import shutil
|
| 11 |
+
from natsort import natsorted
|
| 12 |
+
|
| 13 |
+
from eval_utils import even_or_odd
|
| 14 |
+
from eval_utils import gen_json, get_sorted_files, copy_crop_files
|
| 15 |
+
|
| 16 |
+
def extract_kitti(
|
| 17 |
+
root,
|
| 18 |
+
depth_root,
|
| 19 |
+
sample_len=-1,
|
| 20 |
+
saved_dir="",
|
| 21 |
+
datatset_name="",
|
| 22 |
+
):
|
| 23 |
+
scenes_names = os.listdir(depth_root)
|
| 24 |
+
all_samples = []
|
| 25 |
+
for i, seq_name in enumerate(tqdm(scenes_names)):
|
| 26 |
+
|
| 27 |
+
all_img_names = get_sorted_files(
|
| 28 |
+
osp.join(depth_root, seq_name, "proj_depth/groundtruth/image_02"), suffix=".png"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
seq_len = len(all_img_names)
|
| 32 |
+
step = sample_len if sample_len > 0 else seq_len
|
| 33 |
+
|
| 34 |
+
for ref_idx in range(0, seq_len, step):
|
| 35 |
+
print(f"Progress: {seq_name}, {ref_idx // step + 1} / {seq_len//step}")
|
| 36 |
+
video_imgs = []
|
| 37 |
+
video_depths = []
|
| 38 |
+
|
| 39 |
+
if (ref_idx + step) <= seq_len:
|
| 40 |
+
ref_e = ref_idx + step
|
| 41 |
+
else:
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
for idx in range(ref_idx, ref_e):
|
| 45 |
+
im_path = osp.join(
|
| 46 |
+
root, seq_name[0:10], seq_name, "image_02/data", all_img_names[idx]
|
| 47 |
+
)
|
| 48 |
+
depth_path = osp.join(
|
| 49 |
+
depth_root, seq_name, "proj_depth/groundtruth/image_02", all_img_names[idx],
|
| 50 |
+
)
|
| 51 |
+
out_img_path = osp.join(
|
| 52 |
+
saved_dir, datatset_name,seq_name, "rgb", all_img_names[idx]
|
| 53 |
+
)
|
| 54 |
+
out_depth_path = osp.join(
|
| 55 |
+
saved_dir, datatset_name,seq_name, "depth", all_img_names[idx]
|
| 56 |
+
)
|
| 57 |
+
copy_crop_files(
|
| 58 |
+
im_path=im_path,
|
| 59 |
+
depth_path=depth_path,
|
| 60 |
+
out_img_path=out_img_path,
|
| 61 |
+
out_depth_path=out_depth_path,
|
| 62 |
+
dataset=datatset_name,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# 110 frames like DepthCraft
|
| 66 |
+
out_json_path = osp.join(saved_dir, datatset_name, "kitti_video.json")
|
| 67 |
+
gen_json(
|
| 68 |
+
root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
|
| 69 |
+
start_id=0, end_id=110, step=1, save_path=out_json_path)
|
| 70 |
+
|
| 71 |
+
#~500 frames in paper
|
| 72 |
+
out_json_path = osp.join(saved_dir, datatset_name, "kitti_video_500.json")
|
| 73 |
+
gen_json(
|
| 74 |
+
root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
|
| 75 |
+
start_id=0, end_id=500, step=1, save_path=out_json_path)
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
extract_kitti(
|
| 79 |
+
root="path/to/kitti",
|
| 80 |
+
depth_root="path/to/kitti/val",
|
| 81 |
+
saved_dir="./benchmark/datasets/",
|
| 82 |
+
sample_len=-1,
|
| 83 |
+
datatset_name="kitti",
|
| 84 |
+
)
|
code_depth/benchmark/dataset_extract/dataset_extract_nyuv2.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os.path as osp
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import csv
|
| 7 |
+
import cv2
|
| 8 |
+
import json
|
| 9 |
+
import glob
|
| 10 |
+
from natsort import natsorted
|
| 11 |
+
import shutil
|
| 12 |
+
|
| 13 |
+
from eval_utils import gen_json, get_sorted_files, copy_crop_files
|
| 14 |
+
|
| 15 |
+
def extract_nyuv2(
|
| 16 |
+
root,
|
| 17 |
+
sample_len=-1,
|
| 18 |
+
datatset_name="",
|
| 19 |
+
saved_dir="",
|
| 20 |
+
):
|
| 21 |
+
scenes_names = os.listdir(root)
|
| 22 |
+
scenes_names = sorted(scenes_names)
|
| 23 |
+
all_samples = []
|
| 24 |
+
for i, seq_name in enumerate(tqdm(scenes_names)):
|
| 25 |
+
all_img_names = get_sorted_files(
|
| 26 |
+
osp.join(root, seq_name, "rgb"), suffix=".jpg")
|
| 27 |
+
|
| 28 |
+
seq_len = len(all_img_names)
|
| 29 |
+
step = sample_len if sample_len > 0 else seq_len
|
| 30 |
+
|
| 31 |
+
for ref_idx in range(0, seq_len, step):
|
| 32 |
+
print(f"Progress: {seq_name}, {ref_idx // step + 1} / {seq_len//step}")
|
| 33 |
+
|
| 34 |
+
if (ref_idx + step) <= seq_len:
|
| 35 |
+
ref_e = ref_idx + step
|
| 36 |
+
else:
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
for idx in range(ref_idx, ref_e):
|
| 40 |
+
im_path = osp.join(
|
| 41 |
+
root, seq_name, "rgb", all_img_names[idx]
|
| 42 |
+
)
|
| 43 |
+
depth_path = osp.join(
|
| 44 |
+
root, seq_name, "depth", all_img_names[idx][:-3] + "png"
|
| 45 |
+
)
|
| 46 |
+
out_img_path = osp.join(
|
| 47 |
+
saved_dir, datatset_name, seq_name, "rgb", all_img_names[idx]
|
| 48 |
+
)
|
| 49 |
+
out_depth_path = osp.join(
|
| 50 |
+
saved_dir, datatset_name, seq_name, "depth", all_img_names[idx][:-3] + "png"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
copy_crop_files(
|
| 54 |
+
im_path=im_path,
|
| 55 |
+
depth_path=depth_path,
|
| 56 |
+
out_img_path=out_img_path,
|
| 57 |
+
out_depth_path=out_depth_path,
|
| 58 |
+
dataset=dataset_name,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
#~500 frames in paper
|
| 62 |
+
out_json_path = osp.join(saved_dir, datatset_name, "nyuv2_video_500.json")
|
| 63 |
+
gen_json(
|
| 64 |
+
root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
|
| 65 |
+
start_id=0,end_id=500,step=1,
|
| 66 |
+
save_path=out_json_path)
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
# we use matlab to extract 8 scenes from NYUv2
|
| 70 |
+
#--basement_0001a, bookstore_0001a, cafe_0001a, classroom_0001a, kitchen_0003, office_0004, playroom_0002, study_0002
|
| 71 |
+
extract_scannet(
|
| 72 |
+
root="path/to/nyuv2",
|
| 73 |
+
saved_dir="./benchmark/datasets/",
|
| 74 |
+
sample_len=-1,
|
| 75 |
+
datatset_name="nyuv2",
|
| 76 |
+
)
|
code_depth/benchmark/dataset_extract/dataset_extract_scannet.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os.path as osp
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import csv
|
| 7 |
+
import cv2
|
| 8 |
+
import json
|
| 9 |
+
import glob
|
| 10 |
+
from natsort import natsorted
|
| 11 |
+
import shutil
|
| 12 |
+
|
| 13 |
+
from eval_utils import gen_json, gen_json_scannet_tae, get_sorted_files, copy_crop_files
|
| 14 |
+
|
| 15 |
+
def extract_scannet(
|
| 16 |
+
root,
|
| 17 |
+
sample_len=-1,
|
| 18 |
+
datatset_name="",
|
| 19 |
+
saved_dir="",
|
| 20 |
+
):
|
| 21 |
+
scenes_names = os.listdir(root)
|
| 22 |
+
scenes_names = sorted(scenes_names)[:100]
|
| 23 |
+
all_samples = []
|
| 24 |
+
for i, seq_name in enumerate(tqdm(scenes_names)):
|
| 25 |
+
all_img_names = get_sorted_files(
|
| 26 |
+
osp.join(root, seq_name, "color"), suffix=".jpg")
|
| 27 |
+
all_img_names = all_img_names[:510]
|
| 28 |
+
|
| 29 |
+
seq_len = len(all_img_names)
|
| 30 |
+
step = sample_len if sample_len > 0 else seq_len
|
| 31 |
+
|
| 32 |
+
for ref_idx in range(0, seq_len, step):
|
| 33 |
+
print(f"Progress: {seq_name}, {ref_idx // step + 1} / {seq_len//step}")
|
| 34 |
+
|
| 35 |
+
video_imgs = []
|
| 36 |
+
video_depths = []
|
| 37 |
+
|
| 38 |
+
if (ref_idx + step) <= seq_len:
|
| 39 |
+
ref_e = ref_idx + step
|
| 40 |
+
else:
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
for idx in range(ref_idx, ref_e):
|
| 44 |
+
im_path = osp.join(
|
| 45 |
+
root, seq_name, "color", all_img_names[idx]
|
| 46 |
+
)
|
| 47 |
+
depth_path = osp.join(
|
| 48 |
+
root, seq_name, "depth", all_img_names[idx][:-3] + "png"
|
| 49 |
+
)
|
| 50 |
+
pose_path = osp.join(
|
| 51 |
+
root, seq_name, "pose", all_img_names[idx][:-3] + "txt"
|
| 52 |
+
)
|
| 53 |
+
out_img_path = osp.join(
|
| 54 |
+
saved_dir, datatset_name, seq_name, "color", all_img_names[idx]
|
| 55 |
+
)
|
| 56 |
+
out_depth_path = osp.join(
|
| 57 |
+
saved_dir, datatset_name, seq_name, "depth", all_img_names[idx][:-3] + "png"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
copy_crop_files(
|
| 61 |
+
im_path=im_path,
|
| 62 |
+
depth_path=depth_path,
|
| 63 |
+
out_img_path=out_img_path,
|
| 64 |
+
out_depth_path=out_depth_path,
|
| 65 |
+
dataset=datatset_name,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
origin_img = np.array(Image.open(im_path))
|
| 69 |
+
out_img_origin_path = osp.join(
|
| 70 |
+
saved_dir, datatset_name, seq_name, "color_origin", all_img_names[idx]
|
| 71 |
+
)
|
| 72 |
+
out_pose_path = osp.join(
|
| 73 |
+
saved_dir, datatset_name, seq_name, "pose", all_img_names[idx][:-3] + "txt"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
os.makedirs(osp.dirname(out_img_origin_path), exist_ok=True)
|
| 77 |
+
os.makedirs(osp.dirname(out_pose_path), exist_ok=True)
|
| 78 |
+
|
| 79 |
+
cv2.imwrite(
|
| 80 |
+
out_img_origin_path,
|
| 81 |
+
origin_img,
|
| 82 |
+
)
|
| 83 |
+
shutil.copyfile(pose_path, out_pose_path)
|
| 84 |
+
|
| 85 |
+
intrinsic_path = osp.join(
|
| 86 |
+
root, seq_name, "intrinsic", "intrinsic_depth.txt"
|
| 87 |
+
)
|
| 88 |
+
out_intrinsic_path = osp.join(
|
| 89 |
+
saved_dir, datatset_name, seq_name, "intrinsic", "intrinsic_depth.txt"
|
| 90 |
+
)
|
| 91 |
+
os.makedirs(osp.dirname(out_intrinsic_path), exist_ok=True)
|
| 92 |
+
shutil.copyfile(intrinsic_path, out_intrinsic_path)
|
| 93 |
+
|
| 94 |
+
# 90 frames like DepthCraft
|
| 95 |
+
out_json_path = osp.join(saved_dir, datatset_name, "scannet_video.json")
|
| 96 |
+
gen_json(
|
| 97 |
+
root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
|
| 98 |
+
start_id=0,end_id=90*3,step=3,
|
| 99 |
+
save_path=out_json_path,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
#~500 frames in paper
|
| 103 |
+
out_json_path = osp.join(saved_dir, datatset_name, "scannet_video_500.json")
|
| 104 |
+
gen_json(
|
| 105 |
+
root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
|
| 106 |
+
start_id=0,end_id=500,step=1,
|
| 107 |
+
save_path=out_json_path,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# tae
|
| 111 |
+
out_json_path = osp.join(saved_dir, datatset_name, "scannet_video_tae.json")
|
| 112 |
+
gen_json_scannet_tae(
|
| 113 |
+
root_path=osp.join(saved_dir, datatset_name),
|
| 114 |
+
start_id=0,end_id=192,step=1,
|
| 115 |
+
save_path=out_json_path,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
extract_scannet(
|
| 120 |
+
root="path/to/scannet",
|
| 121 |
+
saved_dir="./benchmark/datasets/",
|
| 122 |
+
sample_len=-1,
|
| 123 |
+
datatset_name="scannet",
|
| 124 |
+
)
|
code_depth/benchmark/dataset_extract/dataset_extract_sintel.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
import os.path as osp
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import csv
|
| 15 |
+
import imageio
|
| 16 |
+
import cv2
|
| 17 |
+
import json
|
| 18 |
+
import glob
|
| 19 |
+
import shutil
|
| 20 |
+
|
| 21 |
+
from eval_utils import gen_json, get_sorted_files
|
| 22 |
+
|
| 23 |
+
TAG_FLOAT = 202021.25
|
| 24 |
+
TAG_CHAR = "PIEH"
|
| 25 |
+
|
| 26 |
+
def depth_read(filename):
|
| 27 |
+
"""Read depth data from file, return as numpy array."""
|
| 28 |
+
f = open(filename, "rb")
|
| 29 |
+
check = np.fromfile(f, dtype=np.float32, count=1)[0]
|
| 30 |
+
assert (
|
| 31 |
+
check == TAG_FLOAT
|
| 32 |
+
), " depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format(
|
| 33 |
+
TAG_FLOAT, check
|
| 34 |
+
)
|
| 35 |
+
width = np.fromfile(f, dtype=np.int32, count=1)[0]
|
| 36 |
+
height = np.fromfile(f, dtype=np.int32, count=1)[0]
|
| 37 |
+
size = width * height
|
| 38 |
+
assert (
|
| 39 |
+
width > 0 and height > 0 and size > 1 and size < 100000000
|
| 40 |
+
), " depth_read:: Wrong input size (width = {0}, height = {1}).".format(
|
| 41 |
+
width, height
|
| 42 |
+
)
|
| 43 |
+
depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width))
|
| 44 |
+
return depth
|
| 45 |
+
|
| 46 |
+
def extract_sintel(
|
| 47 |
+
root,
|
| 48 |
+
depth_root,
|
| 49 |
+
sample_len=-1,
|
| 50 |
+
datatset_name="",
|
| 51 |
+
saved_dir="",
|
| 52 |
+
):
|
| 53 |
+
scenes_names = os.listdir(root)
|
| 54 |
+
all_samples = []
|
| 55 |
+
for i, seq_name in enumerate(tqdm(scenes_names)):
|
| 56 |
+
all_img_names = get_sorted_files(
|
| 57 |
+
os.path.join(root, seq_name), suffix=".png")
|
| 58 |
+
|
| 59 |
+
seq_len = len(all_img_names)
|
| 60 |
+
step = sample_len if sample_len > 0 else seq_len
|
| 61 |
+
|
| 62 |
+
for ref_idx in range(0, seq_len, step):
|
| 63 |
+
print(f"Progress: {seq_name}, {ref_idx // step} / {seq_len // step}")
|
| 64 |
+
|
| 65 |
+
if (ref_idx + step) <= seq_len:
|
| 66 |
+
ref_e = ref_idx + step
|
| 67 |
+
else:
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
for idx in range(ref_idx, ref_e):
|
| 71 |
+
im_path = osp.join(
|
| 72 |
+
root, seq_name, all_img_names[idx]
|
| 73 |
+
)
|
| 74 |
+
depth_path = osp.join(
|
| 75 |
+
depth_root, seq_name, all_img_names[idx][:-3] + "dpt"
|
| 76 |
+
)
|
| 77 |
+
out_img_path = osp.join(
|
| 78 |
+
saved_dir, datatset_name,'clean', seq_name, all_img_names[idx]
|
| 79 |
+
)
|
| 80 |
+
out_depth_path = osp.join(
|
| 81 |
+
saved_dir, datatset_name,'depth', seq_name, all_img_names[idx][:-3] + "png"
|
| 82 |
+
)
|
| 83 |
+
depth = depth_read(depth_path)
|
| 84 |
+
img = np.array(Image.open(im_path))
|
| 85 |
+
|
| 86 |
+
os.makedirs(osp.dirname(out_img_path), exist_ok=True)
|
| 87 |
+
os.makedirs(osp.dirname(out_depth_path), exist_ok=True)
|
| 88 |
+
|
| 89 |
+
cv2.imwrite(
|
| 90 |
+
out_img_path,
|
| 91 |
+
img,
|
| 92 |
+
)
|
| 93 |
+
cv2.imwrite(
|
| 94 |
+
out_depth_path,
|
| 95 |
+
depth.astype(np.uint16)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
gen_json(
|
| 99 |
+
root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
|
| 100 |
+
start_id=0,end_id=100,step=1,
|
| 101 |
+
save_path=osp.join(saved_dir, datatset_name, "sintel_video.json"),)
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
extract_sintel(
|
| 105 |
+
root="path/to/training/clean",
|
| 106 |
+
depth_root="path/to/depth",
|
| 107 |
+
saved_dir="./benchmark/datasets/",
|
| 108 |
+
sample_len=-1,
|
| 109 |
+
datatset_name="sintel",
|
| 110 |
+
)
|
code_depth/benchmark/dataset_extract/eval_utils.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os.path as osp
|
| 4 |
+
import json
|
| 5 |
+
import glob
|
| 6 |
+
import cv2
|
| 7 |
+
import shutil
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from natsort import natsorted
|
| 10 |
+
|
| 11 |
+
def even_or_odd(num):
|
| 12 |
+
if num % 2 == 0:
|
| 13 |
+
return num
|
| 14 |
+
else:
|
| 15 |
+
return num - 1
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def gen_json(root_path, dataset, start_id, end_id, step, save_path=None):
|
| 19 |
+
rgb_name = "rgb"
|
| 20 |
+
if dataset == "kitti":
|
| 21 |
+
factor = 256.0
|
| 22 |
+
elif dataset == "nyuv2":
|
| 23 |
+
factor = 6000.0
|
| 24 |
+
elif dataset == "bonn":
|
| 25 |
+
factor = 5000.0
|
| 26 |
+
elif dataset == 'sintel':
|
| 27 |
+
factor = 65535 / 650
|
| 28 |
+
rgb_name = "clean"
|
| 29 |
+
elif dataset == 'scannet':
|
| 30 |
+
factor = 1000.0
|
| 31 |
+
rgb_name = "color"
|
| 32 |
+
else:
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
+
data = {}
|
| 36 |
+
data[dataset] = []
|
| 37 |
+
pieces = glob.glob(osp.join(root_path, "*"))
|
| 38 |
+
count = 0
|
| 39 |
+
for piece in pieces:
|
| 40 |
+
if not osp.isdir(piece):
|
| 41 |
+
continue
|
| 42 |
+
name = piece.split('/')[-1]
|
| 43 |
+
name_dict = {name:[]}
|
| 44 |
+
images = glob.glob(osp.join(piece, rgb_name, "*.png")) + glob.glob(osp.join(piece, rgb_name, "*.jpg"))
|
| 45 |
+
images = natsorted(images)
|
| 46 |
+
depths = glob.glob(osp.join(piece, "depth/*.png"))
|
| 47 |
+
depths = natsorted(depths)
|
| 48 |
+
images = images[start_id:end_id:step]
|
| 49 |
+
depths = depths[start_id:end_id:step]
|
| 50 |
+
|
| 51 |
+
for i in range(len(images)):
|
| 52 |
+
image = images[i]
|
| 53 |
+
xx = image[len(root_path)+1:]
|
| 54 |
+
depth = depths[i][len(root_path)+1:]
|
| 55 |
+
tmp = {}
|
| 56 |
+
tmp["image"] = xx
|
| 57 |
+
tmp["gt_depth"] = depth
|
| 58 |
+
tmp["factor"] = factor
|
| 59 |
+
name_dict[name].append(tmp)
|
| 60 |
+
data[dataset].append(name_dict)
|
| 61 |
+
with open(save_path, "w") as f:
|
| 62 |
+
json.dump(data, f, indent= 4)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def gen_json_scannet_tae(root_path, start_id, end_id, step, save_path=None):
|
| 66 |
+
data = {}
|
| 67 |
+
data["scannet"] = []
|
| 68 |
+
pieces = glob.glob(osp.join(root_path, "*"))
|
| 69 |
+
|
| 70 |
+
color = 'color_origin'
|
| 71 |
+
|
| 72 |
+
for piece in pieces:
|
| 73 |
+
if not osp.isdir(piece):
|
| 74 |
+
continue
|
| 75 |
+
name = piece.split('/')[-1]
|
| 76 |
+
name_dict = {name:[]}
|
| 77 |
+
images = glob.glob(osp.join(piece,color, "*.jpg"))
|
| 78 |
+
images = natsorted(images)
|
| 79 |
+
depths = glob.glob(osp.join(piece, "depth/*.png"))
|
| 80 |
+
depths = natsorted(depths)
|
| 81 |
+
images = images[start_id:end_id:step]
|
| 82 |
+
depths = depths[start_id:end_id:step]
|
| 83 |
+
print(f"sequence frame number: {piece}")
|
| 84 |
+
count = 0
|
| 85 |
+
for i in range(len(images)):
|
| 86 |
+
image = images[i]
|
| 87 |
+
xx = image[len(root_path)+1:]
|
| 88 |
+
depth = depths[i][len(root_path)+1:]
|
| 89 |
+
|
| 90 |
+
base_path = osp.dirname(image)
|
| 91 |
+
base_path = base_path.replace(color, 'intrinsic')
|
| 92 |
+
K = np.loadtxt(base_path + '/intrinsic_depth.txt')
|
| 93 |
+
|
| 94 |
+
pose_path = image.replace(color, 'pose').replace('.jpg', '.txt')
|
| 95 |
+
pose = np.loadtxt(pose_path)
|
| 96 |
+
|
| 97 |
+
tmp = {}
|
| 98 |
+
tmp["image"] = xx
|
| 99 |
+
tmp["gt_depth"] = depth
|
| 100 |
+
tmp["factor"] = 1000.0
|
| 101 |
+
tmp["K"] = K.tolist()
|
| 102 |
+
tmp["pose"] = pose.tolist()
|
| 103 |
+
name_dict[name].append(tmp)
|
| 104 |
+
data["scannet"].append(name_dict)
|
| 105 |
+
|
| 106 |
+
with open(save_path, "w") as f:
|
| 107 |
+
json.dump(data, f, indent= 4)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_sorted_files(root_path, suffix):
|
| 111 |
+
all_img_names = os.listdir(root_path)
|
| 112 |
+
all_img_names = [x for x in all_img_names if x.endswith(suffix)]
|
| 113 |
+
print(f"sequence frame number: {len(all_img_names)}")
|
| 114 |
+
|
| 115 |
+
all_img_names.sort()
|
| 116 |
+
all_img_names = sorted(all_img_names, key=lambda x: int(x.split(".")[0][-4:]))
|
| 117 |
+
|
| 118 |
+
return all_img_names
|
| 119 |
+
|
| 120 |
+
def copy_crop_files(im_path, depth_path, out_img_path, out_depth_path, dataset):
|
| 121 |
+
img = np.array(Image.open(im_path))
|
| 122 |
+
|
| 123 |
+
if dataset == "kitti" or dataset == "bonn":
|
| 124 |
+
height, width = img.shape[:2]
|
| 125 |
+
height = even_or_odd(height)
|
| 126 |
+
width = even_or_odd(width)
|
| 127 |
+
img = img[:height, :width]
|
| 128 |
+
elif dataset == "nyuv2":
|
| 129 |
+
img = img[45:471, 41:601, :]
|
| 130 |
+
elif dataset == "scannet":
|
| 131 |
+
img = img[8:-8, 11:-11, :]
|
| 132 |
+
|
| 133 |
+
os.makedirs(osp.dirname(out_img_path), exist_ok=True)
|
| 134 |
+
os.makedirs(osp.dirname(out_depth_path), exist_ok=True)
|
| 135 |
+
cv2.imwrite(
|
| 136 |
+
out_img_path,
|
| 137 |
+
img,
|
| 138 |
+
)
|
| 139 |
+
shutil.copyfile(depth_path, out_depth_path)
|
| 140 |
+
|
code_depth/benchmark/eval/eval.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
from scipy.ndimage import map_coordinates
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import os
|
| 11 |
+
import gc
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from metric import *
|
| 15 |
+
import metric
|
| 16 |
+
|
| 17 |
+
device = 'cuda'
|
| 18 |
+
eval_metrics = [
|
| 19 |
+
"abs_relative_difference",
|
| 20 |
+
"rmse_linear",
|
| 21 |
+
"delta1_acc",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
def get_infer(infer_path,args, target_size = None):
|
| 25 |
+
if infer_path.split('.')[-1] == 'npy':
|
| 26 |
+
img_gray = np.load(infer_path)
|
| 27 |
+
img_gray = img_gray.astype(np.float32)
|
| 28 |
+
infer_factor = 1.0
|
| 29 |
+
else:
|
| 30 |
+
img = cv2.imread(infer_path)
|
| 31 |
+
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 32 |
+
img_gray = img_gray.astype(np.float32)
|
| 33 |
+
infer_factor = 1.0 / 255.0
|
| 34 |
+
|
| 35 |
+
infer = img_gray / infer_factor
|
| 36 |
+
|
| 37 |
+
if target_size is not None:
|
| 38 |
+
if infer.shape[0] != target_size[0] or infer.shape[1] != target_size[1]:
|
| 39 |
+
infer = cv2.resize(infer, (target_size[1], target_size[0]))
|
| 40 |
+
return infer
|
| 41 |
+
|
| 42 |
+
def get_gt(depth_gt_path, gt_factor, args):
|
| 43 |
+
if depth_gt_path.split('.')[-1] == 'npy':
|
| 44 |
+
depth_gt = np.load(depth_gt_path)
|
| 45 |
+
else:
|
| 46 |
+
depth_gt = cv2.imread(depth_gt_path, -1)
|
| 47 |
+
depth_gt = np.array(depth_gt)
|
| 48 |
+
depth_gt = depth_gt / gt_factor
|
| 49 |
+
depth_gt[depth_gt==0] = -1
|
| 50 |
+
return depth_gt
|
| 51 |
+
|
| 52 |
+
def get_flow(flow_path):
|
| 53 |
+
assert os.path.exists(flow_path)
|
| 54 |
+
flow = np.load(flow_path, allow_pickle=True)
|
| 55 |
+
return flow
|
| 56 |
+
def depth2disparity(depth, return_mask=False):
|
| 57 |
+
if isinstance(depth, np.ndarray):
|
| 58 |
+
disparity = np.zeros_like(depth)
|
| 59 |
+
non_negtive_mask = depth > 0
|
| 60 |
+
disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
|
| 61 |
+
if return_mask:
|
| 62 |
+
return disparity, non_negtive_mask
|
| 63 |
+
else:
|
| 64 |
+
return disparity
|
| 65 |
+
|
| 66 |
+
def eval_depthcrafter(infer_paths, depth_gt_paths, factors, args):
|
| 67 |
+
depth_errors = []
|
| 68 |
+
gts = []
|
| 69 |
+
infs = []
|
| 70 |
+
seq_length = args.max_eval_len
|
| 71 |
+
dataset_max_depth = args.max_depth_eval
|
| 72 |
+
for i in range(len(infer_paths)):
|
| 73 |
+
if not os.path.exists(infer_paths[i]):
|
| 74 |
+
continue
|
| 75 |
+
depth_gt = get_gt(depth_gt_paths[i], factors[i], args)
|
| 76 |
+
depth_gt = depth_gt[args.a:args.b, args.c:args.d]
|
| 77 |
+
|
| 78 |
+
infer = get_infer(infer_paths[i], args, target_size=depth_gt.shape)
|
| 79 |
+
gts.append(depth_gt)
|
| 80 |
+
infs.append(infer)
|
| 81 |
+
gts = np.stack(gts, axis=0)
|
| 82 |
+
|
| 83 |
+
infs = np.stack(infs, axis=0)
|
| 84 |
+
infs = infs[:seq_length]
|
| 85 |
+
gts = gts[:seq_length]
|
| 86 |
+
valid_mask = np.logical_and((gts>1e-3), (gts<dataset_max_depth))
|
| 87 |
+
|
| 88 |
+
gt_disp_masked = 1. / (gts[valid_mask].reshape((-1,1)).astype(np.float64) + 1e-8)
|
| 89 |
+
infs = np.clip(infs, a_min=1e-3, a_max=None)
|
| 90 |
+
pred_disp_masked = infs[valid_mask].reshape((-1,1)).astype(np.float64)
|
| 91 |
+
|
| 92 |
+
_ones = np.ones_like(pred_disp_masked)
|
| 93 |
+
A = np.concatenate([pred_disp_masked, _ones], axis=-1)
|
| 94 |
+
X = np.linalg.lstsq(A, gt_disp_masked, rcond=None)[0]
|
| 95 |
+
scale, shift = X
|
| 96 |
+
aligned_pred = scale * infs + shift
|
| 97 |
+
aligned_pred = np.clip(aligned_pred, a_min=1e-3, a_max=None)
|
| 98 |
+
|
| 99 |
+
pred_depth = depth2disparity(aligned_pred)
|
| 100 |
+
gt_depth = gts
|
| 101 |
+
pred_depth = np.clip(
|
| 102 |
+
pred_depth, a_min=1e-3, a_max=dataset_max_depth
|
| 103 |
+
)
|
| 104 |
+
sample_metric = []
|
| 105 |
+
metric_funcs = [getattr(metric, _met) for _met in eval_metrics]
|
| 106 |
+
|
| 107 |
+
pred_depth_ts = torch.from_numpy(pred_depth).to(device)
|
| 108 |
+
gt_depth_ts = torch.from_numpy(gt_depth).to(device)
|
| 109 |
+
valid_mask_ts = torch.from_numpy(valid_mask).to(device)
|
| 110 |
+
|
| 111 |
+
n = valid_mask.sum((-1, -2))
|
| 112 |
+
valid_frame = (n > 0)
|
| 113 |
+
pred_depth_ts = pred_depth_ts[valid_frame]
|
| 114 |
+
gt_depth_ts = gt_depth_ts[valid_frame]
|
| 115 |
+
valid_mask_ts = valid_mask_ts[valid_frame]
|
| 116 |
+
|
| 117 |
+
for met_func in metric_funcs:
|
| 118 |
+
_metric_name = met_func.__name__
|
| 119 |
+
_metric = met_func(pred_depth_ts, gt_depth_ts, valid_mask_ts).item()
|
| 120 |
+
sample_metric.append(_metric)
|
| 121 |
+
return sample_metric
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def main():
|
| 125 |
+
|
| 126 |
+
parser = argparse.ArgumentParser()
|
| 127 |
+
parser.add_argument('--infer_path', type=str, default='')
|
| 128 |
+
parser.add_argument('--infer_type', type=str, default='npy')
|
| 129 |
+
parser.add_argument('--benchmark_path', type=str, default='')
|
| 130 |
+
parser.add_argument('--datasets', type=str, nargs='+', default=['vkitti', 'kitti', 'sintel', 'nyu_v2', 'tartanair', 'bonn', 'ip_lidar'])
|
| 131 |
+
|
| 132 |
+
args = parser.parse_args()
|
| 133 |
+
|
| 134 |
+
results_save_path = os.path.join(args.infer_path, 'results.txt')
|
| 135 |
+
|
| 136 |
+
for dataset in args.datasets:
|
| 137 |
+
|
| 138 |
+
file = open(results_save_path, 'a')
|
| 139 |
+
|
| 140 |
+
if dataset == 'kitti':
|
| 141 |
+
args.json_file = os.path.join(args.benchmark_path,'kitti/kitti_video.json')
|
| 142 |
+
args.root_path = os.path.join(args.benchmark_path,'kitti')
|
| 143 |
+
args.max_depth_eval = 80.0
|
| 144 |
+
args.min_depth_eval = 0.1
|
| 145 |
+
args.max_eval_len = 110
|
| 146 |
+
args.a = 0
|
| 147 |
+
args.b = 374
|
| 148 |
+
args.c = 0
|
| 149 |
+
args.d = 1242
|
| 150 |
+
if dataset == 'kitti_500':
|
| 151 |
+
dataset = 'kitti'
|
| 152 |
+
args.json_file = os.path.join(args.benchmark_path,'kitti/kitti_video_500.json')
|
| 153 |
+
args.root_path = os.path.join(args.benchmark_path,'kitti')
|
| 154 |
+
args.max_depth_eval = 80.0
|
| 155 |
+
args.min_depth_eval = 0.1
|
| 156 |
+
args.max_eval_len = 500
|
| 157 |
+
args.a = 0
|
| 158 |
+
args.b = 374
|
| 159 |
+
args.c = 0
|
| 160 |
+
args.d = 1242
|
| 161 |
+
elif dataset == 'sintel':
|
| 162 |
+
args.json_file = os.path.join(args.benchmark_path,'sintel/sintel_video.json')
|
| 163 |
+
args.root_path = os.path.join(args.benchmark_path,'sintel')
|
| 164 |
+
args.max_depth_eval = 70
|
| 165 |
+
args.min_depth_eval = 0.1
|
| 166 |
+
args.max_eval_len = 100
|
| 167 |
+
args.a = 0
|
| 168 |
+
args.b = 436
|
| 169 |
+
args.c = 0
|
| 170 |
+
args.d = 1024
|
| 171 |
+
elif dataset == 'nyuv2_500':
|
| 172 |
+
dataset = 'nyuv2'
|
| 173 |
+
args.json_file = os.path.join(args.benchmark_path,'nyuv2/nyuv2_video_500.json')
|
| 174 |
+
args.root_path = os.path.join(args.benchmark_path,'nyuv2')
|
| 175 |
+
args.max_depth_eval = 10.0
|
| 176 |
+
args.min_depth_eval = 0.1
|
| 177 |
+
args.max_eval_len = 500
|
| 178 |
+
args.a = 45
|
| 179 |
+
args.b = 471
|
| 180 |
+
args.c = 41
|
| 181 |
+
args.d = 601
|
| 182 |
+
elif dataset == 'bonn':
|
| 183 |
+
args.json_file = os.path.join(args.benchmark_path,'bonn/bonn_video.json')
|
| 184 |
+
args.root_path = os.path.join(args.benchmark_path,'bonn')
|
| 185 |
+
args.max_depth_eval = 10.0
|
| 186 |
+
args.min_depth_eval = 0.1
|
| 187 |
+
args.max_eval_len = 110
|
| 188 |
+
args.a = 0
|
| 189 |
+
args.b = 480
|
| 190 |
+
args.c = 0
|
| 191 |
+
args.d = 640
|
| 192 |
+
elif dataset == 'bonn_500':
|
| 193 |
+
dataset = 'bonn'
|
| 194 |
+
args.json_file = os.path.join(args.benchmark_path,'bonn/bonn_video_500.json')
|
| 195 |
+
args.root_path = os.path.join(args.benchmark_path,'bonn')
|
| 196 |
+
args.max_depth_eval = 10.0
|
| 197 |
+
args.min_depth_eval = 0.1
|
| 198 |
+
args.max_eval_len = 500
|
| 199 |
+
args.a = 0
|
| 200 |
+
args.b = 480
|
| 201 |
+
args.c = 0
|
| 202 |
+
args.d = 640
|
| 203 |
+
elif dataset == 'scannet':
|
| 204 |
+
args.json_file = os.path.join(args.benchmark_path,'scannet/scannet_video.json')
|
| 205 |
+
args.root_path = os.path.join(args.benchmark_path,'scannet')
|
| 206 |
+
args.max_depth_eval = 10.0
|
| 207 |
+
args.min_depth_eval = 0.1
|
| 208 |
+
args.max_eval_len = 90
|
| 209 |
+
args.a = 8
|
| 210 |
+
args.b = -8
|
| 211 |
+
args.c = 11
|
| 212 |
+
args.d = -11
|
| 213 |
+
elif dataset == 'scannet_500':
|
| 214 |
+
dataset = 'scannet'
|
| 215 |
+
args.json_file = os.path.join(args.benchmark_path,'scannet/scannet_video_500.json')
|
| 216 |
+
args.root_path = os.path.join(args.benchmark_path,'scannet')
|
| 217 |
+
args.max_depth_eval = 10.0
|
| 218 |
+
args.min_depth_eval = 0.1
|
| 219 |
+
args.max_eval_len = 500
|
| 220 |
+
args.a = 8
|
| 221 |
+
args.b = -8
|
| 222 |
+
args.c = 11
|
| 223 |
+
args.d = -11
|
| 224 |
+
|
| 225 |
+
with open(args.json_file, 'r') as fs:
|
| 226 |
+
path_json = json.load(fs)
|
| 227 |
+
|
| 228 |
+
json_data = path_json[dataset]
|
| 229 |
+
scale_stds = shift_stds = stable_result_fulls = stable_result_wins = 0
|
| 230 |
+
depth_result_fulls = np.zeros(5)
|
| 231 |
+
depth_result_wins = np.zeros(5)
|
| 232 |
+
depth_result_onlys = np.zeros(5)
|
| 233 |
+
count = 0
|
| 234 |
+
line = '-' * 50
|
| 235 |
+
print(f'<{line} {dataset} start {line}>')
|
| 236 |
+
file.write(f'<{line} {dataset} start {line}>\n')
|
| 237 |
+
results_all = []
|
| 238 |
+
for data in tqdm(json_data):
|
| 239 |
+
for key in data.keys():
|
| 240 |
+
value = data[key]
|
| 241 |
+
infer_paths = []
|
| 242 |
+
depth_gt_paths = []
|
| 243 |
+
flow_paths = []
|
| 244 |
+
factors = []
|
| 245 |
+
for images in value:
|
| 246 |
+
infer_path = (args.infer_path + '/'+ dataset + '/' + images['image']).replace('.jpg', '.npy').replace('.png', '.npy')
|
| 247 |
+
|
| 248 |
+
infer_paths.append(infer_path)
|
| 249 |
+
depth_gt_paths.append(args.root_path + '/' + images['gt_depth'])
|
| 250 |
+
factors.append(images['factor'])
|
| 251 |
+
infer_paths = infer_paths[:args.max_eval_len]
|
| 252 |
+
depth_gt_paths = depth_gt_paths[:args.max_eval_len]
|
| 253 |
+
factors = factors[:args.max_eval_len]
|
| 254 |
+
results_single = eval_depthcrafter(infer_paths, depth_gt_paths, factors, args)
|
| 255 |
+
results_all.append(results_single)
|
| 256 |
+
final_results = np.array(results_all)
|
| 257 |
+
final_results_mean = np.mean(final_results, axis=0)
|
| 258 |
+
result_dict = { 'name': dataset }
|
| 259 |
+
for i, metric in enumerate(eval_metrics):
|
| 260 |
+
result_dict[metric] = final_results_mean[i]
|
| 261 |
+
print(f"{metric}: {final_results_mean[i]:04f}")
|
| 262 |
+
file.write(f"{metric}: {final_results_mean[i]:04f}\n")
|
| 263 |
+
file.write(f'<{line} {dataset} finish {line}>\n')
|
| 264 |
+
if __name__ == '__main__':
|
| 265 |
+
main()
|
code_depth/benchmark/eval/eval.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
set -x
|
| 3 |
+
set -e
|
| 4 |
+
|
| 5 |
+
pred_disp_root=$1 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] prediction
|
| 6 |
+
benchmark_root=$2 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] ground truth
|
| 7 |
+
|
| 8 |
+
#eval sintel
|
| 9 |
+
python3 benchmark/eval/eval.py \
|
| 10 |
+
--infer_path $pred_disp_root \
|
| 11 |
+
--benchmark_path $benchmark_root \
|
| 12 |
+
--datasets sintel
|
| 13 |
+
|
| 14 |
+
#eval scannet
|
| 15 |
+
python3 benchmark/eval/eval.py \
|
| 16 |
+
--infer_path $pred_disp_root \
|
| 17 |
+
--benchmark_path $benchmark_root \
|
| 18 |
+
--datasets scannet
|
| 19 |
+
|
| 20 |
+
#eval kitti
|
| 21 |
+
python3 benchmark/eval/eval.py \
|
| 22 |
+
--infer_path $pred_disp_root \
|
| 23 |
+
--benchmark_path $benchmark_root \
|
| 24 |
+
--datasets kitti
|
| 25 |
+
|
| 26 |
+
#eval bonn
|
| 27 |
+
python3 benchmark/eval/eval.py \
|
| 28 |
+
--infer_path $pred_disp_root \
|
| 29 |
+
--benchmark_path $benchmark_root \
|
| 30 |
+
--datasets bonn
|
code_depth/benchmark/eval/eval_500.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
set -x
|
| 3 |
+
set -e
|
| 4 |
+
|
| 5 |
+
pred_disp_root=$1 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] prediction
|
| 6 |
+
benchmark_root=$2 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] ground truth
|
| 7 |
+
|
| 8 |
+
#eval scannet
|
| 9 |
+
python3 benchmark/eval/eval.py \
|
| 10 |
+
--infer_path $pred_disp_root \
|
| 11 |
+
--benchmark_path $benchmark_root \
|
| 12 |
+
--datasets scannet_500
|
| 13 |
+
|
| 14 |
+
#eval kitti
|
| 15 |
+
python3 benchmark/eval/eval.py \
|
| 16 |
+
--infer_path $pred_disp_root \
|
| 17 |
+
--benchmark_path $benchmark_root \
|
| 18 |
+
--datasets kitti_500
|
| 19 |
+
|
| 20 |
+
#eval bonn
|
| 21 |
+
python3 benchmark/eval/eval.py \
|
| 22 |
+
--infer_path $pred_disp_root \
|
| 23 |
+
--benchmark_path $benchmark_root \
|
| 24 |
+
--datasets bonn_500
|
| 25 |
+
|
| 26 |
+
#eval nyu
|
| 27 |
+
python3 benchmark/eval/eval.py \
|
| 28 |
+
--infer_path $pred_disp_root \
|
| 29 |
+
--benchmark_path $benchmark_root \
|
| 30 |
+
--datasets nyuv2_500
|
code_depth/benchmark/eval/eval_tae.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import json
|
| 5 |
+
import argparse
|
| 6 |
+
from scipy.ndimage import map_coordinates
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import os
|
| 9 |
+
import gc
|
| 10 |
+
import time
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 14 |
+
|
| 15 |
+
def compute_errors_torch(gt, pred):
|
| 16 |
+
abs_rel = torch.mean(torch.abs(gt - pred) / gt)
|
| 17 |
+
return abs_rel
|
| 18 |
+
|
| 19 |
+
def get_infer(infer_path,args, target_size = None):
|
| 20 |
+
if infer_path.split('.')[-1] == 'npy':
|
| 21 |
+
img_gray = np.load(infer_path)
|
| 22 |
+
img_gray = img_gray.astype(np.float32)
|
| 23 |
+
infer_factor = 1.0
|
| 24 |
+
else:
|
| 25 |
+
img = cv2.imread(infer_path)
|
| 26 |
+
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 27 |
+
img_gray = img_gray.astype(np.float32)
|
| 28 |
+
infer_factor = 1.0 / 255.0
|
| 29 |
+
|
| 30 |
+
infer = img_gray / infer_factor
|
| 31 |
+
if args.hard_crop:
|
| 32 |
+
infer = infer[args.a:args.b, args.c:args.d]
|
| 33 |
+
|
| 34 |
+
if target_size is not None:
|
| 35 |
+
if infer.shape[0] != target_size[0] or infer.shape[1] != target_size[1]:
|
| 36 |
+
infer = cv2.resize(infer, (target_size[1], target_size[0]))
|
| 37 |
+
return infer
|
| 38 |
+
|
| 39 |
+
def get_gt(depth_gt_path, gt_factor, args):
|
| 40 |
+
if depth_gt_path.split('.')[-1] == 'npy':
|
| 41 |
+
depth_gt = np.load(depth_gt_path)
|
| 42 |
+
else:
|
| 43 |
+
depth_gt = cv2.imread(depth_gt_path, -1)
|
| 44 |
+
depth_gt = np.array(depth_gt)
|
| 45 |
+
depth_gt = depth_gt / gt_factor
|
| 46 |
+
|
| 47 |
+
depth_gt[depth_gt==0] = 0
|
| 48 |
+
return depth_gt
|
| 49 |
+
|
| 50 |
+
def depth2disparity(depth, return_mask=False):
|
| 51 |
+
if isinstance(depth, np.ndarray):
|
| 52 |
+
disparity = np.zeros_like(depth)
|
| 53 |
+
non_negtive_mask = depth > 0
|
| 54 |
+
disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
|
| 55 |
+
if return_mask:
|
| 56 |
+
return disparity, non_negtive_mask
|
| 57 |
+
else:
|
| 58 |
+
return disparity
|
| 59 |
+
|
| 60 |
+
def tae_torch(depth1, depth2, R_2_1, T_2_1, K, mask):
|
| 61 |
+
H, W = depth1.shape
|
| 62 |
+
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
|
| 63 |
+
|
| 64 |
+
# Generate meshgrid
|
| 65 |
+
xx, yy = torch.meshgrid(torch.arange(W), torch.arange(H))
|
| 66 |
+
xx, yy = xx.t(), yy.t() # Transpose to match the shape (H, W)
|
| 67 |
+
|
| 68 |
+
# Convert meshgrid to tensor
|
| 69 |
+
xx = xx.to(dtype=depth1.dtype, device=depth1.device)
|
| 70 |
+
yy = yy.to(dtype=depth1.dtype, device=depth1.device)
|
| 71 |
+
# Calculate 3D points in frame 1
|
| 72 |
+
X = (xx - cx) * depth1 / fx
|
| 73 |
+
Y = (yy - cy) * depth1 / fy
|
| 74 |
+
Z = depth1
|
| 75 |
+
points3d = torch.stack((X.flatten(), Y.flatten(), Z.flatten()), dim=1) # Shape (H*W, 3)
|
| 76 |
+
T = torch.tensor(T_2_1, dtype=depth1.dtype, device=depth1.device)
|
| 77 |
+
|
| 78 |
+
# Transform 3D points to frame 2
|
| 79 |
+
points3d_transformed = torch.matmul(points3d, R_2_1.T) + T
|
| 80 |
+
X_world, Y_world, Z_world = points3d_transformed[:, 0], points3d_transformed[:, 1], points3d_transformed[:, 2]
|
| 81 |
+
# Project 3D points to 2D plane using intrinsic matrix
|
| 82 |
+
X_plane = (X_world * fx) / Z_world + cx
|
| 83 |
+
Y_plane = (Y_world * fy) / Z_world + cy
|
| 84 |
+
|
| 85 |
+
# Round and convert to integers
|
| 86 |
+
X_plane = torch.round(X_plane).to(dtype=torch.long)
|
| 87 |
+
Y_plane = torch.round(Y_plane).to(dtype=torch.long)
|
| 88 |
+
|
| 89 |
+
# Filter valid indices
|
| 90 |
+
valid_mask = (X_plane >= 0) & (X_plane < W) & (Y_plane >= 0) & (Y_plane < H)
|
| 91 |
+
if valid_mask.sum() == 0:
|
| 92 |
+
return 0
|
| 93 |
+
|
| 94 |
+
depth_proj = torch.zeros((H, W), dtype=depth1.dtype, device=depth1.device)
|
| 95 |
+
|
| 96 |
+
valid_X = X_plane[valid_mask]
|
| 97 |
+
valid_Y = Y_plane[valid_mask]
|
| 98 |
+
valid_Z = Z_world[valid_mask]
|
| 99 |
+
|
| 100 |
+
depth_proj[valid_Y, valid_X] = valid_Z
|
| 101 |
+
|
| 102 |
+
valid_mask = (depth_proj > 0) & (depth2 > 0) & (mask)
|
| 103 |
+
if valid_mask.sum() == 0:
|
| 104 |
+
return 0
|
| 105 |
+
abs_errors = compute_errors_torch(depth2[valid_mask], depth_proj[valid_mask])
|
| 106 |
+
|
| 107 |
+
return abs_errors
|
| 108 |
+
|
| 109 |
+
def eval_TAE(infer_paths, depth_gt_paths, factors, masks, Ks, poses, args):
|
| 110 |
+
gts = []
|
| 111 |
+
infs = []
|
| 112 |
+
dataset_max_depth = args.max_depth_eval
|
| 113 |
+
gt_paths_cur = []
|
| 114 |
+
Ks_cur = []
|
| 115 |
+
poses_cur = []
|
| 116 |
+
masks_cur = []
|
| 117 |
+
|
| 118 |
+
for i in range(len(infer_paths)):
|
| 119 |
+
# DAV missing some frames
|
| 120 |
+
if not os.path.exists(infer_paths[i]):
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
depth_gt = get_gt(depth_gt_paths[i], factors[i], args)
|
| 124 |
+
depth_gt = depth_gt[args.a:args.b, args.c:args.d]
|
| 125 |
+
|
| 126 |
+
gt_paths_cur.append(depth_gt_paths[i])
|
| 127 |
+
infer = get_infer(infer_paths[i], args, target_size=depth_gt.shape)
|
| 128 |
+
|
| 129 |
+
gts.append(depth_gt)
|
| 130 |
+
infs.append(infer)
|
| 131 |
+
Ks_cur.append(Ks[i])
|
| 132 |
+
poses_cur.append(poses[i])
|
| 133 |
+
if args.mask:
|
| 134 |
+
masks_cur.append(masks[i])
|
| 135 |
+
|
| 136 |
+
gts = np.stack(gts, axis=0)
|
| 137 |
+
infs = np.stack(infs, axis=0)
|
| 138 |
+
|
| 139 |
+
valid_mask = np.logical_and((gts>1e-3), (gts<dataset_max_depth))
|
| 140 |
+
|
| 141 |
+
gt_disp_masked = 1. / (gts[valid_mask].reshape((-1,1)).astype(np.float64) + 1e-8)
|
| 142 |
+
infs = np.clip(infs, a_min=1e-3, a_max=None)
|
| 143 |
+
pred_disp_masked = infs[valid_mask].reshape((-1,1)).astype(np.float64)
|
| 144 |
+
|
| 145 |
+
_ones = np.ones_like(pred_disp_masked)
|
| 146 |
+
A = np.concatenate([pred_disp_masked, _ones], axis=-1)
|
| 147 |
+
X = np.linalg.lstsq(A, gt_disp_masked, rcond=None)[0]
|
| 148 |
+
scale, shift = X
|
| 149 |
+
|
| 150 |
+
aligned_pred = scale * infs + shift
|
| 151 |
+
aligned_pred = np.clip(aligned_pred, a_min=1e-3, a_max=None)
|
| 152 |
+
|
| 153 |
+
pred_depth = depth2disparity(aligned_pred)
|
| 154 |
+
gt_depth = gts
|
| 155 |
+
pred_depth = np.clip(
|
| 156 |
+
pred_depth, a_min=1e-3, a_max=dataset_max_depth
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
error_sum = 0.
|
| 160 |
+
for i in range(len(gt_paths_cur) -1):
|
| 161 |
+
depth1 = pred_depth[i]
|
| 162 |
+
depth2 = pred_depth[i+1]
|
| 163 |
+
|
| 164 |
+
gt_depth1 = gt_paths_cur[i]
|
| 165 |
+
gt_depth2 = gt_paths_cur[i+1]
|
| 166 |
+
T_1 = poses_cur[i]
|
| 167 |
+
T_2 = poses_cur[i+1]
|
| 168 |
+
|
| 169 |
+
T_2_1 = np.linalg.inv(T_2) @ T_1
|
| 170 |
+
|
| 171 |
+
R_2_1 = T_2_1[:3,:3]
|
| 172 |
+
t_2_1 = T_2_1[:3, 3]
|
| 173 |
+
K = Ks_cur[i]
|
| 174 |
+
|
| 175 |
+
if args.mask:
|
| 176 |
+
mask_path1 = masks_cur[i]
|
| 177 |
+
mask_path2 = masks_cur[i+1]
|
| 178 |
+
mask1 = cv2.imread(mask_path1, -1)
|
| 179 |
+
mask2 = cv2.imread(mask_path2, -1)
|
| 180 |
+
mask1 = mask1[args.a:args.b, args.c:args.d]
|
| 181 |
+
if mask2 is None:
|
| 182 |
+
mask2 = np.ones_like(mask1)
|
| 183 |
+
else:
|
| 184 |
+
mask2 = mask2[args.a:args.b, args.c:args.d]
|
| 185 |
+
|
| 186 |
+
mask1 = mask1 > 0
|
| 187 |
+
mask2 = mask2 > 0
|
| 188 |
+
else:
|
| 189 |
+
mask1 = np.ones_like(depth1)
|
| 190 |
+
mask2 = np.ones_like(depth2)
|
| 191 |
+
|
| 192 |
+
mask1 = mask1 > 0
|
| 193 |
+
mask2 = mask2 > 0
|
| 194 |
+
|
| 195 |
+
depth1 = torch.from_numpy(depth1).to(device=device)
|
| 196 |
+
depth2 = torch.from_numpy(depth2).to(device=device)
|
| 197 |
+
R_2_1 = torch.from_numpy(R_2_1).to(device=device)
|
| 198 |
+
t_2_1 = torch.from_numpy(t_2_1).to(device=device)
|
| 199 |
+
mask1 = torch.from_numpy(mask1).to(device=device)
|
| 200 |
+
mask2 = torch.from_numpy(mask2).to(device=device)
|
| 201 |
+
|
| 202 |
+
error1 = tae_torch(depth1, depth2, R_2_1, t_2_1, K, mask2)
|
| 203 |
+
T_1_2 = np.linalg.inv(T_2_1)
|
| 204 |
+
R_1_2 = T_1_2[:3,:3]
|
| 205 |
+
t_1_2 = T_1_2[:3, 3]
|
| 206 |
+
|
| 207 |
+
R_1_2 = torch.from_numpy(R_1_2).to(device=device)
|
| 208 |
+
t_1_2 = torch.from_numpy(t_1_2).to(device=device)
|
| 209 |
+
|
| 210 |
+
error2 = tae_torch(depth2, depth1, R_1_2, t_1_2, K, mask1)
|
| 211 |
+
|
| 212 |
+
error_sum += error1
|
| 213 |
+
error_sum += error2
|
| 214 |
+
|
| 215 |
+
gc.collect()
|
| 216 |
+
result = error_sum / (2 * (len(gt_paths_cur) -1))
|
| 217 |
+
return result*100
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == '__main__':
|
| 221 |
+
parser = argparse.ArgumentParser()
|
| 222 |
+
parser.add_argument('--infer_path', type=str, default='')
|
| 223 |
+
parser.add_argument('--benchmark_path', type=str, default='')
|
| 224 |
+
|
| 225 |
+
parser.add_argument('--datasets', type=str, nargs='+', default=['scannet', 'sintel'])
|
| 226 |
+
parser.add_argument('--start_idx', type=int, default=0)
|
| 227 |
+
parser.add_argument('--end_idx', type=int, default=180)
|
| 228 |
+
parser.add_argument('--eval_scenes_num', type=int, default=20)
|
| 229 |
+
parser.add_argument('--hard_crop', action='store_true', default=False)
|
| 230 |
+
|
| 231 |
+
args = parser.parse_args()
|
| 232 |
+
|
| 233 |
+
results_save_path = os.path.join(args.infer_path, 'results.txt')
|
| 234 |
+
|
| 235 |
+
for dataset in args.datasets:
|
| 236 |
+
|
| 237 |
+
file = open(results_save_path, 'a')
|
| 238 |
+
if dataset == 'scannet':
|
| 239 |
+
args.json_file = os.path.join(args.benchmark_path,'scannet/scannet_video.json')
|
| 240 |
+
args.root_path = os.path.join(args.benchmark_path, 'scannet/')
|
| 241 |
+
args.max_depth_eval = 10.0
|
| 242 |
+
args.min_depth_eval = 0.1
|
| 243 |
+
args.max_eval_len = 200
|
| 244 |
+
args.mask = False
|
| 245 |
+
#DepthCrafer crop
|
| 246 |
+
args.a = 8
|
| 247 |
+
args.b = -8
|
| 248 |
+
args.c = 11
|
| 249 |
+
args.d = -11
|
| 250 |
+
|
| 251 |
+
with open(args.json_file, 'r') as fs:
|
| 252 |
+
path_json = json.load(fs)
|
| 253 |
+
|
| 254 |
+
json_data = path_json[dataset]
|
| 255 |
+
count = 0
|
| 256 |
+
line = '-' * 50
|
| 257 |
+
print(f'<{line} {dataset} start {line}>')
|
| 258 |
+
file.write(f'<{line} {dataset} start {line}>\n')
|
| 259 |
+
results_all = 0.
|
| 260 |
+
|
| 261 |
+
for data in tqdm(json_data[:args.eval_scenes_num]):
|
| 262 |
+
for scene_name in data.keys():
|
| 263 |
+
value = data[scene_name]
|
| 264 |
+
infer_paths = []
|
| 265 |
+
depth_gt_paths = []
|
| 266 |
+
factors = []
|
| 267 |
+
Ks = []
|
| 268 |
+
poses = []
|
| 269 |
+
masks = []
|
| 270 |
+
for images in value:
|
| 271 |
+
infer_path = (args.infer_path + '/'+ dataset + '/' + images['image']).replace('.jpg', '.npy').replace('.png', '.npy')
|
| 272 |
+
|
| 273 |
+
infer_paths.append(infer_path)
|
| 274 |
+
depth_gt_paths.append(args.root_path + '/' + images['gt_depth'])
|
| 275 |
+
factors.append(images['factor'])
|
| 276 |
+
Ks.append(np.array(images['K']))
|
| 277 |
+
poses.append(np.array(images['pose']))
|
| 278 |
+
|
| 279 |
+
if args.mask:
|
| 280 |
+
masks.append(args.root_path + '/' + images['mask'])
|
| 281 |
+
|
| 282 |
+
infer_paths = infer_paths[args.start_idx:args.end_idx]
|
| 283 |
+
depth_gt_paths = depth_gt_paths[args.start_idx:args.end_idx]
|
| 284 |
+
factors = factors[args.start_idx:args.end_idx]
|
| 285 |
+
poses = poses[args.start_idx:args.end_idx]
|
| 286 |
+
Ks = Ks[args.start_idx:args.end_idx]
|
| 287 |
+
error = eval_TAE(infer_paths, depth_gt_paths, factors,masks,Ks,poses,args)
|
| 288 |
+
results_all += error
|
| 289 |
+
count += 1
|
| 290 |
+
|
| 291 |
+
print(dataset,': ','tae ', results_all / count)
|
| 292 |
+
file.write(f'{dataset}: {results_all / count}\n')
|
| 293 |
+
file.write(f'<{line} {dataset} finish {line}>\n')
|
| 294 |
+
|
| 295 |
+
|
code_depth/benchmark/eval/eval_tae.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
set -x
|
| 3 |
+
set -e
|
| 4 |
+
|
| 5 |
+
pred_disp_root=$1 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] prediction
|
| 6 |
+
benchmark_root=$2 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] ground truth
|
| 7 |
+
|
| 8 |
+
#eval scannet
|
| 9 |
+
python3 benchmark/eval/eval_tae.py \
|
| 10 |
+
--infer_path $pred_disp_root \
|
| 11 |
+
--benchmark_path $benchmark_root \
|
| 12 |
+
--datasets scannet \
|
| 13 |
+
--start_idx 10 \
|
| 14 |
+
--end_idx 180 \
|
| 15 |
+
--eval_scenes_num 20 \
|
| 16 |
+
--hard_crop
|
| 17 |
+
|
| 18 |
+
|
code_depth/benchmark/eval/metric.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def abs_relative_difference(output, target, valid_mask=None):
|
| 4 |
+
actual_output = output
|
| 5 |
+
actual_target = target
|
| 6 |
+
abs_relative_diff = torch.abs(actual_output - actual_target) / actual_target
|
| 7 |
+
if valid_mask is not None:
|
| 8 |
+
abs_relative_diff[~valid_mask] = 0
|
| 9 |
+
n = valid_mask.sum((-1, -2))
|
| 10 |
+
else:
|
| 11 |
+
n = output.shape[-1] * output.shape[-2]
|
| 12 |
+
abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n
|
| 13 |
+
return abs_relative_diff.mean()
|
| 14 |
+
|
| 15 |
+
def squared_relative_difference(output, target, valid_mask=None):
|
| 16 |
+
actual_output = output
|
| 17 |
+
actual_target = target
|
| 18 |
+
square_relative_diff = (
|
| 19 |
+
torch.pow(torch.abs(actual_output - actual_target), 2) / actual_target
|
| 20 |
+
)
|
| 21 |
+
if valid_mask is not None:
|
| 22 |
+
square_relative_diff[~valid_mask] = 0
|
| 23 |
+
n = valid_mask.sum((-1, -2))
|
| 24 |
+
else:
|
| 25 |
+
n = output.shape[-1] * output.shape[-2]
|
| 26 |
+
square_relative_diff = torch.sum(square_relative_diff, (-1, -2)) / n
|
| 27 |
+
return square_relative_diff.mean()
|
| 28 |
+
|
| 29 |
+
def rmse_linear(output, target, valid_mask=None):
|
| 30 |
+
actual_output = output
|
| 31 |
+
actual_target = target
|
| 32 |
+
diff = actual_output - actual_target
|
| 33 |
+
if valid_mask is not None:
|
| 34 |
+
diff[~valid_mask] = 0
|
| 35 |
+
n = valid_mask.sum((-1, -2))
|
| 36 |
+
else:
|
| 37 |
+
n = output.shape[-1] * output.shape[-2]
|
| 38 |
+
diff2 = torch.pow(diff, 2)
|
| 39 |
+
mse = torch.sum(diff2, (-1, -2)) / n
|
| 40 |
+
rmse = torch.sqrt(mse)
|
| 41 |
+
return rmse.mean()
|
| 42 |
+
|
| 43 |
+
def rmse_log(output, target, valid_mask=None):
|
| 44 |
+
diff = torch.log(output) - torch.log(target)
|
| 45 |
+
if valid_mask is not None:
|
| 46 |
+
diff[~valid_mask] = 0
|
| 47 |
+
n = valid_mask.sum((-1, -2))
|
| 48 |
+
else:
|
| 49 |
+
n = output.shape[-1] * output.shape[-2]
|
| 50 |
+
diff2 = torch.pow(diff, 2)
|
| 51 |
+
mse = torch.sum(diff2, (-1, -2)) / n # [B]
|
| 52 |
+
rmse = torch.sqrt(mse)
|
| 53 |
+
return rmse.mean()
|
| 54 |
+
|
| 55 |
+
def log10(output, target, valid_mask=None):
|
| 56 |
+
if valid_mask is not None:
|
| 57 |
+
diff = torch.abs(
|
| 58 |
+
torch.log10(output[valid_mask]) - torch.log10(target[valid_mask])
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
diff = torch.abs(torch.log10(output) - torch.log10(target))
|
| 62 |
+
return diff.mean()
|
| 63 |
+
|
| 64 |
+
# adapt from: https://github.com/imran3180/depth-map-prediction/blob/master/main.py
|
| 65 |
+
def threshold_percentage(output, target, threshold_val, valid_mask=None):
|
| 66 |
+
d1 = output / target
|
| 67 |
+
d2 = target / output
|
| 68 |
+
max_d1_d2 = torch.max(d1, d2)
|
| 69 |
+
zero = torch.zeros(*output.shape)
|
| 70 |
+
one = torch.ones(*output.shape)
|
| 71 |
+
bit_mat = torch.where(max_d1_d2.cpu() < threshold_val, one, zero)
|
| 72 |
+
if valid_mask is not None:
|
| 73 |
+
bit_mat[~valid_mask] = 0
|
| 74 |
+
n = valid_mask.sum((-1, -2))
|
| 75 |
+
else:
|
| 76 |
+
n = output.shape[-1] * output.shape[-2]
|
| 77 |
+
count_mat = torch.sum(bit_mat, (-1, -2))
|
| 78 |
+
threshold_mat = count_mat / n.cpu()
|
| 79 |
+
return threshold_mat.mean()
|
| 80 |
+
|
| 81 |
+
def delta1_acc(pred, gt, valid_mask):
|
| 82 |
+
return threshold_percentage(pred, gt, 1.25, valid_mask)
|
| 83 |
+
|
| 84 |
+
def delta2_acc(pred, gt, valid_mask):
|
| 85 |
+
return threshold_percentage(pred, gt, 1.25**2, valid_mask)
|
| 86 |
+
|
| 87 |
+
def delta3_acc(pred, gt, valid_mask):
|
| 88 |
+
return threshold_percentage(pred, gt, 1.25**3, valid_mask)
|
| 89 |
+
|
| 90 |
+
def i_rmse(output, target, valid_mask=None):
|
| 91 |
+
output_inv = 1.0 / output
|
| 92 |
+
target_inv = 1.0 / target
|
| 93 |
+
diff = output_inv - target_inv
|
| 94 |
+
if valid_mask is not None:
|
| 95 |
+
diff[~valid_mask] = 0
|
| 96 |
+
n = valid_mask.sum((-1, -2))
|
| 97 |
+
else:
|
| 98 |
+
n = output.shape[-1] * output.shape[-2]
|
| 99 |
+
diff2 = torch.pow(diff, 2)
|
| 100 |
+
mse = torch.sum(diff2, (-1, -2)) / n # [B]
|
| 101 |
+
rmse = torch.sqrt(mse)
|
| 102 |
+
return rmse.mean()
|
| 103 |
+
|
| 104 |
+
def silog_rmse(depth_pred, depth_gt, valid_mask=None):
|
| 105 |
+
diff = torch.log(depth_pred) - torch.log(depth_gt)
|
| 106 |
+
if valid_mask is not None:
|
| 107 |
+
diff[~valid_mask] = 0
|
| 108 |
+
n = valid_mask.sum((-1, -2))
|
| 109 |
+
else:
|
| 110 |
+
n = depth_gt.shape[-2] * depth_gt.shape[-1]
|
| 111 |
+
|
| 112 |
+
diff2 = torch.pow(diff, 2)
|
| 113 |
+
|
| 114 |
+
first_term = torch.sum(diff2, (-1, -2)) / n
|
| 115 |
+
second_term = torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2)
|
| 116 |
+
loss = torch.sqrt(torch.mean(first_term - second_term)) * 100
|
| 117 |
+
return loss
|
code_depth/benchmark/infer/infer.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from video_depth_anything.video_depth import VideoDepthAnything
|
| 10 |
+
from utils.dc_utils import read_video_frames
|
| 11 |
+
|
| 12 |
+
if __name__ == '__main__':
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument('--infer_path', type=str, default='')
|
| 15 |
+
|
| 16 |
+
parser.add_argument('--json_file', type=str, default='')
|
| 17 |
+
parser.add_argument('--datasets', type=str, nargs='+', default=['scannet', 'nyuv2'])
|
| 18 |
+
|
| 19 |
+
parser.add_argument('--input_size', type=int, default=518)
|
| 20 |
+
parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
|
| 21 |
+
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
|
| 24 |
+
for dataset in args.datasets:
|
| 25 |
+
|
| 26 |
+
with open(args.json_file, 'r') as fs:
|
| 27 |
+
path_json = json.load(fs)
|
| 28 |
+
|
| 29 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 30 |
+
|
| 31 |
+
model_configs = {
|
| 32 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 33 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
|
| 37 |
+
video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'), strict=True)
|
| 38 |
+
video_depth_anything = video_depth_anything.to(DEVICE).eval()
|
| 39 |
+
|
| 40 |
+
json_data = path_json[dataset]
|
| 41 |
+
root_path = os.path.dirname(args.json_file)
|
| 42 |
+
for data in tqdm(json_data):
|
| 43 |
+
for key in data.keys():
|
| 44 |
+
value = data[key]
|
| 45 |
+
infer_paths = []
|
| 46 |
+
|
| 47 |
+
videos = []
|
| 48 |
+
for images in value:
|
| 49 |
+
|
| 50 |
+
image_path = os.path.join(root_path, images['image'])
|
| 51 |
+
infer_path = (args.infer_path + '/'+ dataset + '/' + images['image']).replace('.jpg', '.npy').replace('.png', '.npy')
|
| 52 |
+
infer_paths.append(infer_path)
|
| 53 |
+
|
| 54 |
+
img = cv2.imread(image_path)
|
| 55 |
+
videos.append(img)
|
| 56 |
+
videos = np.stack(videos, axis=0)
|
| 57 |
+
target_fps=1
|
| 58 |
+
depths, fps = video_depth_anything.infer_video_depth(videos, target_fps, input_size=args.input_size, device=DEVICE, fp32=True)
|
| 59 |
+
|
| 60 |
+
for i in range(len(infer_paths)):
|
| 61 |
+
infer_path = infer_paths[i]
|
| 62 |
+
os.makedirs(os.path.dirname(infer_path), exist_ok=True)
|
| 63 |
+
depth = depths[i]
|
| 64 |
+
np.save(infer_path, depth)
|
| 65 |
+
|
code_depth/get_weights.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
mkdir checkpoints
|
| 4 |
+
cd checkpoints
|
| 5 |
+
wget https://huggingface.co/depth-anything/Video-Depth-Anything-Small/resolve/main/video_depth_anything_vits.pth
|
| 6 |
+
wget https://huggingface.co/depth-anything/Video-Depth-Anything-Large/resolve/main/video_depth_anything_vitl.pth
|
code_depth/large_files.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
./checkpoints/video_depth_anything_vitl.pth
|
| 2 |
+
./checkpoints/video_depth_anything_vits.pth
|
code_depth/requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.23.1
|
| 2 |
+
torch==2.1.1
|
| 3 |
+
torchvision==0.16.1
|
| 4 |
+
opencv-python
|
| 5 |
+
matplotlib
|
| 6 |
+
pillow
|
| 7 |
+
imageio==2.19.3
|
| 8 |
+
imageio-ffmpeg==0.4.7
|
| 9 |
+
decord
|
| 10 |
+
xformers==0.0.23
|
| 11 |
+
einops==0.4.1
|
| 12 |
+
easydict
|
| 13 |
+
tqdm
|
| 14 |
+
OpenEXR==3.3.1
|
code_depth/run.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import numpy as np
|
| 16 |
+
import os
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from video_depth_anything.video_depth import VideoDepthAnything
|
| 20 |
+
from utils.dc_utils import read_video_frames, save_video
|
| 21 |
+
|
| 22 |
+
if __name__ == '__main__':
|
| 23 |
+
parser = argparse.ArgumentParser(description='Video Depth Anything')
|
| 24 |
+
parser.add_argument('--input_video', type=str, default='./assets/example_videos/davis_rollercoaster.mp4')
|
| 25 |
+
parser.add_argument('--output_dir', type=str, default='./outputs')
|
| 26 |
+
parser.add_argument('--input_size', type=int, default=518)
|
| 27 |
+
parser.add_argument('--max_res', type=int, default=1280)
|
| 28 |
+
parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
|
| 29 |
+
parser.add_argument('--max_len', type=int, default=-1, help='maximum length of the input video, -1 means no limit')
|
| 30 |
+
parser.add_argument('--target_fps', type=int, default=-1, help='target fps of the input video, -1 means the original fps')
|
| 31 |
+
parser.add_argument('--fp32', action='store_true', help='model infer with torch.float32, default is torch.float16')
|
| 32 |
+
parser.add_argument('--grayscale', action='store_true', help='do not apply colorful palette')
|
| 33 |
+
parser.add_argument('--save_npz', action='store_true', help='save depths as npz')
|
| 34 |
+
parser.add_argument('--save_exr', action='store_true', help='save depths as exr')
|
| 35 |
+
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 39 |
+
|
| 40 |
+
model_configs = {
|
| 41 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 42 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
|
| 46 |
+
video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'), strict=True)
|
| 47 |
+
video_depth_anything = video_depth_anything.to(DEVICE).eval()
|
| 48 |
+
|
| 49 |
+
frames, target_fps = read_video_frames(args.input_video, args.max_len, args.target_fps, args.max_res)
|
| 50 |
+
depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=args.input_size, device=DEVICE, fp32=args.fp32)
|
| 51 |
+
|
| 52 |
+
video_name = os.path.basename(args.input_video)
|
| 53 |
+
if not os.path.exists(args.output_dir):
|
| 54 |
+
os.makedirs(args.output_dir)
|
| 55 |
+
|
| 56 |
+
processed_video_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
|
| 57 |
+
depth_vis_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
|
| 58 |
+
save_video(frames, processed_video_path, fps=fps)
|
| 59 |
+
save_video(depths, depth_vis_path, fps=fps, is_depths=True, grayscale=args.grayscale)
|
| 60 |
+
|
| 61 |
+
if args.save_npz:
|
| 62 |
+
depth_npz_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_depths.npz')
|
| 63 |
+
np.savez_compressed(depth_npz_path, depths=depths)
|
| 64 |
+
if args.save_exr:
|
| 65 |
+
depth_exr_dir = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_depths_exr')
|
| 66 |
+
os.makedirs(depth_exr_dir, exist_ok=True)
|
| 67 |
+
import OpenEXR
|
| 68 |
+
import Imath
|
| 69 |
+
for i, depth in enumerate(depths):
|
| 70 |
+
output_exr = f"{depth_exr_dir}/frame_{i:05d}.exr"
|
| 71 |
+
header = OpenEXR.Header(depth.shape[1], depth.shape[0])
|
| 72 |
+
header["channels"] = {
|
| 73 |
+
"Z": Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT))
|
| 74 |
+
}
|
| 75 |
+
exr_file = OpenEXR.OutputFile(output_exr, header)
|
| 76 |
+
exr_file.writePixels({"Z": depth.tobytes()})
|
| 77 |
+
exr_file.close()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
code_depth/run_images_rord.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import numpy as np
|
| 16 |
+
import os
|
| 17 |
+
import torch
|
| 18 |
+
import os
|
| 19 |
+
import cv2
|
| 20 |
+
import numpy as np
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
import matplotlib.cm as cm
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from video_depth_anything.video_depth import VideoDepthAnything
|
| 25 |
+
from utils.dc_utils import read_video_frames, save_video
|
| 26 |
+
import tqdm
|
| 27 |
+
|
| 28 |
+
if __name__ == '__main__':
|
| 29 |
+
parser = argparse.ArgumentParser(description='Video Depth Anything')
|
| 30 |
+
parser.add_argument('--input_size', type=int, default=518)
|
| 31 |
+
parser.add_argument('--max_res', type=int, default=1280)
|
| 32 |
+
parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
|
| 33 |
+
parser.add_argument('--max_len', type=int, default=-1, help='maximum length of the input video, -1 means no limit')
|
| 34 |
+
parser.add_argument('--target_fps', type=int, default=-1, help='target fps of the input video, -1 means the original fps')
|
| 35 |
+
parser.add_argument('--fp32', action='store_true', help='model infer with torch.float32, default is torch.float16')
|
| 36 |
+
parser.add_argument('--grayscale', action='store_true', help='do not apply colorful palette')
|
| 37 |
+
parser.add_argument('--save_npz', action='store_true', help='save depths as npz')
|
| 38 |
+
parser.add_argument('--save_exr', action='store_true', help='save depths as exr')
|
| 39 |
+
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
|
| 42 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 43 |
+
|
| 44 |
+
model_configs = {
|
| 45 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 46 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
|
| 50 |
+
video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'), strict=True)
|
| 51 |
+
video_depth_anything = video_depth_anything.to(DEVICE).eval()
|
| 52 |
+
|
| 53 |
+
# place input dir and out dir here
|
| 54 |
+
root_img_dir = "RORD/train/img"
|
| 55 |
+
root_gt_dir = "RORD/train/gt"
|
| 56 |
+
save_root_img_base = "RORD/val/img_depth"
|
| 57 |
+
save_root_gt_base = "RORD/val/gt_depth"
|
| 58 |
+
|
| 59 |
+
video_ids = sorted(os.listdir(root_img_dir))
|
| 60 |
+
|
| 61 |
+
for video_id in tqdm.tqdm(video_ids):
|
| 62 |
+
frame_dir = os.path.join(root_img_dir, video_id)
|
| 63 |
+
|
| 64 |
+
frame_paths = sorted([
|
| 65 |
+
os.path.join(frame_dir, fname) for fname in os.listdir(frame_dir)
|
| 66 |
+
if fname.endswith(".jpg") or fname.endswith(".png")
|
| 67 |
+
])
|
| 68 |
+
frames = [cv2.imread(p)[:, :, ::-1] for p in frame_paths]
|
| 69 |
+
gt_path = frame_paths[0].replace("/img/", "/gt/")
|
| 70 |
+
|
| 71 |
+
gt_img = cv2.imread(gt_path)[:, :, ::-1] # BGR to RGB
|
| 72 |
+
frames.append(gt_img)
|
| 73 |
+
|
| 74 |
+
resized_frames = []
|
| 75 |
+
max_res = 1280
|
| 76 |
+
for f in frames:
|
| 77 |
+
h, w = f.shape[:2]
|
| 78 |
+
if max(h, w) > max_res:
|
| 79 |
+
scale = max_res / max(h, w)
|
| 80 |
+
f = cv2.resize(f, (int(w * scale), int(h * scale)))
|
| 81 |
+
resized_frames.append(f)
|
| 82 |
+
|
| 83 |
+
resized_frames = np.stack(resized_frames, axis=0)
|
| 84 |
+
|
| 85 |
+
depths, _ = video_depth_anything.infer_video_depth(
|
| 86 |
+
resized_frames, 32, input_size=518, device=DEVICE, fp32=False
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
save_root_img = os.path.join(save_root_img_base, video_id)
|
| 90 |
+
save_root_gt = os.path.join(save_root_gt_base, video_id)
|
| 91 |
+
os.makedirs(save_root_img, exist_ok=True)
|
| 92 |
+
os.makedirs(save_root_gt, exist_ok=True)
|
| 93 |
+
|
| 94 |
+
colormap = np.array(cm.get_cmap("inferno").colors)
|
| 95 |
+
d_min, d_max = depths.min(), depths.max()
|
| 96 |
+
for i, path in enumerate(frame_paths):
|
| 97 |
+
fname = os.path.basename(path)
|
| 98 |
+
|
| 99 |
+
depth = depths[i]
|
| 100 |
+
depth_norm = ((depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8)
|
| 101 |
+
depth_vis = (colormap[depth_norm] * 255).astype(np.uint8) # shape: (H, W, 3), uint8
|
| 102 |
+
|
| 103 |
+
img_path = os.path.join(save_root_img, fname)
|
| 104 |
+
Image.fromarray(depth_vis).save(img_path)
|
| 105 |
+
|
| 106 |
+
gt_depth = depths[-1]
|
| 107 |
+
gt_norm = ((gt_depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8)
|
| 108 |
+
gt_vis = (colormap[gt_norm] * 255).astype(np.uint8)
|
| 109 |
+
|
| 110 |
+
gt_save_path = os.path.join(save_root_gt, fname)
|
| 111 |
+
Image.fromarray(gt_vis).save(gt_save_path)
|
| 112 |
+
|
code_depth/run_single_image.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import cv2
|
| 9 |
+
import matplotlib.cm as cm
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from video_depth_anything.video_depth import VideoDepthAnything
|
| 12 |
+
|
| 13 |
+
if __name__ == '__main__':
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
parser = argparse.ArgumentParser(description='Video Depth Anything')
|
| 17 |
+
parser.add_argument('--input_size', type=int, default=518)
|
| 18 |
+
parser.add_argument('--max_res', type=int, default=1280)
|
| 19 |
+
parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
|
| 20 |
+
parser.add_argument('--max_len', type=int, default=-1)
|
| 21 |
+
parser.add_argument('--target_fps', type=int, default=-1)
|
| 22 |
+
parser.add_argument('--fp32', action='store_true')
|
| 23 |
+
parser.add_argument('--grayscale', action='store_true')
|
| 24 |
+
parser.add_argument('--save_npz', action='store_true')
|
| 25 |
+
parser.add_argument('--save_exr', action='store_true')
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
|
| 28 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 29 |
+
|
| 30 |
+
model_configs = {
|
| 31 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 32 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
|
| 36 |
+
video_depth_anything.load_state_dict(
|
| 37 |
+
torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'),
|
| 38 |
+
strict=True
|
| 39 |
+
)
|
| 40 |
+
video_depth_anything = video_depth_anything.to(DEVICE).eval()
|
| 41 |
+
|
| 42 |
+
# your image input and output path
|
| 43 |
+
input_path = ""
|
| 44 |
+
output_path = ""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
img = cv2.imread(input_path)[:, :, ::-1]
|
| 48 |
+
h, w = img.shape[:2]
|
| 49 |
+
|
| 50 |
+
if max(h, w) > args.max_res:
|
| 51 |
+
scale = args.max_res / max(h, w)
|
| 52 |
+
img = cv2.resize(img, (int(w * scale), int(h * scale)))
|
| 53 |
+
|
| 54 |
+
frame_tensor = np.stack([img], axis=0)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
depths, _ = video_depth_anything.infer_video_depth(
|
| 58 |
+
frame_tensor, 32, input_size=518, device=DEVICE, fp32=False
|
| 59 |
+
)
|
| 60 |
+
depth = depths[0]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
colormap = np.array(cm.get_cmap("inferno").colors)
|
| 64 |
+
d_min, d_max = depth.min(), depth.max()
|
| 65 |
+
depth_norm = ((depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8)
|
| 66 |
+
depth_vis = (colormap[depth_norm] * 255).astype(np.uint8)
|
| 67 |
+
|
| 68 |
+
Image.fromarray(depth_vis).save(output_path)
|
| 69 |
+
print(f"Saved depth map to: {output_path}")
|
code_depth/utils/dc_utils.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is originally from DepthCrafter/depthcrafter/utils.py at main · Tencent/DepthCrafter
|
| 2 |
+
# SPDX-License-Identifier: MIT License license
|
| 3 |
+
#
|
| 4 |
+
# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
|
| 5 |
+
# Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file].
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.cm as cm
|
| 8 |
+
import imageio
|
| 9 |
+
try:
|
| 10 |
+
from decord import VideoReader, cpu
|
| 11 |
+
DECORD_AVAILABLE = True
|
| 12 |
+
except:
|
| 13 |
+
import cv2
|
| 14 |
+
DECORD_AVAILABLE = False
|
| 15 |
+
|
| 16 |
+
def ensure_even(value):
|
| 17 |
+
return value if value % 2 == 0 else value + 1
|
| 18 |
+
|
| 19 |
+
def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1):
|
| 20 |
+
if DECORD_AVAILABLE:
|
| 21 |
+
vid = VideoReader(video_path, ctx=cpu(0))
|
| 22 |
+
original_height, original_width = vid.get_batch([0]).shape[1:3]
|
| 23 |
+
height = original_height
|
| 24 |
+
width = original_width
|
| 25 |
+
if max_res > 0 and max(height, width) > max_res:
|
| 26 |
+
scale = max_res / max(original_height, original_width)
|
| 27 |
+
height = ensure_even(round(original_height * scale))
|
| 28 |
+
width = ensure_even(round(original_width * scale))
|
| 29 |
+
|
| 30 |
+
vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
|
| 31 |
+
|
| 32 |
+
fps = vid.get_avg_fps() if target_fps == -1 else target_fps
|
| 33 |
+
stride = round(vid.get_avg_fps() / fps)
|
| 34 |
+
stride = max(stride, 1)
|
| 35 |
+
frames_idx = list(range(0, len(vid), stride))
|
| 36 |
+
if process_length != -1 and process_length < len(frames_idx):
|
| 37 |
+
frames_idx = frames_idx[:process_length]
|
| 38 |
+
frames = vid.get_batch(frames_idx).asnumpy()
|
| 39 |
+
else:
|
| 40 |
+
cap = cv2.VideoCapture(video_path)
|
| 41 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
| 42 |
+
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 43 |
+
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 44 |
+
|
| 45 |
+
if max_res > 0 and max(original_height, original_width) > max_res:
|
| 46 |
+
scale = max_res / max(original_height, original_width)
|
| 47 |
+
height = round(original_height * scale)
|
| 48 |
+
width = round(original_width * scale)
|
| 49 |
+
|
| 50 |
+
fps = original_fps if target_fps < 0 else target_fps
|
| 51 |
+
|
| 52 |
+
stride = max(round(original_fps / fps), 1)
|
| 53 |
+
|
| 54 |
+
frames = []
|
| 55 |
+
frame_count = 0
|
| 56 |
+
while cap.isOpened():
|
| 57 |
+
ret, frame = cap.read()
|
| 58 |
+
if not ret or (process_length > 0 and frame_count >= process_length):
|
| 59 |
+
break
|
| 60 |
+
if frame_count % stride == 0:
|
| 61 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
|
| 62 |
+
if max_res > 0 and max(original_height, original_width) > max_res:
|
| 63 |
+
frame = cv2.resize(frame, (width, height)) # Resize frame
|
| 64 |
+
frames.append(frame)
|
| 65 |
+
frame_count += 1
|
| 66 |
+
cap.release()
|
| 67 |
+
frames = np.stack(frames, axis=0)
|
| 68 |
+
|
| 69 |
+
return frames, fps
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def save_video(frames, output_video_path, fps=10, is_depths=False, grayscale=False):
|
| 73 |
+
writer = imageio.get_writer(output_video_path, fps=fps, macro_block_size=1, codec='libx264', ffmpeg_params=['-crf', '18'])
|
| 74 |
+
if is_depths:
|
| 75 |
+
colormap = np.array(cm.get_cmap("inferno").colors)
|
| 76 |
+
d_min, d_max = frames.min(), frames.max()
|
| 77 |
+
for i in range(frames.shape[0]):
|
| 78 |
+
depth = frames[i]
|
| 79 |
+
depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8)
|
| 80 |
+
depth_vis = (colormap[depth_norm] * 255).astype(np.uint8) if not grayscale else depth_norm
|
| 81 |
+
writer.append_data(depth_vis)
|
| 82 |
+
else:
|
| 83 |
+
for i in range(frames.shape[0]):
|
| 84 |
+
writer.append_data(frames[i])
|
| 85 |
+
|
| 86 |
+
writer.close()
|
code_depth/utils/util.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
def compute_scale_and_shift(prediction, target, mask, scale_only=False):
|
| 17 |
+
if scale_only:
|
| 18 |
+
return compute_scale(prediction, target, mask), 0
|
| 19 |
+
else:
|
| 20 |
+
return compute_scale_and_shift_full(prediction, target, mask)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def compute_scale(prediction, target, mask):
|
| 24 |
+
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
|
| 25 |
+
prediction = prediction.astype(np.float32)
|
| 26 |
+
target = target.astype(np.float32)
|
| 27 |
+
mask = mask.astype(np.float32)
|
| 28 |
+
|
| 29 |
+
a_00 = np.sum(mask * prediction * prediction)
|
| 30 |
+
a_01 = np.sum(mask * prediction)
|
| 31 |
+
a_11 = np.sum(mask)
|
| 32 |
+
|
| 33 |
+
# right hand side: b = [b_0, b_1]
|
| 34 |
+
b_0 = np.sum(mask * prediction * target)
|
| 35 |
+
|
| 36 |
+
x_0 = b_0 / (a_00 + 1e-6)
|
| 37 |
+
|
| 38 |
+
return x_0
|
| 39 |
+
|
| 40 |
+
def compute_scale_and_shift_full(prediction, target, mask):
|
| 41 |
+
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
|
| 42 |
+
prediction = prediction.astype(np.float32)
|
| 43 |
+
target = target.astype(np.float32)
|
| 44 |
+
mask = mask.astype(np.float32)
|
| 45 |
+
|
| 46 |
+
a_00 = np.sum(mask * prediction * prediction)
|
| 47 |
+
a_01 = np.sum(mask * prediction)
|
| 48 |
+
a_11 = np.sum(mask)
|
| 49 |
+
|
| 50 |
+
b_0 = np.sum(mask * prediction * target)
|
| 51 |
+
b_1 = np.sum(mask * target)
|
| 52 |
+
|
| 53 |
+
x_0 = 1
|
| 54 |
+
x_1 = 0
|
| 55 |
+
|
| 56 |
+
det = a_00 * a_11 - a_01 * a_01
|
| 57 |
+
|
| 58 |
+
if det != 0:
|
| 59 |
+
x_0 = (a_11 * b_0 - a_01 * b_1) / det
|
| 60 |
+
x_1 = (-a_01 * b_0 + a_00 * b_1) / det
|
| 61 |
+
|
| 62 |
+
return x_0, x_1
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_interpolate_frames(frame_list_pre, frame_list_post):
|
| 66 |
+
assert len(frame_list_pre) == len(frame_list_post)
|
| 67 |
+
min_w = 0.0
|
| 68 |
+
max_w = 1.0
|
| 69 |
+
step = (max_w - min_w) / (len(frame_list_pre)-1)
|
| 70 |
+
post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w]
|
| 71 |
+
interpolated_frames = []
|
| 72 |
+
for i in range(len(frame_list_pre)):
|
| 73 |
+
interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i])
|
| 74 |
+
return interpolated_frames
|
code_depth/video_depth_anything/dinov2.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.utils.checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
|
| 20 |
+
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("dinov2")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 27 |
+
if not depth_first and include_root:
|
| 28 |
+
fn(module=module, name=name)
|
| 29 |
+
for child_name, child_module in module.named_children():
|
| 30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 32 |
+
if depth_first and include_root:
|
| 33 |
+
fn(module=module, name=name)
|
| 34 |
+
return module
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BlockChunk(nn.ModuleList):
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
for b in self:
|
| 40 |
+
x = b(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DinoVisionTransformer(nn.Module):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
img_size=224,
|
| 48 |
+
patch_size=16,
|
| 49 |
+
in_chans=3,
|
| 50 |
+
embed_dim=768,
|
| 51 |
+
depth=12,
|
| 52 |
+
num_heads=12,
|
| 53 |
+
mlp_ratio=4.0,
|
| 54 |
+
qkv_bias=True,
|
| 55 |
+
ffn_bias=True,
|
| 56 |
+
proj_bias=True,
|
| 57 |
+
drop_path_rate=0.0,
|
| 58 |
+
drop_path_uniform=False,
|
| 59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 60 |
+
embed_layer=PatchEmbed,
|
| 61 |
+
act_layer=nn.GELU,
|
| 62 |
+
block_fn=Block,
|
| 63 |
+
ffn_layer="mlp",
|
| 64 |
+
block_chunks=1,
|
| 65 |
+
num_register_tokens=0,
|
| 66 |
+
interpolate_antialias=False,
|
| 67 |
+
interpolate_offset=0.1,
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
img_size (int, tuple): input image size
|
| 72 |
+
patch_size (int, tuple): patch size
|
| 73 |
+
in_chans (int): number of input channels
|
| 74 |
+
embed_dim (int): embedding dimension
|
| 75 |
+
depth (int): depth of transformer
|
| 76 |
+
num_heads (int): number of attention heads
|
| 77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 78 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 80 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 81 |
+
drop_path_rate (float): stochastic depth rate
|
| 82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 83 |
+
weight_init (str): weight init scheme
|
| 84 |
+
init_values (float): layer-scale init values
|
| 85 |
+
embed_layer (nn.Module): patch embedding layer
|
| 86 |
+
act_layer (nn.Module): MLP activation layer
|
| 87 |
+
block_fn (nn.Module): transformer block class
|
| 88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 93 |
+
"""
|
| 94 |
+
super().__init__()
|
| 95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 96 |
+
|
| 97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 98 |
+
self.num_tokens = 1
|
| 99 |
+
self.n_blocks = depth
|
| 100 |
+
self.num_heads = num_heads
|
| 101 |
+
self.patch_size = patch_size
|
| 102 |
+
self.num_register_tokens = num_register_tokens
|
| 103 |
+
self.interpolate_antialias = interpolate_antialias
|
| 104 |
+
self.interpolate_offset = interpolate_offset
|
| 105 |
+
|
| 106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 107 |
+
num_patches = self.patch_embed.num_patches
|
| 108 |
+
|
| 109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 111 |
+
assert num_register_tokens >= 0
|
| 112 |
+
self.register_tokens = (
|
| 113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if drop_path_uniform is True:
|
| 117 |
+
dpr = [drop_path_rate] * depth
|
| 118 |
+
else:
|
| 119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 120 |
+
|
| 121 |
+
if ffn_layer == "mlp":
|
| 122 |
+
logger.info("using MLP layer as FFN")
|
| 123 |
+
ffn_layer = Mlp
|
| 124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 125 |
+
logger.info("using SwiGLU layer as FFN")
|
| 126 |
+
ffn_layer = SwiGLUFFNFused
|
| 127 |
+
elif ffn_layer == "identity":
|
| 128 |
+
logger.info("using Identity layer as FFN")
|
| 129 |
+
|
| 130 |
+
def f(*args, **kwargs):
|
| 131 |
+
return nn.Identity()
|
| 132 |
+
|
| 133 |
+
ffn_layer = f
|
| 134 |
+
else:
|
| 135 |
+
raise NotImplementedError
|
| 136 |
+
|
| 137 |
+
blocks_list = [
|
| 138 |
+
block_fn(
|
| 139 |
+
dim=embed_dim,
|
| 140 |
+
num_heads=num_heads,
|
| 141 |
+
mlp_ratio=mlp_ratio,
|
| 142 |
+
qkv_bias=qkv_bias,
|
| 143 |
+
proj_bias=proj_bias,
|
| 144 |
+
ffn_bias=ffn_bias,
|
| 145 |
+
drop_path=dpr[i],
|
| 146 |
+
norm_layer=norm_layer,
|
| 147 |
+
act_layer=act_layer,
|
| 148 |
+
ffn_layer=ffn_layer,
|
| 149 |
+
init_values=init_values,
|
| 150 |
+
)
|
| 151 |
+
for i in range(depth)
|
| 152 |
+
]
|
| 153 |
+
if block_chunks > 0:
|
| 154 |
+
self.chunked_blocks = True
|
| 155 |
+
chunked_blocks = []
|
| 156 |
+
chunksize = depth // block_chunks
|
| 157 |
+
for i in range(0, depth, chunksize):
|
| 158 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 161 |
+
else:
|
| 162 |
+
self.chunked_blocks = False
|
| 163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 164 |
+
|
| 165 |
+
self.norm = norm_layer(embed_dim)
|
| 166 |
+
self.head = nn.Identity()
|
| 167 |
+
|
| 168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 169 |
+
|
| 170 |
+
self.init_weights()
|
| 171 |
+
|
| 172 |
+
def init_weights(self):
|
| 173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 175 |
+
if self.register_tokens is not None:
|
| 176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 177 |
+
named_apply(init_weights_vit_timm, self)
|
| 178 |
+
|
| 179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 180 |
+
previous_dtype = x.dtype
|
| 181 |
+
npatch = x.shape[1] - 1
|
| 182 |
+
N = self.pos_embed.shape[1] - 1
|
| 183 |
+
if npatch == N and w == h:
|
| 184 |
+
return self.pos_embed
|
| 185 |
+
pos_embed = self.pos_embed.float()
|
| 186 |
+
class_pos_embed = pos_embed[:, 0]
|
| 187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 188 |
+
dim = x.shape[-1]
|
| 189 |
+
w0 = w // self.patch_size
|
| 190 |
+
h0 = h // self.patch_size
|
| 191 |
+
# we add a small number to avoid floating point error in the interpolation
|
| 192 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
| 193 |
+
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
|
| 194 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
| 195 |
+
# w0, h0 = w0 + 0.1, h0 + 0.1
|
| 196 |
+
|
| 197 |
+
sqrt_N = math.sqrt(N)
|
| 198 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
| 199 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 200 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
| 201 |
+
scale_factor=(sx, sy),
|
| 202 |
+
# (int(w0), int(h0)), # to solve the upsampling shape issue
|
| 203 |
+
mode="bicubic",
|
| 204 |
+
antialias=self.interpolate_antialias
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
| 208 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
| 209 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 210 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 211 |
+
|
| 212 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 213 |
+
B, nc, w, h = x.shape
|
| 214 |
+
x = self.patch_embed(x)
|
| 215 |
+
if masks is not None:
|
| 216 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 217 |
+
|
| 218 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 219 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 220 |
+
|
| 221 |
+
if self.register_tokens is not None:
|
| 222 |
+
x = torch.cat(
|
| 223 |
+
(
|
| 224 |
+
x[:, :1],
|
| 225 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 226 |
+
x[:, 1:],
|
| 227 |
+
),
|
| 228 |
+
dim=1,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return x
|
| 232 |
+
|
| 233 |
+
def forward_features_list(self, x_list, masks_list):
|
| 234 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 235 |
+
for blk in self.blocks:
|
| 236 |
+
x = blk(x)
|
| 237 |
+
|
| 238 |
+
all_x = x
|
| 239 |
+
output = []
|
| 240 |
+
for x, masks in zip(all_x, masks_list):
|
| 241 |
+
x_norm = self.norm(x)
|
| 242 |
+
output.append(
|
| 243 |
+
{
|
| 244 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 245 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 246 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 247 |
+
"x_prenorm": x,
|
| 248 |
+
"masks": masks,
|
| 249 |
+
}
|
| 250 |
+
)
|
| 251 |
+
return output
|
| 252 |
+
|
| 253 |
+
def forward_features(self, x, masks=None):
|
| 254 |
+
if isinstance(x, list):
|
| 255 |
+
return self.forward_features_list(x, masks)
|
| 256 |
+
|
| 257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 258 |
+
|
| 259 |
+
for blk in self.blocks:
|
| 260 |
+
x = blk(x)
|
| 261 |
+
|
| 262 |
+
x_norm = self.norm(x)
|
| 263 |
+
return {
|
| 264 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 265 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 266 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 267 |
+
"x_prenorm": x,
|
| 268 |
+
"masks": masks,
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 272 |
+
x = self.prepare_tokens_with_masks(x)
|
| 273 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 274 |
+
output, total_block_len = [], len(self.blocks)
|
| 275 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 276 |
+
for i, blk in enumerate(self.blocks):
|
| 277 |
+
x = blk(x)
|
| 278 |
+
if i in blocks_to_take:
|
| 279 |
+
output.append(x)
|
| 280 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 281 |
+
return output
|
| 282 |
+
|
| 283 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 284 |
+
x = self.prepare_tokens_with_masks(x)
|
| 285 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 286 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 288 |
+
for block_chunk in self.blocks:
|
| 289 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 290 |
+
x = blk(x)
|
| 291 |
+
if i in blocks_to_take:
|
| 292 |
+
output.append(x)
|
| 293 |
+
i += 1
|
| 294 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 295 |
+
return output
|
| 296 |
+
|
| 297 |
+
def get_intermediate_layers(
|
| 298 |
+
self,
|
| 299 |
+
x: torch.Tensor,
|
| 300 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 301 |
+
reshape: bool = False,
|
| 302 |
+
return_class_token: bool = False,
|
| 303 |
+
norm=True
|
| 304 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 305 |
+
if self.chunked_blocks:
|
| 306 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 307 |
+
else:
|
| 308 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 309 |
+
if norm:
|
| 310 |
+
outputs = [self.norm(out) for out in outputs]
|
| 311 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 312 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
| 313 |
+
if reshape:
|
| 314 |
+
B, _, w, h = x.shape
|
| 315 |
+
outputs = [
|
| 316 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 317 |
+
for out in outputs
|
| 318 |
+
]
|
| 319 |
+
if return_class_token:
|
| 320 |
+
return tuple(zip(outputs, class_tokens))
|
| 321 |
+
return tuple(outputs)
|
| 322 |
+
|
| 323 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 324 |
+
ret = self.forward_features(*args, **kwargs)
|
| 325 |
+
if is_training:
|
| 326 |
+
return ret
|
| 327 |
+
else:
|
| 328 |
+
return self.head(ret["x_norm_clstoken"])
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 332 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 333 |
+
if isinstance(module, nn.Linear):
|
| 334 |
+
trunc_normal_(module.weight, std=0.02)
|
| 335 |
+
if module.bias is not None:
|
| 336 |
+
nn.init.zeros_(module.bias)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 340 |
+
model = DinoVisionTransformer(
|
| 341 |
+
patch_size=patch_size,
|
| 342 |
+
embed_dim=384,
|
| 343 |
+
depth=12,
|
| 344 |
+
num_heads=6,
|
| 345 |
+
mlp_ratio=4,
|
| 346 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 347 |
+
num_register_tokens=num_register_tokens,
|
| 348 |
+
**kwargs,
|
| 349 |
+
)
|
| 350 |
+
return model
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 354 |
+
model = DinoVisionTransformer(
|
| 355 |
+
patch_size=patch_size,
|
| 356 |
+
embed_dim=768,
|
| 357 |
+
depth=12,
|
| 358 |
+
num_heads=12,
|
| 359 |
+
mlp_ratio=4,
|
| 360 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 361 |
+
num_register_tokens=num_register_tokens,
|
| 362 |
+
**kwargs,
|
| 363 |
+
)
|
| 364 |
+
return model
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 368 |
+
model = DinoVisionTransformer(
|
| 369 |
+
patch_size=patch_size,
|
| 370 |
+
embed_dim=1024,
|
| 371 |
+
depth=24,
|
| 372 |
+
num_heads=16,
|
| 373 |
+
mlp_ratio=4,
|
| 374 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 375 |
+
num_register_tokens=num_register_tokens,
|
| 376 |
+
**kwargs,
|
| 377 |
+
)
|
| 378 |
+
return model
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 382 |
+
"""
|
| 383 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 384 |
+
"""
|
| 385 |
+
model = DinoVisionTransformer(
|
| 386 |
+
patch_size=patch_size,
|
| 387 |
+
embed_dim=1536,
|
| 388 |
+
depth=40,
|
| 389 |
+
num_heads=24,
|
| 390 |
+
mlp_ratio=4,
|
| 391 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 392 |
+
num_register_tokens=num_register_tokens,
|
| 393 |
+
**kwargs,
|
| 394 |
+
)
|
| 395 |
+
return model
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def DINOv2(model_name):
|
| 399 |
+
model_zoo = {
|
| 400 |
+
"vits": vit_small,
|
| 401 |
+
"vitb": vit_base,
|
| 402 |
+
"vitl": vit_large,
|
| 403 |
+
"vitg": vit_giant2
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
return model_zoo[model_name](
|
| 407 |
+
img_size=518,
|
| 408 |
+
patch_size=14,
|
| 409 |
+
init_values=1.0,
|
| 410 |
+
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
|
| 411 |
+
block_chunks=0,
|
| 412 |
+
num_register_tokens=0,
|
| 413 |
+
interpolate_antialias=False,
|
| 414 |
+
interpolate_offset=0.1
|
| 415 |
+
)
|
code_depth/video_depth_anything/dinov2_layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
code_depth/video_depth_anything/dinov2_layers/attention.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("dinov2")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
| 22 |
+
|
| 23 |
+
XFORMERS_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
logger.warning("xFormers not available")
|
| 26 |
+
XFORMERS_AVAILABLE = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Attention(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim: int,
|
| 33 |
+
num_heads: int = 8,
|
| 34 |
+
qkv_bias: bool = False,
|
| 35 |
+
proj_bias: bool = True,
|
| 36 |
+
attn_drop: float = 0.0,
|
| 37 |
+
proj_drop: float = 0.0,
|
| 38 |
+
) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.num_heads = num_heads
|
| 41 |
+
head_dim = dim // num_heads
|
| 42 |
+
self.scale = head_dim**-0.5
|
| 43 |
+
|
| 44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 48 |
+
|
| 49 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 50 |
+
B, N, C = x.shape
|
| 51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 52 |
+
|
| 53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 54 |
+
attn = q @ k.transpose(-2, -1)
|
| 55 |
+
|
| 56 |
+
attn = attn.softmax(dim=-1)
|
| 57 |
+
attn = self.attn_drop(attn)
|
| 58 |
+
|
| 59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 60 |
+
x = self.proj(x)
|
| 61 |
+
x = self.proj_drop(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class MemEffAttention(Attention):
|
| 66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 67 |
+
if not XFORMERS_AVAILABLE:
|
| 68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
| 69 |
+
return super().forward(x)
|
| 70 |
+
|
| 71 |
+
B, N, C = x.shape
|
| 72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 73 |
+
|
| 74 |
+
q, k, v = unbind(qkv, 2)
|
| 75 |
+
|
| 76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 77 |
+
x = x.reshape([B, N, C])
|
| 78 |
+
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
x = self.proj_drop(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
code_depth/video_depth_anything/dinov2_layers/block.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn, Tensor
|
| 16 |
+
|
| 17 |
+
from .attention import Attention, MemEffAttention
|
| 18 |
+
from .drop_path import DropPath
|
| 19 |
+
from .layer_scale import LayerScale
|
| 20 |
+
from .mlp import Mlp
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("dinov2")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from xformers.ops import fmha
|
| 28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
| 29 |
+
|
| 30 |
+
XFORMERS_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
logger.warning("xFormers not available")
|
| 33 |
+
XFORMERS_AVAILABLE = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Block(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int,
|
| 41 |
+
mlp_ratio: float = 4.0,
|
| 42 |
+
qkv_bias: bool = False,
|
| 43 |
+
proj_bias: bool = True,
|
| 44 |
+
ffn_bias: bool = True,
|
| 45 |
+
drop: float = 0.0,
|
| 46 |
+
attn_drop: float = 0.0,
|
| 47 |
+
init_values=None,
|
| 48 |
+
drop_path: float = 0.0,
|
| 49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 53 |
+
) -> None:
|
| 54 |
+
super().__init__()
|
| 55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 56 |
+
self.norm1 = norm_layer(dim)
|
| 57 |
+
self.attn = attn_class(
|
| 58 |
+
dim,
|
| 59 |
+
num_heads=num_heads,
|
| 60 |
+
qkv_bias=qkv_bias,
|
| 61 |
+
proj_bias=proj_bias,
|
| 62 |
+
attn_drop=attn_drop,
|
| 63 |
+
proj_drop=drop,
|
| 64 |
+
)
|
| 65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 67 |
+
|
| 68 |
+
self.norm2 = norm_layer(dim)
|
| 69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 70 |
+
self.mlp = ffn_layer(
|
| 71 |
+
in_features=dim,
|
| 72 |
+
hidden_features=mlp_hidden_dim,
|
| 73 |
+
act_layer=act_layer,
|
| 74 |
+
drop=drop,
|
| 75 |
+
bias=ffn_bias,
|
| 76 |
+
)
|
| 77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 79 |
+
|
| 80 |
+
self.sample_drop_ratio = drop_path
|
| 81 |
+
|
| 82 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 85 |
+
|
| 86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 88 |
+
|
| 89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 91 |
+
x = drop_add_residual_stochastic_depth(
|
| 92 |
+
x,
|
| 93 |
+
residual_func=attn_residual_func,
|
| 94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 95 |
+
)
|
| 96 |
+
x = drop_add_residual_stochastic_depth(
|
| 97 |
+
x,
|
| 98 |
+
residual_func=ffn_residual_func,
|
| 99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 100 |
+
)
|
| 101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 104 |
+
else:
|
| 105 |
+
x = x + attn_residual_func(x)
|
| 106 |
+
x = x + ffn_residual_func(x)
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def drop_add_residual_stochastic_depth(
|
| 111 |
+
x: Tensor,
|
| 112 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 113 |
+
sample_drop_ratio: float = 0.0,
|
| 114 |
+
) -> Tensor:
|
| 115 |
+
# 1) extract subset using permutation
|
| 116 |
+
b, n, d = x.shape
|
| 117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 119 |
+
x_subset = x[brange]
|
| 120 |
+
|
| 121 |
+
# 2) apply residual_func to get residual
|
| 122 |
+
residual = residual_func(x_subset)
|
| 123 |
+
|
| 124 |
+
x_flat = x.flatten(1)
|
| 125 |
+
residual = residual.flatten(1)
|
| 126 |
+
|
| 127 |
+
residual_scale_factor = b / sample_subset_size
|
| 128 |
+
|
| 129 |
+
# 3) add the residual
|
| 130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 131 |
+
return x_plus_residual.view_as(x)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 135 |
+
b, n, d = x.shape
|
| 136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 138 |
+
residual_scale_factor = b / sample_subset_size
|
| 139 |
+
return brange, residual_scale_factor
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 143 |
+
if scaling_vector is None:
|
| 144 |
+
x_flat = x.flatten(1)
|
| 145 |
+
residual = residual.flatten(1)
|
| 146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 147 |
+
else:
|
| 148 |
+
x_plus_residual = scaled_index_add(
|
| 149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 150 |
+
)
|
| 151 |
+
return x_plus_residual
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 158 |
+
"""
|
| 159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 160 |
+
"""
|
| 161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 163 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 164 |
+
seqlens = []
|
| 165 |
+
for b, x in zip(batch_sizes, x_list):
|
| 166 |
+
for _ in range(b):
|
| 167 |
+
seqlens.append(x.shape[1])
|
| 168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 169 |
+
attn_bias._batch_sizes = batch_sizes
|
| 170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 171 |
+
|
| 172 |
+
if branges is not None:
|
| 173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 174 |
+
else:
|
| 175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 177 |
+
|
| 178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def drop_add_residual_stochastic_depth_list(
|
| 182 |
+
x_list: List[Tensor],
|
| 183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 184 |
+
sample_drop_ratio: float = 0.0,
|
| 185 |
+
scaling_vector=None,
|
| 186 |
+
) -> Tensor:
|
| 187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 189 |
+
branges = [s[0] for s in branges_scales]
|
| 190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 191 |
+
|
| 192 |
+
# 2) get attention bias and index+concat the tensors
|
| 193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 194 |
+
|
| 195 |
+
# 3) apply residual_func to get residual, and split the result
|
| 196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 197 |
+
|
| 198 |
+
outputs = []
|
| 199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 201 |
+
return outputs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class NestedTensorBlock(Block):
|
| 205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 206 |
+
"""
|
| 207 |
+
x_list contains a list of tensors to nest together and run
|
| 208 |
+
"""
|
| 209 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 210 |
+
|
| 211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 212 |
+
|
| 213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 215 |
+
|
| 216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 217 |
+
return self.mlp(self.norm2(x))
|
| 218 |
+
|
| 219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 220 |
+
x_list,
|
| 221 |
+
residual_func=attn_residual_func,
|
| 222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 224 |
+
)
|
| 225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 226 |
+
x_list,
|
| 227 |
+
residual_func=ffn_residual_func,
|
| 228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 230 |
+
)
|
| 231 |
+
return x_list
|
| 232 |
+
else:
|
| 233 |
+
|
| 234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 236 |
+
|
| 237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 239 |
+
|
| 240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 242 |
+
x = x + ffn_residual_func(x)
|
| 243 |
+
return attn_bias.split(x)
|
| 244 |
+
|
| 245 |
+
def forward(self, x_or_x_list):
|
| 246 |
+
if isinstance(x_or_x_list, Tensor):
|
| 247 |
+
return super().forward(x_or_x_list)
|
| 248 |
+
elif isinstance(x_or_x_list, list):
|
| 249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
| 250 |
+
return self.forward_nested(x_or_x_list)
|
| 251 |
+
else:
|
| 252 |
+
raise AssertionError
|
code_depth/video_depth_anything/dinov2_layers/drop_path.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 16 |
+
if drop_prob == 0.0 or not training:
|
| 17 |
+
return x
|
| 18 |
+
keep_prob = 1 - drop_prob
|
| 19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 21 |
+
if keep_prob > 0.0:
|
| 22 |
+
random_tensor.div_(keep_prob)
|
| 23 |
+
output = x * random_tensor
|
| 24 |
+
return output
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DropPath(nn.Module):
|
| 28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, drop_prob=None):
|
| 31 |
+
super(DropPath, self).__init__()
|
| 32 |
+
self.drop_prob = drop_prob
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
return drop_path(x, self.drop_prob, self.training)
|
code_depth/video_depth_anything/dinov2_layers/layer_scale.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 8 |
+
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LayerScale(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 21 |
+
inplace: bool = False,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.inplace = inplace
|
| 25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 26 |
+
|
| 27 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
code_depth/video_depth_anything/dinov2_layers/mlp.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from typing import Callable, Optional
|
| 13 |
+
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Mlp(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
in_features: int,
|
| 21 |
+
hidden_features: Optional[int] = None,
|
| 22 |
+
out_features: Optional[int] = None,
|
| 23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 24 |
+
drop: float = 0.0,
|
| 25 |
+
bias: bool = True,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__()
|
| 28 |
+
out_features = out_features or in_features
|
| 29 |
+
hidden_features = hidden_features or in_features
|
| 30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 31 |
+
self.act = act_layer()
|
| 32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 33 |
+
self.drop = nn.Dropout(drop)
|
| 34 |
+
|
| 35 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 36 |
+
x = self.fc1(x)
|
| 37 |
+
x = self.act(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
x = self.fc2(x)
|
| 40 |
+
x = self.drop(x)
|
| 41 |
+
return x
|
code_depth/video_depth_anything/dinov2_layers/patch_embed.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def make_2tuple(x):
|
| 18 |
+
if isinstance(x, tuple):
|
| 19 |
+
assert len(x) == 2
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
assert isinstance(x, int)
|
| 23 |
+
return (x, x)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PatchEmbed(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
img_size: Image size.
|
| 32 |
+
patch_size: Patch token size.
|
| 33 |
+
in_chans: Number of input image channels.
|
| 34 |
+
embed_dim: Number of linear projection output channels.
|
| 35 |
+
norm_layer: Normalization layer.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 42 |
+
in_chans: int = 3,
|
| 43 |
+
embed_dim: int = 768,
|
| 44 |
+
norm_layer: Optional[Callable] = None,
|
| 45 |
+
flatten_embedding: bool = True,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
image_HW = make_2tuple(img_size)
|
| 50 |
+
patch_HW = make_2tuple(patch_size)
|
| 51 |
+
patch_grid_size = (
|
| 52 |
+
image_HW[0] // patch_HW[0],
|
| 53 |
+
image_HW[1] // patch_HW[1],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.img_size = image_HW
|
| 57 |
+
self.patch_size = patch_HW
|
| 58 |
+
self.patches_resolution = patch_grid_size
|
| 59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 60 |
+
|
| 61 |
+
self.in_chans = in_chans
|
| 62 |
+
self.embed_dim = embed_dim
|
| 63 |
+
|
| 64 |
+
self.flatten_embedding = flatten_embedding
|
| 65 |
+
|
| 66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 68 |
+
|
| 69 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 70 |
+
_, _, H, W = x.shape
|
| 71 |
+
patch_H, patch_W = self.patch_size
|
| 72 |
+
|
| 73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 75 |
+
|
| 76 |
+
x = self.proj(x) # B C H W
|
| 77 |
+
H, W = x.size(2), x.size(3)
|
| 78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 79 |
+
x = self.norm(x)
|
| 80 |
+
if not self.flatten_embedding:
|
| 81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
def flops(self) -> float:
|
| 85 |
+
Ho, Wo = self.patches_resolution
|
| 86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 87 |
+
if self.norm is not None:
|
| 88 |
+
flops += Ho * Wo * self.embed_dim
|
| 89 |
+
return flops
|
code_depth/video_depth_anything/dinov2_layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SwiGLUFFN(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_features: int,
|
| 17 |
+
hidden_features: Optional[int] = None,
|
| 18 |
+
out_features: Optional[int] = None,
|
| 19 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 20 |
+
drop: float = 0.0,
|
| 21 |
+
bias: bool = True,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
out_features = out_features or in_features
|
| 25 |
+
hidden_features = hidden_features or in_features
|
| 26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 30 |
+
x12 = self.w12(x)
|
| 31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 32 |
+
hidden = F.silu(x1) * x2
|
| 33 |
+
return self.w3(hidden)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from xformers.ops import SwiGLU
|
| 38 |
+
|
| 39 |
+
XFORMERS_AVAILABLE = True
|
| 40 |
+
except ImportError:
|
| 41 |
+
SwiGLU = SwiGLUFFN
|
| 42 |
+
XFORMERS_AVAILABLE = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
in_features: int,
|
| 49 |
+
hidden_features: Optional[int] = None,
|
| 50 |
+
out_features: Optional[int] = None,
|
| 51 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 52 |
+
drop: float = 0.0,
|
| 53 |
+
bias: bool = True,
|
| 54 |
+
) -> None:
|
| 55 |
+
out_features = out_features or in_features
|
| 56 |
+
hidden_features = hidden_features or in_features
|
| 57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 58 |
+
super().__init__(
|
| 59 |
+
in_features=in_features,
|
| 60 |
+
hidden_features=hidden_features,
|
| 61 |
+
out_features=out_features,
|
| 62 |
+
bias=bias,
|
| 63 |
+
)
|
code_depth/video_depth_anything/dpt.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from .util.blocks import FeatureFusionBlock, _make_scratch
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _make_fusion_block(features, use_bn, size=None):
|
| 22 |
+
return FeatureFusionBlock(
|
| 23 |
+
features,
|
| 24 |
+
nn.ReLU(False),
|
| 25 |
+
deconv=False,
|
| 26 |
+
bn=use_bn,
|
| 27 |
+
expand=False,
|
| 28 |
+
align_corners=True,
|
| 29 |
+
size=size,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ConvBlock(nn.Module):
|
| 34 |
+
def __init__(self, in_feature, out_feature):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
self.conv_block = nn.Sequential(
|
| 38 |
+
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
|
| 39 |
+
nn.BatchNorm2d(out_feature),
|
| 40 |
+
nn.ReLU(True)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
return self.conv_block(x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DPTHead(nn.Module):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
in_channels,
|
| 51 |
+
features=256,
|
| 52 |
+
use_bn=False,
|
| 53 |
+
out_channels=[256, 512, 1024, 1024],
|
| 54 |
+
use_clstoken=False
|
| 55 |
+
):
|
| 56 |
+
super(DPTHead, self).__init__()
|
| 57 |
+
|
| 58 |
+
self.use_clstoken = use_clstoken
|
| 59 |
+
|
| 60 |
+
self.projects = nn.ModuleList([
|
| 61 |
+
nn.Conv2d(
|
| 62 |
+
in_channels=in_channels,
|
| 63 |
+
out_channels=out_channel,
|
| 64 |
+
kernel_size=1,
|
| 65 |
+
stride=1,
|
| 66 |
+
padding=0,
|
| 67 |
+
) for out_channel in out_channels
|
| 68 |
+
])
|
| 69 |
+
|
| 70 |
+
self.resize_layers = nn.ModuleList([
|
| 71 |
+
nn.ConvTranspose2d(
|
| 72 |
+
in_channels=out_channels[0],
|
| 73 |
+
out_channels=out_channels[0],
|
| 74 |
+
kernel_size=4,
|
| 75 |
+
stride=4,
|
| 76 |
+
padding=0),
|
| 77 |
+
nn.ConvTranspose2d(
|
| 78 |
+
in_channels=out_channels[1],
|
| 79 |
+
out_channels=out_channels[1],
|
| 80 |
+
kernel_size=2,
|
| 81 |
+
stride=2,
|
| 82 |
+
padding=0),
|
| 83 |
+
nn.Identity(),
|
| 84 |
+
nn.Conv2d(
|
| 85 |
+
in_channels=out_channels[3],
|
| 86 |
+
out_channels=out_channels[3],
|
| 87 |
+
kernel_size=3,
|
| 88 |
+
stride=2,
|
| 89 |
+
padding=1)
|
| 90 |
+
])
|
| 91 |
+
|
| 92 |
+
if use_clstoken:
|
| 93 |
+
self.readout_projects = nn.ModuleList()
|
| 94 |
+
for _ in range(len(self.projects)):
|
| 95 |
+
self.readout_projects.append(
|
| 96 |
+
nn.Sequential(
|
| 97 |
+
nn.Linear(2 * in_channels, in_channels),
|
| 98 |
+
nn.GELU()))
|
| 99 |
+
|
| 100 |
+
self.scratch = _make_scratch(
|
| 101 |
+
out_channels,
|
| 102 |
+
features,
|
| 103 |
+
groups=1,
|
| 104 |
+
expand=False,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.scratch.stem_transpose = None
|
| 108 |
+
|
| 109 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
| 110 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
| 111 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
| 112 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
| 113 |
+
|
| 114 |
+
head_features_1 = features
|
| 115 |
+
head_features_2 = 32
|
| 116 |
+
|
| 117 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
| 118 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 119 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 120 |
+
nn.ReLU(True),
|
| 121 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
| 122 |
+
nn.ReLU(True),
|
| 123 |
+
nn.Identity(),
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def forward(self, out_features, patch_h, patch_w):
|
| 127 |
+
out = []
|
| 128 |
+
for i, x in enumerate(out_features):
|
| 129 |
+
if self.use_clstoken:
|
| 130 |
+
x, cls_token = x[0], x[1]
|
| 131 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
| 132 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
| 133 |
+
else:
|
| 134 |
+
x = x[0]
|
| 135 |
+
|
| 136 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 137 |
+
|
| 138 |
+
x = self.projects[i](x)
|
| 139 |
+
x = self.resize_layers[i](x)
|
| 140 |
+
|
| 141 |
+
out.append(x)
|
| 142 |
+
|
| 143 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
| 144 |
+
|
| 145 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 146 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 147 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 148 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 149 |
+
|
| 150 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 151 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 152 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 153 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
| 154 |
+
|
| 155 |
+
out = self.scratch.output_conv1(path_1)
|
| 156 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
| 157 |
+
out = self.scratch.output_conv2(out)
|
| 158 |
+
|
| 159 |
+
return out
|
| 160 |
+
|
code_depth/video_depth_anything/dpt_temporal.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from .dpt import DPTHead
|
| 18 |
+
from .motion_module.motion_module import TemporalModule
|
| 19 |
+
from easydict import EasyDict
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DPTHeadTemporal(DPTHead):
|
| 23 |
+
def __init__(self,
|
| 24 |
+
in_channels,
|
| 25 |
+
features=256,
|
| 26 |
+
use_bn=False,
|
| 27 |
+
out_channels=[256, 512, 1024, 1024],
|
| 28 |
+
use_clstoken=False,
|
| 29 |
+
num_frames=32,
|
| 30 |
+
pe='ape'
|
| 31 |
+
):
|
| 32 |
+
super().__init__(in_channels, features, use_bn, out_channels, use_clstoken)
|
| 33 |
+
|
| 34 |
+
assert num_frames > 0
|
| 35 |
+
motion_module_kwargs = EasyDict(num_attention_heads = 8,
|
| 36 |
+
num_transformer_block = 1,
|
| 37 |
+
num_attention_blocks = 2,
|
| 38 |
+
temporal_max_len = num_frames,
|
| 39 |
+
zero_initialize = True,
|
| 40 |
+
pos_embedding_type = pe)
|
| 41 |
+
|
| 42 |
+
self.motion_modules = nn.ModuleList([
|
| 43 |
+
TemporalModule(in_channels=out_channels[2],
|
| 44 |
+
**motion_module_kwargs),
|
| 45 |
+
TemporalModule(in_channels=out_channels[3],
|
| 46 |
+
**motion_module_kwargs),
|
| 47 |
+
TemporalModule(in_channels=features,
|
| 48 |
+
**motion_module_kwargs),
|
| 49 |
+
TemporalModule(in_channels=features,
|
| 50 |
+
**motion_module_kwargs)
|
| 51 |
+
])
|
| 52 |
+
|
| 53 |
+
def forward(self, out_features, patch_h, patch_w, frame_length, micro_batch_size=4):
|
| 54 |
+
out = []
|
| 55 |
+
for i, x in enumerate(out_features):
|
| 56 |
+
if self.use_clstoken:
|
| 57 |
+
x, cls_token = x[0], x[1]
|
| 58 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
| 59 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
| 60 |
+
else:
|
| 61 |
+
x = x[0]
|
| 62 |
+
|
| 63 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)).contiguous()
|
| 64 |
+
|
| 65 |
+
B, T = x.shape[0] // frame_length, frame_length
|
| 66 |
+
x = self.projects[i](x)
|
| 67 |
+
x = self.resize_layers[i](x)
|
| 68 |
+
|
| 69 |
+
out.append(x)
|
| 70 |
+
|
| 71 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
| 72 |
+
|
| 73 |
+
B, T = layer_1.shape[0] // frame_length, frame_length
|
| 74 |
+
|
| 75 |
+
layer_3 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
| 76 |
+
layer_4 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
| 77 |
+
|
| 78 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 79 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 80 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 81 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 82 |
+
|
| 83 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 84 |
+
path_4 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
| 85 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 86 |
+
path_3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
| 87 |
+
|
| 88 |
+
batch_size = layer_1_rn.shape[0]
|
| 89 |
+
if batch_size <= micro_batch_size or batch_size % micro_batch_size != 0:
|
| 90 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 91 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
| 92 |
+
|
| 93 |
+
out = self.scratch.output_conv1(path_1)
|
| 94 |
+
out = F.interpolate(
|
| 95 |
+
out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True
|
| 96 |
+
)
|
| 97 |
+
ori_type = out.dtype
|
| 98 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
| 99 |
+
out = self.scratch.output_conv2(out.float())
|
| 100 |
+
return out.to(ori_type)
|
| 101 |
+
else:
|
| 102 |
+
ret = []
|
| 103 |
+
for i in range(0, batch_size, micro_batch_size):
|
| 104 |
+
path_2 = self.scratch.refinenet2(path_3[i:i + micro_batch_size], layer_2_rn[i:i + micro_batch_size], size=layer_1_rn[i:i + micro_batch_size].shape[2:])
|
| 105 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn[i:i + micro_batch_size])
|
| 106 |
+
out = self.scratch.output_conv1(path_1)
|
| 107 |
+
out = F.interpolate(
|
| 108 |
+
out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True
|
| 109 |
+
)
|
| 110 |
+
ori_type = out.dtype
|
| 111 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
| 112 |
+
out = self.scratch.output_conv2(out.float())
|
| 113 |
+
ret.append(out.to(ori_type))
|
| 114 |
+
return torch.cat(ret, dim=0)
|
code_depth/video_depth_anything/motion_module/attention.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import xformers
|
| 22 |
+
import xformers.ops
|
| 23 |
+
|
| 24 |
+
XFORMERS_AVAILABLE = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
print("xFormers not available")
|
| 27 |
+
XFORMERS_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CrossAttention(nn.Module):
|
| 31 |
+
r"""
|
| 32 |
+
A cross attention layer.
|
| 33 |
+
|
| 34 |
+
Parameters:
|
| 35 |
+
query_dim (`int`): The number of channels in the query.
|
| 36 |
+
cross_attention_dim (`int`, *optional*):
|
| 37 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
| 38 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
| 39 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
| 40 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 41 |
+
bias (`bool`, *optional*, defaults to False):
|
| 42 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
query_dim: int,
|
| 48 |
+
cross_attention_dim: Optional[int] = None,
|
| 49 |
+
heads: int = 8,
|
| 50 |
+
dim_head: int = 64,
|
| 51 |
+
dropout: float = 0.0,
|
| 52 |
+
bias=False,
|
| 53 |
+
upcast_attention: bool = False,
|
| 54 |
+
upcast_softmax: bool = False,
|
| 55 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 56 |
+
norm_num_groups: Optional[int] = None,
|
| 57 |
+
):
|
| 58 |
+
super().__init__()
|
| 59 |
+
inner_dim = dim_head * heads
|
| 60 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 61 |
+
self.upcast_attention = upcast_attention
|
| 62 |
+
self.upcast_softmax = upcast_softmax
|
| 63 |
+
self.upcast_efficient_attention = False
|
| 64 |
+
|
| 65 |
+
self.scale = dim_head**-0.5
|
| 66 |
+
|
| 67 |
+
self.heads = heads
|
| 68 |
+
# for slice_size > 0 the attention score computation
|
| 69 |
+
# is split across the batch axis to save memory
|
| 70 |
+
# You can set slice_size with `set_attention_slice`
|
| 71 |
+
self.sliceable_head_dim = heads
|
| 72 |
+
self._slice_size = None
|
| 73 |
+
self._use_memory_efficient_attention_xformers = False
|
| 74 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 75 |
+
|
| 76 |
+
if norm_num_groups is not None:
|
| 77 |
+
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
| 78 |
+
else:
|
| 79 |
+
self.group_norm = None
|
| 80 |
+
|
| 81 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
| 82 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
| 83 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
| 84 |
+
|
| 85 |
+
if self.added_kv_proj_dim is not None:
|
| 86 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
| 87 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
| 88 |
+
|
| 89 |
+
self.to_out = nn.ModuleList([])
|
| 90 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
| 91 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 92 |
+
|
| 93 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
| 94 |
+
batch_size, seq_len, dim = tensor.shape
|
| 95 |
+
head_size = self.heads
|
| 96 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
|
| 97 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size).contiguous()
|
| 98 |
+
return tensor
|
| 99 |
+
|
| 100 |
+
def reshape_heads_to_4d(self, tensor):
|
| 101 |
+
batch_size, seq_len, dim = tensor.shape
|
| 102 |
+
head_size = self.heads
|
| 103 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
|
| 104 |
+
return tensor
|
| 105 |
+
|
| 106 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
| 107 |
+
batch_size, seq_len, dim = tensor.shape
|
| 108 |
+
head_size = self.heads
|
| 109 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim).contiguous()
|
| 110 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size).contiguous()
|
| 111 |
+
return tensor
|
| 112 |
+
|
| 113 |
+
def reshape_4d_to_heads(self, tensor):
|
| 114 |
+
batch_size, seq_len, head_size, dim = tensor.shape
|
| 115 |
+
head_size = self.heads
|
| 116 |
+
tensor = tensor.reshape(batch_size, seq_len, dim * head_size).contiguous()
|
| 117 |
+
return tensor
|
| 118 |
+
|
| 119 |
+
def set_attention_slice(self, slice_size):
|
| 120 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
| 121 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
| 122 |
+
|
| 123 |
+
self._slice_size = slice_size
|
| 124 |
+
|
| 125 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 126 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 127 |
+
|
| 128 |
+
encoder_hidden_states = encoder_hidden_states
|
| 129 |
+
|
| 130 |
+
if self.group_norm is not None:
|
| 131 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 132 |
+
|
| 133 |
+
query = self.to_q(hidden_states)
|
| 134 |
+
dim = query.shape[-1]
|
| 135 |
+
query = self.reshape_heads_to_batch_dim(query)
|
| 136 |
+
|
| 137 |
+
if self.added_kv_proj_dim is not None:
|
| 138 |
+
key = self.to_k(hidden_states)
|
| 139 |
+
value = self.to_v(hidden_states)
|
| 140 |
+
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
| 141 |
+
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
| 142 |
+
|
| 143 |
+
key = self.reshape_heads_to_batch_dim(key)
|
| 144 |
+
value = self.reshape_heads_to_batch_dim(value)
|
| 145 |
+
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
| 146 |
+
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
| 147 |
+
|
| 148 |
+
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
| 149 |
+
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
| 150 |
+
else:
|
| 151 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
| 152 |
+
key = self.to_k(encoder_hidden_states)
|
| 153 |
+
value = self.to_v(encoder_hidden_states)
|
| 154 |
+
|
| 155 |
+
key = self.reshape_heads_to_batch_dim(key)
|
| 156 |
+
value = self.reshape_heads_to_batch_dim(value)
|
| 157 |
+
|
| 158 |
+
if attention_mask is not None:
|
| 159 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
| 160 |
+
target_length = query.shape[1]
|
| 161 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 162 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
| 163 |
+
|
| 164 |
+
# attention, what we cannot get enough of
|
| 165 |
+
if XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers:
|
| 166 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
| 167 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
| 168 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 169 |
+
else:
|
| 170 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
| 171 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
| 172 |
+
else:
|
| 173 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
| 174 |
+
|
| 175 |
+
# linear proj
|
| 176 |
+
hidden_states = self.to_out[0](hidden_states)
|
| 177 |
+
|
| 178 |
+
# dropout
|
| 179 |
+
hidden_states = self.to_out[1](hidden_states)
|
| 180 |
+
return hidden_states
|
| 181 |
+
|
| 182 |
+
def _attention(self, query, key, value, attention_mask=None):
|
| 183 |
+
if self.upcast_attention:
|
| 184 |
+
query = query.float()
|
| 185 |
+
key = key.float()
|
| 186 |
+
|
| 187 |
+
attention_scores = torch.baddbmm(
|
| 188 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
| 189 |
+
query,
|
| 190 |
+
key.transpose(-1, -2),
|
| 191 |
+
beta=0,
|
| 192 |
+
alpha=self.scale,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if attention_mask is not None:
|
| 196 |
+
attention_scores = attention_scores + attention_mask
|
| 197 |
+
|
| 198 |
+
if self.upcast_softmax:
|
| 199 |
+
attention_scores = attention_scores.float()
|
| 200 |
+
|
| 201 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 202 |
+
|
| 203 |
+
# cast back to the original dtype
|
| 204 |
+
attention_probs = attention_probs.to(value.dtype)
|
| 205 |
+
|
| 206 |
+
# compute attention output
|
| 207 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 208 |
+
|
| 209 |
+
# reshape hidden_states
|
| 210 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
| 211 |
+
return hidden_states
|
| 212 |
+
|
| 213 |
+
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
| 214 |
+
batch_size_attention = query.shape[0]
|
| 215 |
+
hidden_states = torch.zeros(
|
| 216 |
+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
| 217 |
+
)
|
| 218 |
+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
| 219 |
+
for i in range(hidden_states.shape[0] // slice_size):
|
| 220 |
+
start_idx = i * slice_size
|
| 221 |
+
end_idx = (i + 1) * slice_size
|
| 222 |
+
|
| 223 |
+
query_slice = query[start_idx:end_idx]
|
| 224 |
+
key_slice = key[start_idx:end_idx]
|
| 225 |
+
|
| 226 |
+
if self.upcast_attention:
|
| 227 |
+
query_slice = query_slice.float()
|
| 228 |
+
key_slice = key_slice.float()
|
| 229 |
+
|
| 230 |
+
attn_slice = torch.baddbmm(
|
| 231 |
+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
| 232 |
+
query_slice,
|
| 233 |
+
key_slice.transpose(-1, -2),
|
| 234 |
+
beta=0,
|
| 235 |
+
alpha=self.scale,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if attention_mask is not None:
|
| 239 |
+
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
| 240 |
+
|
| 241 |
+
if self.upcast_softmax:
|
| 242 |
+
attn_slice = attn_slice.float()
|
| 243 |
+
|
| 244 |
+
attn_slice = attn_slice.softmax(dim=-1)
|
| 245 |
+
|
| 246 |
+
# cast back to the original dtype
|
| 247 |
+
attn_slice = attn_slice.to(value.dtype)
|
| 248 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
| 249 |
+
|
| 250 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
| 251 |
+
|
| 252 |
+
# reshape hidden_states
|
| 253 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
| 254 |
+
return hidden_states
|
| 255 |
+
|
| 256 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
| 257 |
+
if self.upcast_efficient_attention:
|
| 258 |
+
org_dtype = query.dtype
|
| 259 |
+
query = query.float()
|
| 260 |
+
key = key.float()
|
| 261 |
+
value = value.float()
|
| 262 |
+
if attention_mask is not None:
|
| 263 |
+
attention_mask = attention_mask.float()
|
| 264 |
+
hidden_states = self._memory_efficient_attention_split(query, key, value, attention_mask)
|
| 265 |
+
|
| 266 |
+
if self.upcast_efficient_attention:
|
| 267 |
+
hidden_states = hidden_states.to(org_dtype)
|
| 268 |
+
|
| 269 |
+
hidden_states = self.reshape_4d_to_heads(hidden_states)
|
| 270 |
+
return hidden_states
|
| 271 |
+
|
| 272 |
+
# print("Errror: no xformers")
|
| 273 |
+
# raise NotImplementedError
|
| 274 |
+
|
| 275 |
+
def _memory_efficient_attention_split(self, query, key, value, attention_mask):
|
| 276 |
+
batch_size = query.shape[0]
|
| 277 |
+
max_batch_size = 65535
|
| 278 |
+
num_batches = (batch_size + max_batch_size - 1) // max_batch_size
|
| 279 |
+
results = []
|
| 280 |
+
for i in range(num_batches):
|
| 281 |
+
start_idx = i * max_batch_size
|
| 282 |
+
end_idx = min((i + 1) * max_batch_size, batch_size)
|
| 283 |
+
query_batch = query[start_idx:end_idx]
|
| 284 |
+
key_batch = key[start_idx:end_idx]
|
| 285 |
+
value_batch = value[start_idx:end_idx]
|
| 286 |
+
if attention_mask is not None:
|
| 287 |
+
attention_mask_batch = attention_mask[start_idx:end_idx]
|
| 288 |
+
else:
|
| 289 |
+
attention_mask_batch = None
|
| 290 |
+
result = xformers.ops.memory_efficient_attention(query_batch, key_batch, value_batch, attn_bias=attention_mask_batch)
|
| 291 |
+
results.append(result)
|
| 292 |
+
full_result = torch.cat(results, dim=0)
|
| 293 |
+
return full_result
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class FeedForward(nn.Module):
|
| 297 |
+
r"""
|
| 298 |
+
A feed-forward layer.
|
| 299 |
+
|
| 300 |
+
Parameters:
|
| 301 |
+
dim (`int`): The number of channels in the input.
|
| 302 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
| 303 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
| 304 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 305 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
def __init__(
|
| 309 |
+
self,
|
| 310 |
+
dim: int,
|
| 311 |
+
dim_out: Optional[int] = None,
|
| 312 |
+
mult: int = 4,
|
| 313 |
+
dropout: float = 0.0,
|
| 314 |
+
activation_fn: str = "geglu",
|
| 315 |
+
):
|
| 316 |
+
super().__init__()
|
| 317 |
+
inner_dim = int(dim * mult)
|
| 318 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 319 |
+
|
| 320 |
+
if activation_fn == "gelu":
|
| 321 |
+
act_fn = GELU(dim, inner_dim)
|
| 322 |
+
elif activation_fn == "geglu":
|
| 323 |
+
act_fn = GEGLU(dim, inner_dim)
|
| 324 |
+
elif activation_fn == "geglu-approximate":
|
| 325 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
| 326 |
+
|
| 327 |
+
self.net = nn.ModuleList([])
|
| 328 |
+
# project in
|
| 329 |
+
self.net.append(act_fn)
|
| 330 |
+
# project dropout
|
| 331 |
+
self.net.append(nn.Dropout(dropout))
|
| 332 |
+
# project out
|
| 333 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
| 334 |
+
|
| 335 |
+
def forward(self, hidden_states):
|
| 336 |
+
for module in self.net:
|
| 337 |
+
hidden_states = module(hidden_states)
|
| 338 |
+
return hidden_states
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class GELU(nn.Module):
|
| 342 |
+
r"""
|
| 343 |
+
GELU activation function
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def __init__(self, dim_in: int, dim_out: int):
|
| 347 |
+
super().__init__()
|
| 348 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
| 349 |
+
|
| 350 |
+
def gelu(self, gate):
|
| 351 |
+
if gate.device.type != "mps":
|
| 352 |
+
return F.gelu(gate)
|
| 353 |
+
# mps: gelu is not implemented for float16
|
| 354 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
| 355 |
+
|
| 356 |
+
def forward(self, hidden_states):
|
| 357 |
+
hidden_states = self.proj(hidden_states)
|
| 358 |
+
hidden_states = self.gelu(hidden_states)
|
| 359 |
+
return hidden_states
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# feedforward
|
| 363 |
+
class GEGLU(nn.Module):
|
| 364 |
+
r"""
|
| 365 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
| 366 |
+
|
| 367 |
+
Parameters:
|
| 368 |
+
dim_in (`int`): The number of channels in the input.
|
| 369 |
+
dim_out (`int`): The number of channels in the output.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
def __init__(self, dim_in: int, dim_out: int):
|
| 373 |
+
super().__init__()
|
| 374 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
| 375 |
+
|
| 376 |
+
def gelu(self, gate):
|
| 377 |
+
if gate.device.type != "mps":
|
| 378 |
+
return F.gelu(gate)
|
| 379 |
+
# mps: gelu is not implemented for float16
|
| 380 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
| 381 |
+
|
| 382 |
+
def forward(self, hidden_states):
|
| 383 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
| 384 |
+
return hidden_states * self.gelu(gate)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class ApproximateGELU(nn.Module):
|
| 388 |
+
"""
|
| 389 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
| 390 |
+
|
| 391 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
| 392 |
+
"""
|
| 393 |
+
|
| 394 |
+
def __init__(self, dim_in: int, dim_out: int):
|
| 395 |
+
super().__init__()
|
| 396 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
| 397 |
+
|
| 398 |
+
def forward(self, x):
|
| 399 |
+
x = self.proj(x)
|
| 400 |
+
return x * torch.sigmoid(1.702 * x)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
| 404 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 405 |
+
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
| 406 |
+
freqs = torch.outer(t, freqs)
|
| 407 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 408 |
+
return freqs_cis
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 412 |
+
ndim = x.ndim
|
| 413 |
+
assert 0 <= 1 < ndim
|
| 414 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
| 415 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 416 |
+
return freqs_cis.view(*shape)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def apply_rotary_emb(
|
| 420 |
+
xq: torch.Tensor,
|
| 421 |
+
xk: torch.Tensor,
|
| 422 |
+
freqs_cis: torch.Tensor,
|
| 423 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 424 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).contiguous())
|
| 425 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).contiguous())
|
| 426 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 427 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
|
| 428 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
|
| 429 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
code_depth/video_depth_anything/motion_module/motion_module.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0 license
|
| 3 |
+
#
|
| 4 |
+
# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
|
| 5 |
+
# Original file was released under [ Apache-2.0 license], with the full license text available at [https://github.com/guoyww/AnimateDiff?tab=Apache-2.0-1-ov-file#readme].
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from .attention import CrossAttention, FeedForward, apply_rotary_emb, precompute_freqs_cis
|
| 11 |
+
|
| 12 |
+
from einops import rearrange, repeat
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import xformers
|
| 17 |
+
import xformers.ops
|
| 18 |
+
|
| 19 |
+
XFORMERS_AVAILABLE = True
|
| 20 |
+
except ImportError:
|
| 21 |
+
print("xFormers not available")
|
| 22 |
+
XFORMERS_AVAILABLE = False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def zero_module(module):
|
| 26 |
+
# Zero out the parameters of a module and return it.
|
| 27 |
+
for p in module.parameters():
|
| 28 |
+
p.detach().zero_()
|
| 29 |
+
return module
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TemporalModule(nn.Module):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
in_channels,
|
| 36 |
+
num_attention_heads = 8,
|
| 37 |
+
num_transformer_block = 2,
|
| 38 |
+
num_attention_blocks = 2,
|
| 39 |
+
norm_num_groups = 32,
|
| 40 |
+
temporal_max_len = 32,
|
| 41 |
+
zero_initialize = True,
|
| 42 |
+
pos_embedding_type = "ape",
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
| 47 |
+
in_channels=in_channels,
|
| 48 |
+
num_attention_heads=num_attention_heads,
|
| 49 |
+
attention_head_dim=in_channels // num_attention_heads,
|
| 50 |
+
num_layers=num_transformer_block,
|
| 51 |
+
num_attention_blocks=num_attention_blocks,
|
| 52 |
+
norm_num_groups=norm_num_groups,
|
| 53 |
+
temporal_max_len=temporal_max_len,
|
| 54 |
+
pos_embedding_type=pos_embedding_type,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if zero_initialize:
|
| 58 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
| 59 |
+
|
| 60 |
+
def forward(self, input_tensor, encoder_hidden_states, attention_mask=None):
|
| 61 |
+
hidden_states = input_tensor
|
| 62 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
| 63 |
+
|
| 64 |
+
output = hidden_states
|
| 65 |
+
return output
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class TemporalTransformer3DModel(nn.Module):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
in_channels,
|
| 72 |
+
num_attention_heads,
|
| 73 |
+
attention_head_dim,
|
| 74 |
+
num_layers,
|
| 75 |
+
num_attention_blocks = 2,
|
| 76 |
+
norm_num_groups = 32,
|
| 77 |
+
temporal_max_len = 32,
|
| 78 |
+
pos_embedding_type = "ape",
|
| 79 |
+
):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 83 |
+
|
| 84 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 85 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 86 |
+
|
| 87 |
+
self.transformer_blocks = nn.ModuleList(
|
| 88 |
+
[
|
| 89 |
+
TemporalTransformerBlock(
|
| 90 |
+
dim=inner_dim,
|
| 91 |
+
num_attention_heads=num_attention_heads,
|
| 92 |
+
attention_head_dim=attention_head_dim,
|
| 93 |
+
num_attention_blocks=num_attention_blocks,
|
| 94 |
+
temporal_max_len=temporal_max_len,
|
| 95 |
+
pos_embedding_type=pos_embedding_type,
|
| 96 |
+
)
|
| 97 |
+
for d in range(num_layers)
|
| 98 |
+
]
|
| 99 |
+
)
|
| 100 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
| 101 |
+
|
| 102 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 103 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
| 104 |
+
video_length = hidden_states.shape[2]
|
| 105 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
| 106 |
+
|
| 107 |
+
batch, channel, height, width = hidden_states.shape
|
| 108 |
+
residual = hidden_states
|
| 109 |
+
|
| 110 |
+
hidden_states = self.norm(hidden_states)
|
| 111 |
+
inner_dim = hidden_states.shape[1]
|
| 112 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim).contiguous()
|
| 113 |
+
hidden_states = self.proj_in(hidden_states)
|
| 114 |
+
|
| 115 |
+
# Transformer Blocks
|
| 116 |
+
for block in self.transformer_blocks:
|
| 117 |
+
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask)
|
| 118 |
+
|
| 119 |
+
# output
|
| 120 |
+
hidden_states = self.proj_out(hidden_states)
|
| 121 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 122 |
+
|
| 123 |
+
output = hidden_states + residual
|
| 124 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
| 125 |
+
|
| 126 |
+
return output
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class TemporalTransformerBlock(nn.Module):
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
dim,
|
| 133 |
+
num_attention_heads,
|
| 134 |
+
attention_head_dim,
|
| 135 |
+
num_attention_blocks = 2,
|
| 136 |
+
temporal_max_len = 32,
|
| 137 |
+
pos_embedding_type = "ape",
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
|
| 141 |
+
self.attention_blocks = nn.ModuleList(
|
| 142 |
+
[
|
| 143 |
+
TemporalAttention(
|
| 144 |
+
query_dim=dim,
|
| 145 |
+
heads=num_attention_heads,
|
| 146 |
+
dim_head=attention_head_dim,
|
| 147 |
+
temporal_max_len=temporal_max_len,
|
| 148 |
+
pos_embedding_type=pos_embedding_type,
|
| 149 |
+
)
|
| 150 |
+
for i in range(num_attention_blocks)
|
| 151 |
+
]
|
| 152 |
+
)
|
| 153 |
+
self.norms = nn.ModuleList(
|
| 154 |
+
[
|
| 155 |
+
nn.LayerNorm(dim)
|
| 156 |
+
for i in range(num_attention_blocks)
|
| 157 |
+
]
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.ff = FeedForward(dim, dropout=0.0, activation_fn="geglu")
|
| 161 |
+
self.ff_norm = nn.LayerNorm(dim)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
| 165 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
| 166 |
+
norm_hidden_states = norm(hidden_states)
|
| 167 |
+
hidden_states = attention_block(
|
| 168 |
+
norm_hidden_states,
|
| 169 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 170 |
+
video_length=video_length,
|
| 171 |
+
attention_mask=attention_mask,
|
| 172 |
+
) + hidden_states
|
| 173 |
+
|
| 174 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
| 175 |
+
|
| 176 |
+
output = hidden_states
|
| 177 |
+
return output
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class PositionalEncoding(nn.Module):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
d_model,
|
| 184 |
+
dropout = 0.,
|
| 185 |
+
max_len = 32
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 189 |
+
position = torch.arange(max_len).unsqueeze(1)
|
| 190 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
| 191 |
+
pe = torch.zeros(1, max_len, d_model)
|
| 192 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
| 193 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
| 194 |
+
self.register_buffer('pe', pe)
|
| 195 |
+
|
| 196 |
+
def forward(self, x):
|
| 197 |
+
x = x + self.pe[:, :x.size(1)].to(x.dtype)
|
| 198 |
+
return self.dropout(x)
|
| 199 |
+
|
| 200 |
+
class TemporalAttention(CrossAttention):
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
temporal_max_len = 32,
|
| 204 |
+
pos_embedding_type = "ape",
|
| 205 |
+
*args, **kwargs
|
| 206 |
+
):
|
| 207 |
+
super().__init__(*args, **kwargs)
|
| 208 |
+
|
| 209 |
+
self.pos_embedding_type = pos_embedding_type
|
| 210 |
+
self._use_memory_efficient_attention_xformers = True
|
| 211 |
+
|
| 212 |
+
self.pos_encoder = None
|
| 213 |
+
self.freqs_cis = None
|
| 214 |
+
if self.pos_embedding_type == "ape":
|
| 215 |
+
self.pos_encoder = PositionalEncoding(
|
| 216 |
+
kwargs["query_dim"],
|
| 217 |
+
dropout=0.,
|
| 218 |
+
max_len=temporal_max_len
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
elif self.pos_embedding_type == "rope":
|
| 222 |
+
self.freqs_cis = precompute_freqs_cis(
|
| 223 |
+
kwargs["query_dim"],
|
| 224 |
+
temporal_max_len
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
else:
|
| 228 |
+
raise NotImplementedError
|
| 229 |
+
|
| 230 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
| 231 |
+
d = hidden_states.shape[1]
|
| 232 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
| 233 |
+
|
| 234 |
+
if self.pos_encoder is not None:
|
| 235 |
+
hidden_states = self.pos_encoder(hidden_states)
|
| 236 |
+
|
| 237 |
+
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
| 238 |
+
|
| 239 |
+
if self.group_norm is not None:
|
| 240 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 241 |
+
|
| 242 |
+
query = self.to_q(hidden_states)
|
| 243 |
+
dim = query.shape[-1]
|
| 244 |
+
|
| 245 |
+
if self.added_kv_proj_dim is not None:
|
| 246 |
+
raise NotImplementedError
|
| 247 |
+
|
| 248 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
| 249 |
+
key = self.to_k(encoder_hidden_states)
|
| 250 |
+
value = self.to_v(encoder_hidden_states)
|
| 251 |
+
|
| 252 |
+
if self.freqs_cis is not None:
|
| 253 |
+
seq_len = query.shape[1]
|
| 254 |
+
freqs_cis = self.freqs_cis[:seq_len].to(query.device)
|
| 255 |
+
query, key = apply_rotary_emb(query, key, freqs_cis)
|
| 256 |
+
|
| 257 |
+
if attention_mask is not None:
|
| 258 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
| 259 |
+
target_length = query.shape[1]
|
| 260 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 261 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
use_memory_efficient = XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers
|
| 265 |
+
if use_memory_efficient and (dim // self.heads) % 8 != 0:
|
| 266 |
+
# print('Warning: the dim {} cannot be divided by 8. Fall into normal attention'.format(dim // self.heads))
|
| 267 |
+
use_memory_efficient = False
|
| 268 |
+
|
| 269 |
+
# attention, what we cannot get enough of
|
| 270 |
+
if use_memory_efficient:
|
| 271 |
+
query = self.reshape_heads_to_4d(query)
|
| 272 |
+
key = self.reshape_heads_to_4d(key)
|
| 273 |
+
value = self.reshape_heads_to_4d(value)
|
| 274 |
+
|
| 275 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
| 276 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
| 277 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 278 |
+
else:
|
| 279 |
+
query = self.reshape_heads_to_batch_dim(query)
|
| 280 |
+
key = self.reshape_heads_to_batch_dim(key)
|
| 281 |
+
value = self.reshape_heads_to_batch_dim(value)
|
| 282 |
+
|
| 283 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
| 284 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
| 285 |
+
else:
|
| 286 |
+
raise NotImplementedError
|
| 287 |
+
# hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
| 288 |
+
|
| 289 |
+
# linear proj
|
| 290 |
+
hidden_states = self.to_out[0](hidden_states)
|
| 291 |
+
|
| 292 |
+
# dropout
|
| 293 |
+
hidden_states = self.to_out[1](hidden_states)
|
| 294 |
+
|
| 295 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
| 296 |
+
|
| 297 |
+
return hidden_states
|
code_depth/video_depth_anything/util/blocks.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
| 5 |
+
scratch = nn.Module()
|
| 6 |
+
|
| 7 |
+
out_shape1 = out_shape
|
| 8 |
+
out_shape2 = out_shape
|
| 9 |
+
out_shape3 = out_shape
|
| 10 |
+
if len(in_shape) >= 4:
|
| 11 |
+
out_shape4 = out_shape
|
| 12 |
+
|
| 13 |
+
if expand:
|
| 14 |
+
out_shape1 = out_shape
|
| 15 |
+
out_shape2 = out_shape * 2
|
| 16 |
+
out_shape3 = out_shape * 4
|
| 17 |
+
if len(in_shape) >= 4:
|
| 18 |
+
out_shape4 = out_shape * 8
|
| 19 |
+
|
| 20 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 21 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 22 |
+
)
|
| 23 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 24 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 25 |
+
)
|
| 26 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 27 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 28 |
+
)
|
| 29 |
+
if len(in_shape) >= 4:
|
| 30 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 31 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
return scratch
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ResidualConvUnit(nn.Module):
|
| 38 |
+
"""Residual convolution module."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, features, activation, bn):
|
| 41 |
+
"""Init.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
features (int): number of features
|
| 45 |
+
"""
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
self.bn = bn
|
| 49 |
+
|
| 50 |
+
self.groups = 1
|
| 51 |
+
|
| 52 |
+
self.conv1 = nn.Conv2d(
|
| 53 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.conv2 = nn.Conv2d(
|
| 57 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
if self.bn is True:
|
| 61 |
+
self.bn1 = nn.BatchNorm2d(features)
|
| 62 |
+
self.bn2 = nn.BatchNorm2d(features)
|
| 63 |
+
|
| 64 |
+
self.activation = activation
|
| 65 |
+
|
| 66 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
"""Forward pass.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
x (tensor): input
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
tensor: output
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
out = self.activation(x)
|
| 79 |
+
out = self.conv1(out)
|
| 80 |
+
if self.bn is True:
|
| 81 |
+
out = self.bn1(out)
|
| 82 |
+
|
| 83 |
+
out = self.activation(out)
|
| 84 |
+
out = self.conv2(out)
|
| 85 |
+
if self.bn is True:
|
| 86 |
+
out = self.bn2(out)
|
| 87 |
+
|
| 88 |
+
if self.groups > 1:
|
| 89 |
+
out = self.conv_merge(out)
|
| 90 |
+
|
| 91 |
+
return self.skip_add.add(out, x)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class FeatureFusionBlock(nn.Module):
|
| 95 |
+
"""Feature fusion block."""
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
features,
|
| 100 |
+
activation,
|
| 101 |
+
deconv=False,
|
| 102 |
+
bn=False,
|
| 103 |
+
expand=False,
|
| 104 |
+
align_corners=True,
|
| 105 |
+
size=None,
|
| 106 |
+
):
|
| 107 |
+
"""Init.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
features (int): number of features
|
| 111 |
+
"""
|
| 112 |
+
super().__init__()
|
| 113 |
+
|
| 114 |
+
self.deconv = deconv
|
| 115 |
+
self.align_corners = align_corners
|
| 116 |
+
|
| 117 |
+
self.groups = 1
|
| 118 |
+
|
| 119 |
+
self.expand = expand
|
| 120 |
+
out_features = features
|
| 121 |
+
if self.expand is True:
|
| 122 |
+
out_features = features // 2
|
| 123 |
+
|
| 124 |
+
self.out_conv = nn.Conv2d(
|
| 125 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
| 129 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
| 130 |
+
|
| 131 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 132 |
+
|
| 133 |
+
self.size = size
|
| 134 |
+
|
| 135 |
+
def forward(self, *xs, size=None):
|
| 136 |
+
"""Forward pass.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
tensor: output
|
| 140 |
+
"""
|
| 141 |
+
output = xs[0]
|
| 142 |
+
|
| 143 |
+
if len(xs) == 2:
|
| 144 |
+
res = self.resConfUnit1(xs[1])
|
| 145 |
+
output = self.skip_add.add(output, res)
|
| 146 |
+
|
| 147 |
+
output = self.resConfUnit2(output)
|
| 148 |
+
|
| 149 |
+
if (size is None) and (self.size is None):
|
| 150 |
+
modifier = {"scale_factor": 2}
|
| 151 |
+
elif size is None:
|
| 152 |
+
modifier = {"size": self.size}
|
| 153 |
+
else:
|
| 154 |
+
modifier = {"size": size}
|
| 155 |
+
|
| 156 |
+
output = nn.functional.interpolate(
|
| 157 |
+
output.contiguous(), **modifier, mode="bilinear", align_corners=self.align_corners
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
output = self.out_conv(output)
|
| 161 |
+
|
| 162 |
+
return output
|
code_depth/video_depth_anything/util/transform.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Resize(object):
|
| 6 |
+
"""Resize sample to given size (width, height).
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
width,
|
| 12 |
+
height,
|
| 13 |
+
resize_target=True,
|
| 14 |
+
keep_aspect_ratio=False,
|
| 15 |
+
ensure_multiple_of=1,
|
| 16 |
+
resize_method="lower_bound",
|
| 17 |
+
image_interpolation_method=cv2.INTER_AREA,
|
| 18 |
+
):
|
| 19 |
+
"""Init.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
width (int): desired output width
|
| 23 |
+
height (int): desired output height
|
| 24 |
+
resize_target (bool, optional):
|
| 25 |
+
True: Resize the full sample (image, mask, target).
|
| 26 |
+
False: Resize image only.
|
| 27 |
+
Defaults to True.
|
| 28 |
+
keep_aspect_ratio (bool, optional):
|
| 29 |
+
True: Keep the aspect ratio of the input sample.
|
| 30 |
+
Output sample might not have the given width and height, and
|
| 31 |
+
resize behaviour depends on the parameter 'resize_method'.
|
| 32 |
+
Defaults to False.
|
| 33 |
+
ensure_multiple_of (int, optional):
|
| 34 |
+
Output width and height is constrained to be multiple of this parameter.
|
| 35 |
+
Defaults to 1.
|
| 36 |
+
resize_method (str, optional):
|
| 37 |
+
"lower_bound": Output will be at least as large as the given size.
|
| 38 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
| 39 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
| 40 |
+
Defaults to "lower_bound".
|
| 41 |
+
"""
|
| 42 |
+
self.__width = width
|
| 43 |
+
self.__height = height
|
| 44 |
+
|
| 45 |
+
self.__resize_target = resize_target
|
| 46 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
| 47 |
+
self.__multiple_of = ensure_multiple_of
|
| 48 |
+
self.__resize_method = resize_method
|
| 49 |
+
self.__image_interpolation_method = image_interpolation_method
|
| 50 |
+
|
| 51 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
| 52 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 53 |
+
|
| 54 |
+
if max_val is not None and y > max_val:
|
| 55 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 56 |
+
|
| 57 |
+
if y < min_val:
|
| 58 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 59 |
+
|
| 60 |
+
return y
|
| 61 |
+
|
| 62 |
+
def get_size(self, width, height):
|
| 63 |
+
# determine new height and width
|
| 64 |
+
scale_height = self.__height / height
|
| 65 |
+
scale_width = self.__width / width
|
| 66 |
+
|
| 67 |
+
if self.__keep_aspect_ratio:
|
| 68 |
+
if self.__resize_method == "lower_bound":
|
| 69 |
+
# scale such that output size is lower bound
|
| 70 |
+
if scale_width > scale_height:
|
| 71 |
+
# fit width
|
| 72 |
+
scale_height = scale_width
|
| 73 |
+
else:
|
| 74 |
+
# fit height
|
| 75 |
+
scale_width = scale_height
|
| 76 |
+
elif self.__resize_method == "upper_bound":
|
| 77 |
+
# scale such that output size is upper bound
|
| 78 |
+
if scale_width < scale_height:
|
| 79 |
+
# fit width
|
| 80 |
+
scale_height = scale_width
|
| 81 |
+
else:
|
| 82 |
+
# fit height
|
| 83 |
+
scale_width = scale_height
|
| 84 |
+
elif self.__resize_method == "minimal":
|
| 85 |
+
# scale as least as possbile
|
| 86 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
| 87 |
+
# fit width
|
| 88 |
+
scale_height = scale_width
|
| 89 |
+
else:
|
| 90 |
+
# fit height
|
| 91 |
+
scale_width = scale_height
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
| 94 |
+
|
| 95 |
+
if self.__resize_method == "lower_bound":
|
| 96 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
| 97 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
| 98 |
+
elif self.__resize_method == "upper_bound":
|
| 99 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
| 100 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
| 101 |
+
elif self.__resize_method == "minimal":
|
| 102 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
| 103 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
| 106 |
+
|
| 107 |
+
return (new_width, new_height)
|
| 108 |
+
|
| 109 |
+
def __call__(self, sample):
|
| 110 |
+
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
| 111 |
+
|
| 112 |
+
# resize sample
|
| 113 |
+
sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
|
| 114 |
+
|
| 115 |
+
if self.__resize_target:
|
| 116 |
+
if "depth" in sample:
|
| 117 |
+
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
| 118 |
+
|
| 119 |
+
if "mask" in sample:
|
| 120 |
+
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
|
| 121 |
+
|
| 122 |
+
return sample
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class NormalizeImage(object):
|
| 126 |
+
"""Normlize image by given mean and std.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, mean, std):
|
| 130 |
+
self.__mean = mean
|
| 131 |
+
self.__std = std
|
| 132 |
+
|
| 133 |
+
def __call__(self, sample):
|
| 134 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
| 135 |
+
|
| 136 |
+
return sample
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class PrepareForNet(object):
|
| 140 |
+
"""Prepare sample for usage as network input.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(self):
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
def __call__(self, sample):
|
| 147 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
| 148 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
| 149 |
+
|
| 150 |
+
if "depth" in sample:
|
| 151 |
+
depth = sample["depth"].astype(np.float32)
|
| 152 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
| 153 |
+
|
| 154 |
+
if "mask" in sample:
|
| 155 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
| 156 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
| 157 |
+
|
| 158 |
+
return sample
|
code_depth/video_depth_anything/video_depth.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torchvision.transforms import Compose
|
| 18 |
+
import cv2
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import numpy as np
|
| 21 |
+
import gc
|
| 22 |
+
|
| 23 |
+
from .dinov2 import DINOv2
|
| 24 |
+
from .dpt_temporal import DPTHeadTemporal
|
| 25 |
+
from .util.transform import Resize, NormalizeImage, PrepareForNet
|
| 26 |
+
|
| 27 |
+
from utils.util import compute_scale_and_shift, get_interpolate_frames
|
| 28 |
+
|
| 29 |
+
# infer settings, do not change
|
| 30 |
+
INFER_LEN = 32
|
| 31 |
+
OVERLAP = 10
|
| 32 |
+
KEYFRAMES = [0,12,24,25,26,27,28,29,30,31]
|
| 33 |
+
INTERP_LEN = 8
|
| 34 |
+
|
| 35 |
+
class VideoDepthAnything(nn.Module):
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
encoder='vitl',
|
| 39 |
+
features=256,
|
| 40 |
+
out_channels=[256, 512, 1024, 1024],
|
| 41 |
+
use_bn=False,
|
| 42 |
+
use_clstoken=False,
|
| 43 |
+
num_frames=32,
|
| 44 |
+
pe='ape'
|
| 45 |
+
):
|
| 46 |
+
super(VideoDepthAnything, self).__init__()
|
| 47 |
+
|
| 48 |
+
self.intermediate_layer_idx = {
|
| 49 |
+
'vits': [2, 5, 8, 11],
|
| 50 |
+
'vitl': [4, 11, 17, 23]
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
self.encoder = encoder
|
| 54 |
+
self.pretrained = DINOv2(model_name=encoder)
|
| 55 |
+
|
| 56 |
+
self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
B, T, C, H, W = x.shape
|
| 60 |
+
patch_h, patch_w = H // 14, W // 14
|
| 61 |
+
features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True)
|
| 62 |
+
depth = self.head(features, patch_h, patch_w, T)
|
| 63 |
+
depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True)
|
| 64 |
+
depth = F.relu(depth)
|
| 65 |
+
return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W]
|
| 66 |
+
|
| 67 |
+
def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda', fp32=False):
|
| 68 |
+
frame_height, frame_width = frames[0].shape[:2]
|
| 69 |
+
ratio = max(frame_height, frame_width) / min(frame_height, frame_width)
|
| 70 |
+
if ratio > 1.78: # we recommend to process video with ratio smaller than 16:9 due to memory limitation
|
| 71 |
+
input_size = int(input_size * 1.777 / ratio)
|
| 72 |
+
input_size = round(input_size / 14) * 14
|
| 73 |
+
|
| 74 |
+
transform = Compose([
|
| 75 |
+
Resize(
|
| 76 |
+
width=input_size,
|
| 77 |
+
height=input_size,
|
| 78 |
+
resize_target=False,
|
| 79 |
+
keep_aspect_ratio=True,
|
| 80 |
+
ensure_multiple_of=14,
|
| 81 |
+
resize_method='lower_bound',
|
| 82 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
| 83 |
+
),
|
| 84 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 85 |
+
PrepareForNet(),
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
frame_list = [frames[i] for i in range(frames.shape[0])]
|
| 89 |
+
frame_step = INFER_LEN - OVERLAP
|
| 90 |
+
org_video_len = len(frame_list)
|
| 91 |
+
append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step)
|
| 92 |
+
frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len
|
| 93 |
+
|
| 94 |
+
depth_list = []
|
| 95 |
+
pre_input = None
|
| 96 |
+
for frame_id in tqdm(range(0, org_video_len, frame_step)):
|
| 97 |
+
cur_list = []
|
| 98 |
+
for i in range(INFER_LEN):
|
| 99 |
+
cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0))
|
| 100 |
+
cur_input = torch.cat(cur_list, dim=1).to(device)
|
| 101 |
+
if pre_input is not None:
|
| 102 |
+
cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...]
|
| 103 |
+
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
with torch.autocast(device_type=device, enabled=(not fp32)):
|
| 106 |
+
depth = self.forward(cur_input) # depth shape: [1, T, H, W]
|
| 107 |
+
|
| 108 |
+
depth = depth.to(cur_input.dtype)
|
| 109 |
+
depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True)
|
| 110 |
+
depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]
|
| 111 |
+
|
| 112 |
+
pre_input = cur_input
|
| 113 |
+
|
| 114 |
+
del frame_list
|
| 115 |
+
gc.collect()
|
| 116 |
+
|
| 117 |
+
depth_list_aligned = []
|
| 118 |
+
ref_align = []
|
| 119 |
+
align_len = OVERLAP - INTERP_LEN
|
| 120 |
+
kf_align_list = KEYFRAMES[:align_len]
|
| 121 |
+
|
| 122 |
+
for frame_id in range(0, len(depth_list), INFER_LEN):
|
| 123 |
+
if len(depth_list_aligned) == 0:
|
| 124 |
+
depth_list_aligned += depth_list[:INFER_LEN]
|
| 125 |
+
for kf_id in kf_align_list:
|
| 126 |
+
ref_align.append(depth_list[frame_id+kf_id])
|
| 127 |
+
else:
|
| 128 |
+
curr_align = []
|
| 129 |
+
for i in range(len(kf_align_list)):
|
| 130 |
+
curr_align.append(depth_list[frame_id+i])
|
| 131 |
+
scale, shift = compute_scale_and_shift(np.concatenate(curr_align),
|
| 132 |
+
np.concatenate(ref_align),
|
| 133 |
+
np.concatenate(np.ones_like(ref_align)==1))
|
| 134 |
+
|
| 135 |
+
pre_depth_list = depth_list_aligned[-INTERP_LEN:]
|
| 136 |
+
post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP]
|
| 137 |
+
for i in range(len(post_depth_list)):
|
| 138 |
+
post_depth_list[i] = post_depth_list[i] * scale + shift
|
| 139 |
+
post_depth_list[i][post_depth_list[i]<0] = 0
|
| 140 |
+
depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list)
|
| 141 |
+
|
| 142 |
+
for i in range(OVERLAP, INFER_LEN):
|
| 143 |
+
new_depth = depth_list[frame_id+i] * scale + shift
|
| 144 |
+
new_depth[new_depth<0] = 0
|
| 145 |
+
depth_list_aligned.append(new_depth)
|
| 146 |
+
|
| 147 |
+
ref_align = ref_align[:1]
|
| 148 |
+
for kf_id in kf_align_list[1:]:
|
| 149 |
+
new_depth = depth_list[frame_id+kf_id] * scale + shift
|
| 150 |
+
new_depth[new_depth<0] = 0
|
| 151 |
+
ref_align.append(new_depth)
|
| 152 |
+
|
| 153 |
+
depth_list = depth_list_aligned
|
| 154 |
+
|
| 155 |
+
return np.stack(depth_list[:org_video_len], axis=0), target_fps
|
| 156 |
+
|
code_edit/.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
code_edit/Flux_fill_d2i.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import FluxFillPipeline_token12_depth as FluxFillPipeline
|
| 3 |
+
from diffusers.utils import load_image
|
| 4 |
+
import os, glob
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
image_path = ["example_data/I-210618_I01001_W01_I-210618_I01001_W01_F0153_img.jpg"]
|
| 10 |
+
|
| 11 |
+
pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda")
|
| 12 |
+
pipe.load_lora_weights("stage2/checkpoint-20000")
|
| 13 |
+
for image_ep in image_path:
|
| 14 |
+
image = Image.open(image_ep)
|
| 15 |
+
mask = Image.new("L", image.size, 0) # place_hold
|
| 16 |
+
depth_path = image_ep.replace("_img.jpg", "_depth_img.png")
|
| 17 |
+
depth_image = Image.open(depth_path)
|
| 18 |
+
depth = Image.open(depth_path.replace("_img", "_img_fill_in"))
|
| 19 |
+
image_name = os.path.basename(image_ep)
|
| 20 |
+
|
| 21 |
+
orig_w, orig_h = image.size
|
| 22 |
+
w, h = image.size
|
| 23 |
+
MAX_SIZE = 1024
|
| 24 |
+
if max(w, h) > MAX_SIZE:
|
| 25 |
+
factor = MAX_SIZE / max(w, h)
|
| 26 |
+
w = int(factor * w)
|
| 27 |
+
h = int(factor * h)
|
| 28 |
+
width, height = map(lambda x: x - x % 64, (w, h))
|
| 29 |
+
# # Resize to 1024 × 1024
|
| 30 |
+
target_size = (width, height)
|
| 31 |
+
# target_size = (1024, 1024)
|
| 32 |
+
# image_resized = image.resize(target_size, Image.BICUBIC)
|
| 33 |
+
# mask_resized = mask.resize(target_size, Image.NEAREST)
|
| 34 |
+
# depth_resized = depth.resize(target_size, Image.BICUBIC)
|
| 35 |
+
# depth_image_resized = depth_image.resize(target_size, Image.BICUBIC)
|
| 36 |
+
|
| 37 |
+
image = pipe(
|
| 38 |
+
prompt="A beautiful scene",
|
| 39 |
+
image=image,
|
| 40 |
+
mask_image=mask,
|
| 41 |
+
width=target_size[0],
|
| 42 |
+
height=target_size[1],
|
| 43 |
+
guidance_scale=30,
|
| 44 |
+
num_inference_steps=50,
|
| 45 |
+
max_sequence_length=512,
|
| 46 |
+
generator=torch.Generator("cpu").manual_seed(0),
|
| 47 |
+
depth=depth,
|
| 48 |
+
depth_image=depth_image,
|
| 49 |
+
).images[0]
|
| 50 |
+
image_final = image.resize((orig_w * 3, orig_h), Image.BICUBIC)
|
| 51 |
+
output_dir = "./test_images/"
|
| 52 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 53 |
+
image_final.save(os.path.join(output_dir,image_name))
|
code_edit/Flux_fill_infer_depth.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import FluxFillPipeline_token12_depth_only as FluxFillPipeline
|
| 3 |
+
from diffusers.utils import load_image
|
| 4 |
+
import os, glob
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from PIL import Image, ImageOps
|
| 8 |
+
|
| 9 |
+
image_path = ["example_data/I-210618_I01001_W01_I-210618_I01001_W01_F0153_img.jpg"]
|
| 10 |
+
pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda")
|
| 11 |
+
pipe.load_lora_weights("stage1/checkpoint-4800")
|
| 12 |
+
for image_ep in image_path:
|
| 13 |
+
mask_path = image_ep.replace("_img.jpg","_mask.png")
|
| 14 |
+
image = Image.open(image_ep) # place_hold
|
| 15 |
+
depth = Image.open(image_ep.replace("_img.jpg",
|
| 16 |
+
"_depth_img.png"))
|
| 17 |
+
image_name = os.path.basename(image_ep)
|
| 18 |
+
mask = Image.open(mask_path).convert("L")
|
| 19 |
+
mask = ImageOps.invert(mask) # inverse rord_mask
|
| 20 |
+
|
| 21 |
+
# mask_np = np.array(mask)
|
| 22 |
+
|
| 23 |
+
# # mask dilation
|
| 24 |
+
# dilation_px = 32
|
| 25 |
+
# kernel = np.ones((3, 3), np.uint8)
|
| 26 |
+
# iterations = dilation_px // 2
|
| 27 |
+
# dilated_mask = cv2.dilate(mask_np, kernel, iterations=iterations)
|
| 28 |
+
# mask = Image.fromarray(dilated_mask)
|
| 29 |
+
|
| 30 |
+
orig_w, orig_h = image.size
|
| 31 |
+
|
| 32 |
+
# Resize to 1024 × 1024
|
| 33 |
+
# target_size = (1024, 1024)
|
| 34 |
+
# image_resized = image.resize(target_size, Image.BICUBIC)
|
| 35 |
+
# mask_resized = mask.resize(target_size, Image.NEAREST)
|
| 36 |
+
# depth_resized = depth.resize(target_size, Image.BICUBIC)
|
| 37 |
+
|
| 38 |
+
w, h = image.size
|
| 39 |
+
MAX_SIZE = 1024
|
| 40 |
+
if max(w, h) > MAX_SIZE:
|
| 41 |
+
factor = MAX_SIZE / max(w, h)
|
| 42 |
+
w = int(factor * w)
|
| 43 |
+
h = int(factor * h)
|
| 44 |
+
width, height = map(lambda x: x - x % 64, (w, h))
|
| 45 |
+
image_out = pipe(
|
| 46 |
+
prompt="A beautiful scene",
|
| 47 |
+
image=image,
|
| 48 |
+
mask_image=mask,
|
| 49 |
+
width=width,
|
| 50 |
+
height=height,
|
| 51 |
+
guidance_scale=30,
|
| 52 |
+
num_inference_steps=50,
|
| 53 |
+
max_sequence_length=512,
|
| 54 |
+
generator=torch.Generator("cpu").manual_seed(0),
|
| 55 |
+
depth=depth
|
| 56 |
+
).images[0]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
image_final = image_out.resize((orig_w, orig_h), Image.BICUBIC)
|
| 60 |
+
|
| 61 |
+
output_dir = "./depth_fillin_results"
|
| 62 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 63 |
+
image_final.save(os.path.join(output_dir, image_name))
|
| 64 |
+
|
code_edit/README.md
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The official implementation of the **NeurIPS 2025** paper:
|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
<h1>
|
| 5 |
+
<b>
|
| 6 |
+
GeoRemover: Removing Objects and Their Causal Visual Artifacts, NeurIPS, 2025 (Spotlight)
|
| 7 |
+
</b>
|
| 8 |
+
</h1>
|
| 9 |
+
</div>
|
| 10 |
+
|
| 11 |
+
<p align="center"><img src="docs/teaser.png" width="800"/></p>
|
| 12 |
+
|
| 13 |
+
> [**GeoRemover: Removing Objects and Their Causal Visual Artifacts**](https://arxiv.org/abs/2509.18538)
|
| 14 |
+
>
|
| 15 |
+
> Zixin Zhu, Haoxiang Li, Xuelu Feng, He Wu, Chunming Qiao, Junsong Yuan
|
| 16 |
+
|
| 17 |
+
> **Abstract:** *Towards intelligent image editing, object removal should eliminate both the target object and its causal visual artifacts, such as shadows and reflections. However, existing image appearance-based methods either follow strictly mask-aligned training and fail to remove these casual effects which are not explicitly masked, or adopt loosely mask-aligned strategies that lack controllability and may unintentionally over-erase other objects. We identify that these limitations stem from ignoring the causal relationship between an object’s geometry presence and its visual effects. To address this limitation, we propose a geometry-aware two-stage framework that decouples object removal into (1) geometry removal and (2) appearance rendering. In the first stage, we remove the object directly from the geometry (e.g., depth) using strictly mask-aligned supervision, enabling structure-aware editing with strong geometric constraints. In the second stage, we render a photorealistic RGB image conditioned on the updated geometry, where causal visual effects are considered implicitly as a result of the modified 3D geometry. To guide learning in the geometry removal stage, we introduce a preference-driven objective based on positive and negative sample pairs, encouraging the model to remove objects as well as their causal visual artifacts while avoiding new structural insertions. Extensive experiments demonstrate that our method achieves state-of-the-art performance in removing both objects and their associated artifacts on two popular benchmarks.*
|
| 18 |
+
|
| 19 |
+
### Installing the dependencies
|
| 20 |
+
|
| 21 |
+
Before running the scripts, make sure to install the library's training dependencies:
|
| 22 |
+
|
| 23 |
+
**Important**
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
bash env.sh
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
accelerate config
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Or for a default accelerate configuration without answering questions about your environment
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
accelerate config default
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Data prepare
|
| 42 |
+
Download the images on [RORD](https://github.com/Forty-lock/RORD) and generate depth maps with [Video-Depth-Anythingv2](https://github.com/DepthAnything/Video-Depth-Anything). (The code for VideoDepthAnything v2 can be found in the same repository, on the `depth` branch, using the [script](https://github.com/buxiangzhiren/GeoRemover/blob/depth/run_images_rord.py))
|
| 43 |
+
|
| 44 |
+
### Training
|
| 45 |
+
You should build your own *train_images_and_rord_masks.csv* first. The file in the repo is not the full RORD—it's just an example.
|
| 46 |
+
|
| 47 |
+
For stage1:geometry removal
|
| 48 |
+
```bash
|
| 49 |
+
bash train_stage1.sh
|
| 50 |
+
```
|
| 51 |
+
For stage2:appearance rendering
|
| 52 |
+
```bash
|
| 53 |
+
bash train_stage2.sh
|
| 54 |
+
```
|
| 55 |
+
### Inference
|
| 56 |
+
First, use https://github.com/buxiangzhiren/GeoRemover/blob/depth/run_single_image.py to get the depth of a image
|
| 57 |
+
|
| 58 |
+
For stage1:geometry removal
|
| 59 |
+
```bash
|
| 60 |
+
python Flux_fill_infer_depth.py
|
| 61 |
+
```
|
| 62 |
+
For stage2:appearance rendering
|
| 63 |
+
```bash
|
| 64 |
+
python Flux_fill_d2i.py
|
| 65 |
+
```
|
| 66 |
+
### Checkpoints
|
| 67 |
+
Hugging Face:
|
| 68 |
+
[stage1:geometry removal and stage2:appearance rendering](https://huggingface.co/buxiangzhiren/GeoRemover)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
Google drive:
|
| 72 |
+
[stage1:geometry removal](https://drive.google.com/file/d/1y6vnxqnFTiO6sxoKDBkvFbAeniHFka89/view?usp=sharing)
|
| 73 |
+
and [stage2:appearance rendering](https://drive.google.com/file/d/1U8rp1hqOswQB-0T0fh2aDQu-o1GLfd6E/view?usp=sharing)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
### Acknowledgement
|
| 77 |
+
|
| 78 |
+
This repo is based on [RORD](https://github.com/Forty-lock/RORD), [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) and [Video-Depth-Anythingv2](https://github.com/DepthAnything/Video-Depth-Anything). Thanks for their wonderful works.
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
### Citation
|
| 82 |
+
|
| 83 |
+
```
|
| 84 |
+
@misc{zhu2025georemoverremovingobjectscausal,
|
| 85 |
+
title={GeoRemover: Removing Objects and Their Causal Visual Artifacts},
|
| 86 |
+
author={Zixin Zhu and Haoxiang Li and Xuelu Feng and He Wu and Chunming Qiao and Junsong Yuan},
|
| 87 |
+
year={2025},
|
| 88 |
+
eprint={2509.18538},
|
| 89 |
+
archivePrefix={arXiv},
|
| 90 |
+
primaryClass={cs.CV},
|
| 91 |
+
url={https://arxiv.org/abs/2509.18538},
|
| 92 |
+
}
|
| 93 |
+
```
|