toshas commited on
Commit
c732904
·
1 Parent(s): e8416d0

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea
2
+ .DS_Store
3
+ __pycache__
README.md CHANGED
@@ -1,13 +1,29 @@
1
  ---
2
- title: Marigold Lcm
3
- emoji: 🚀
4
- colorFrom: indigo
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.22.0
8
  app_file: app.py
9
- pinned: false
10
- license: apache-2.0
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Marigold-LCM Depth Estimation
3
+ emoji: 🏵️
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.22.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: cc-by-sa-4.0
11
+ models:
12
+ - prs-eth/marigold-v1-0
13
+ - prs-eth/marigold-lcm-v1-0
14
  ---
15
 
16
+ This is a demo of Marigold-LCM, the state-of-the-art depth estimator for images in the wild.
17
+ It combines the power of the original Marigold 10-step estimator and the Latent Consistency Models, delivering high-quality results in as little as one step.
18
+ Find out more in our paper titled ["Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation"](https://arxiv.org/abs/2312.02145)
19
+
20
+ ```
21
+ @misc{ke2023repurposing,
22
+ title={Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation},
23
+ author={Bingxin Ke and Anton Obukhov and Shengyu Huang and Nando Metzger and Rodrigo Caye Daudt and Konrad Schindler},
24
+ year={2023},
25
+ eprint={2312.02145},
26
+ archivePrefix={arXiv},
27
+ primaryClass={cs.CV}
28
+ }
29
+ ```
app.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import shutil
4
+ import zipfile
5
+ from io import BytesIO
6
+
7
+ import gradio as gr
8
+ import imageio as imageio
9
+ import numpy as np
10
+ import torch as torch
11
+ from PIL import Image
12
+ from diffusers import UNet2DConditionModel, LCMScheduler
13
+ from gradio_imageslider import ImageSlider
14
+ from huggingface_hub import login
15
+ from tqdm import tqdm
16
+
17
+ from extrude import extrude_depth_3d
18
+ from marigold_depth_estimation_lcm import MarigoldDepthConsistencyPipeline
19
+
20
+ default_seed = 2024
21
+
22
+ default_image_denoise_steps = 4
23
+ default_image_ensemble_size = 1
24
+ default_image_processing_res = 768
25
+ default_image_reproducuble = True
26
+
27
+ default_video_depth_latent_init_strength = 0.1
28
+ default_video_denoise_steps = 1
29
+ default_video_ensemble_size = 1
30
+ default_video_processing_res = 768
31
+ default_video_out_fps = 15
32
+ default_video_out_max_frames = 100
33
+
34
+ default_bas_plane_near = 0.0
35
+ default_bas_plane_far = 1.0
36
+ default_bas_embossing = 20
37
+ default_bas_denoise_steps = 4
38
+ default_bas_ensemble_size = 1
39
+ default_bas_processing_res = 768
40
+ default_bas_size_longest_px = 512
41
+ default_bas_size_longest_cm = 10
42
+ default_bas_filter_size = 3
43
+ default_bas_frame_thickness = 5
44
+ default_bas_frame_near = 1
45
+ default_bas_frame_far = 1
46
+
47
+
48
+ def process_image(
49
+ pipe,
50
+ path_input,
51
+ denoise_steps=default_image_denoise_steps,
52
+ ensemble_size=default_image_ensemble_size,
53
+ processing_res=default_image_processing_res,
54
+ reproducible=default_image_reproducuble,
55
+ ):
56
+ input_image = Image.open(path_input)
57
+
58
+ pipe_out = pipe(
59
+ input_image,
60
+ denoising_steps=denoise_steps,
61
+ ensemble_size=ensemble_size,
62
+ processing_res=processing_res,
63
+ batch_size=1 if processing_res == 0 else 0,
64
+ seed=default_seed if reproducible else None,
65
+ show_progress_bar=False,
66
+ )
67
+
68
+ depth_pred = pipe_out.depth_np
69
+ depth_colored = pipe_out.depth_colored
70
+ depth_16bit = (depth_pred * 65535.0).astype(np.uint16)
71
+
72
+ path_output_dir = os.path.splitext(path_input)[0] + "_output"
73
+ os.makedirs(path_output_dir, exist_ok=True)
74
+
75
+ name_base = os.path.splitext(os.path.basename(path_input))[0]
76
+ path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy")
77
+ path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png")
78
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png")
79
+
80
+ np.save(path_out_fp32, depth_pred)
81
+ Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16")
82
+ depth_colored.save(path_out_vis)
83
+
84
+ return (
85
+ [path_out_16bit, path_out_vis],
86
+ [path_out_16bit, path_out_fp32, path_out_vis],
87
+ )
88
+
89
+
90
+ def process_video(
91
+ pipe,
92
+ path_input,
93
+ depth_latent_init_strength=default_video_depth_latent_init_strength,
94
+ denoise_steps=default_video_denoise_steps,
95
+ ensemble_size=default_video_ensemble_size,
96
+ processing_res=default_video_processing_res,
97
+ out_fps=default_video_out_fps,
98
+ out_max_frames=default_video_out_max_frames,
99
+ progress=gr.Progress(),
100
+ ):
101
+ path_output_dir = os.path.splitext(path_input)[0] + "_output"
102
+ os.makedirs(path_output_dir, exist_ok=True)
103
+
104
+ name_base = os.path.splitext(os.path.basename(path_input))[0]
105
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4")
106
+ path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip")
107
+
108
+ reader = imageio.get_reader(path_input)
109
+
110
+ meta_data = reader.get_meta_data()
111
+ fps = meta_data["fps"]
112
+ size = meta_data["size"]
113
+ duration_sec = meta_data["duration"]
114
+
115
+ if fps <= out_fps:
116
+ frame_interval, out_fps = 1, fps
117
+ else:
118
+ frame_interval = round(fps / out_fps)
119
+ out_fps = fps / frame_interval
120
+
121
+ out_duration_sec = out_max_frames / out_fps
122
+ if duration_sec > out_duration_sec:
123
+ gr.Warning(
124
+ f"Only the first ~{int(out_duration_sec)} seconds will be processed; "
125
+ f"use alternative setups for full processing"
126
+ )
127
+
128
+ writer = imageio.get_writer(path_out_vis, fps=out_fps)
129
+ zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED)
130
+ prev_depth_latent = None
131
+
132
+ pbar = tqdm(desc="Processing Video", total=out_max_frames)
133
+
134
+ out_frame_id = 0
135
+ for frame_id, frame in enumerate(reader):
136
+ if not (frame_id % frame_interval == 0):
137
+ continue
138
+ out_frame_id += 1
139
+ pbar.update(1)
140
+ if out_frame_id > out_max_frames:
141
+ break
142
+
143
+ frame_pil = Image.fromarray(frame)
144
+
145
+ pipe_out = pipe(
146
+ frame_pil,
147
+ denoising_steps=denoise_steps,
148
+ ensemble_size=ensemble_size,
149
+ processing_res=processing_res,
150
+ match_input_res=False,
151
+ batch_size=0,
152
+ depth_latent_init=prev_depth_latent,
153
+ depth_latent_init_strength=depth_latent_init_strength,
154
+ seed=default_seed,
155
+ show_progress_bar=False,
156
+ )
157
+
158
+ prev_depth_latent = pipe_out.depth_latent
159
+
160
+ processed_frame = pipe_out.depth_colored
161
+ processed_frame = imageio.core.util.Array(np.array(processed_frame))
162
+ writer.append_data(processed_frame)
163
+
164
+ processed_frame = (65535 * np.clip(pipe_out.depth_np, 0.0, 1.0)).astype(
165
+ np.uint16
166
+ )
167
+ processed_frame = Image.fromarray(processed_frame, mode="I;16")
168
+
169
+ archive_path = os.path.join(
170
+ f"{name_base}_depth_16bit", f"{out_frame_id:05d}.png"
171
+ )
172
+ img_byte_arr = BytesIO()
173
+ processed_frame.save(img_byte_arr, format="png")
174
+ img_byte_arr.seek(0)
175
+ zipf.writestr(archive_path, img_byte_arr.read())
176
+
177
+ reader.close()
178
+ writer.close()
179
+ zipf.close()
180
+
181
+ return (
182
+ path_out_vis,
183
+ [path_out_vis, path_out_16bit],
184
+ )
185
+
186
+
187
+ def process_bas(
188
+ pipe,
189
+ path_input,
190
+ plane_near=default_bas_plane_near,
191
+ plane_far=default_bas_plane_far,
192
+ embossing=default_bas_embossing,
193
+ denoise_steps=default_bas_denoise_steps,
194
+ ensemble_size=default_bas_ensemble_size,
195
+ processing_res=default_bas_processing_res,
196
+ size_longest_px=default_bas_size_longest_px,
197
+ size_longest_cm=default_bas_size_longest_cm,
198
+ filter_size=default_bas_filter_size,
199
+ frame_thickness=default_bas_frame_thickness,
200
+ frame_near=default_bas_frame_near,
201
+ frame_far=default_bas_frame_far,
202
+ ):
203
+ if plane_near >= plane_far:
204
+ raise gr.Error("NEAR plane must have a value smaller than the FAR plane")
205
+
206
+ path_output_dir = os.path.splitext(path_input)[0] + "_output"
207
+ os.makedirs(path_output_dir, exist_ok=True)
208
+
209
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
210
+
211
+ input_image = Image.open(path_input)
212
+
213
+ pipe_out = pipe(
214
+ input_image,
215
+ denoising_steps=denoise_steps,
216
+ ensemble_size=ensemble_size,
217
+ processing_res=processing_res,
218
+ seed=default_seed,
219
+ show_progress_bar=False,
220
+ )
221
+
222
+ depth_pred = pipe_out.depth_np * 65535
223
+
224
+ def _process_3d(
225
+ size_longest_px,
226
+ filter_size,
227
+ vertex_colors,
228
+ scene_lights,
229
+ output_model_scale=None,
230
+ prepare_for_3d_printing=False,
231
+ ):
232
+ image_rgb_w, image_rgb_h = input_image.width, input_image.height
233
+ image_rgb_d = max(image_rgb_w, image_rgb_h)
234
+ image_new_w = size_longest_px * image_rgb_w // image_rgb_d
235
+ image_new_h = size_longest_px * image_rgb_h // image_rgb_d
236
+
237
+ image_rgb_new = os.path.join(
238
+ path_output_dir, f"{name_base}_rgb_{size_longest_px}{name_ext}"
239
+ )
240
+ image_depth_new = os.path.join(
241
+ path_output_dir, f"{name_base}_depth_{size_longest_px}.png"
242
+ )
243
+ input_image.resize((image_new_w, image_new_h), Image.LANCZOS).save(
244
+ image_rgb_new
245
+ )
246
+ Image.fromarray(depth_pred).convert(mode="F").resize(
247
+ (image_new_w, image_new_h), Image.BILINEAR
248
+ ).convert("I").save(image_depth_new)
249
+
250
+ path_glb, path_stl = extrude_depth_3d(
251
+ image_rgb_new,
252
+ image_depth_new,
253
+ output_model_scale=size_longest_cm * 10
254
+ if output_model_scale is None
255
+ else output_model_scale,
256
+ filter_size=filter_size,
257
+ coef_near=plane_near,
258
+ coef_far=plane_far,
259
+ emboss=embossing / 100,
260
+ f_thic=frame_thickness / 100,
261
+ f_near=frame_near / 100,
262
+ f_back=frame_far / 100,
263
+ vertex_colors=vertex_colors,
264
+ scene_lights=scene_lights,
265
+ prepare_for_3d_printing=prepare_for_3d_printing,
266
+ )
267
+
268
+ return path_glb, path_stl
269
+
270
+ path_viewer_glb, _ = _process_3d(
271
+ 256, filter_size, vertex_colors=False, scene_lights=True, output_model_scale=1
272
+ )
273
+ path_files_glb, path_files_stl = _process_3d(
274
+ size_longest_px, filter_size, vertex_colors=True, scene_lights=False, prepare_for_3d_printing=True
275
+ )
276
+
277
+ return path_viewer_glb, [path_files_glb, path_files_stl]
278
+
279
+
280
+ def run_demo_server(pipe):
281
+ process_pipe_image = functools.partial(process_image, pipe)
282
+ process_pipe_video = functools.partial(process_video, pipe)
283
+ process_pipe_bas = functools.partial(process_bas, pipe)
284
+ os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
285
+
286
+ gradio_theme = gr.themes.Default()
287
+ # gradio_theme.set(
288
+ # section_header_text_size="20px",
289
+ # section_header_text_weight="bold",
290
+ # )
291
+
292
+ with gr.Blocks(
293
+ theme=gradio_theme,
294
+ title="Marigold-LCM Depth Estimation",
295
+ css="""
296
+ #download {
297
+ height: 118px;
298
+ }
299
+ .slider .inner {
300
+ width: 5px;
301
+ background: #FFF;
302
+ }
303
+ .viewport {
304
+ aspect-ratio: 4/3;
305
+ }
306
+ .tabs button.selected {
307
+ font-size: 20px !important;
308
+ color: crimson !important;
309
+ }
310
+ """,
311
+ head="""
312
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
313
+ <script>
314
+ window.dataLayer = window.dataLayer || [];
315
+ function gtag() {dataLayer.push(arguments);}
316
+ gtag('js', new Date());
317
+ gtag('config', 'G-1FWSVCGZTG');
318
+ </script>
319
+ """,
320
+ ) as demo:
321
+ gr.Markdown(
322
+ """
323
+ <h1 align="center">Marigold-LCM Depth Estimation</h1>
324
+ <p align="center">
325
+ <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
326
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
327
+ </a>
328
+ <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
329
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
330
+ </a>
331
+ <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
332
+ <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
333
+ </a>
334
+ <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
335
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
336
+ </a>
337
+ </p>
338
+ <p align="justify">
339
+ Marigold-LCM is the fast version of Marigold, the state-of-the-art depth estimator for images in the wild.
340
+ It combines the power of the original Marigold 10-step estimator and the Latent Consistency Models, delivering high-quality results in as little as <b>one step</b>.
341
+ We provide three functions in this demo: Image, Video, and Bas-relief 3D processing — <b>see the tabs below</b>.
342
+ Upload your content into the <b>left</b> side, or click any of the <b>examples</b> below.
343
+ Wait a second (for images and 3D) or a minute (for videos), and interact with the result in the <b>right</b> side.
344
+ To avoid queuing, fork the demo into your profile.
345
+ </p>
346
+ """
347
+ )
348
+
349
+ with gr.Tabs(elem_classes=["tabs"]):
350
+ with gr.Tab("Image"):
351
+ with gr.Row():
352
+ with gr.Column():
353
+ image_input = gr.Image(
354
+ label="Input Image",
355
+ type="filepath",
356
+ )
357
+ with gr.Row():
358
+ image_submit_btn = gr.Button(
359
+ value="Compute Depth", variant="primary"
360
+ )
361
+ image_reset_btn = gr.Button(value="Reset")
362
+ with gr.Accordion("Advanced options", open=False):
363
+ image_denoise_steps = gr.Slider(
364
+ label="Number of denoising steps",
365
+ minimum=1,
366
+ maximum=4,
367
+ step=1,
368
+ value=default_image_denoise_steps,
369
+ )
370
+ image_ensemble_size = gr.Slider(
371
+ label="Ensemble size",
372
+ minimum=1,
373
+ maximum=10,
374
+ step=1,
375
+ value=default_image_ensemble_size,
376
+ )
377
+ image_processing_res = gr.Radio(
378
+ [
379
+ ("Native", 0),
380
+ ("Recommended", 768),
381
+ ],
382
+ label="Processing resolution",
383
+ value=default_image_processing_res,
384
+ )
385
+ with gr.Column():
386
+ image_output_slider = ImageSlider(
387
+ label="Predicted depth (red-near, blue-far)",
388
+ type="filepath",
389
+ show_download_button=True,
390
+ show_share_button=True,
391
+ interactive=False,
392
+ elem_classes="slider",
393
+ position=0.25,
394
+ )
395
+ image_output_files = gr.Files(
396
+ label="Depth outputs",
397
+ elem_id="download",
398
+ interactive=False,
399
+ )
400
+ gr.Examples(
401
+ fn=process_pipe_image,
402
+ examples=[
403
+ os.path.join("files", "image", name)
404
+ for name in [
405
+ "arc.jpeg",
406
+ "berries.jpeg",
407
+ "butterfly.jpeg",
408
+ "cat.jpg",
409
+ "concert.jpeg",
410
+ "dog.jpeg",
411
+ "doughnuts.jpeg",
412
+ "einstein.jpg",
413
+ "food.jpeg",
414
+ "glasses.jpeg",
415
+ "house.jpg",
416
+ "lake.jpeg",
417
+ "marigold.jpeg",
418
+ "portrait_1.jpeg",
419
+ "portrait_2.jpeg",
420
+ "pumpkins.jpg",
421
+ "puzzle.jpeg",
422
+ "road.jpg",
423
+ "scientists.jpg",
424
+ "surfboards.jpeg",
425
+ "surfer.jpeg",
426
+ "swings.jpg",
427
+ "switzerland.jpeg",
428
+ "teamwork.jpeg",
429
+ "wave.jpeg",
430
+ ]
431
+ ],
432
+ inputs=[image_input],
433
+ outputs=[image_output_slider, image_output_files],
434
+ cache_examples=True,
435
+ )
436
+
437
+ with gr.Tab("Video"):
438
+ with gr.Row():
439
+ with gr.Column():
440
+ video_input = gr.Video(
441
+ label="Input Video",
442
+ sources=["upload"],
443
+ )
444
+ with gr.Row():
445
+ video_submit_btn = gr.Button(
446
+ value="Compute Depth", variant="primary"
447
+ )
448
+ video_reset_btn = gr.Button(value="Reset")
449
+ with gr.Column():
450
+ video_output_video = gr.Video(
451
+ label="Output video depth (red-near, blue-far)",
452
+ interactive=False,
453
+ )
454
+ video_output_files = gr.Files(
455
+ label="Depth outputs",
456
+ elem_id="download",
457
+ interactive=False,
458
+ )
459
+ gr.Examples(
460
+ fn=process_pipe_video,
461
+ examples=[
462
+ os.path.join("files", "video", name)
463
+ for name in [
464
+ "cab.mp4",
465
+ "elephant.mp4",
466
+ "obama.mp4",
467
+ ]
468
+ ],
469
+ inputs=[video_input],
470
+ outputs=[video_output_video, video_output_files],
471
+ cache_examples=True,
472
+ )
473
+
474
+ with gr.Tab("Bas-relief (3D)"):
475
+ gr.Markdown(
476
+ """
477
+ <p align="justify">
478
+ This part of the demo uses Marigold-LCM to create a bas-relief model.
479
+ The models are watertight, with correct normals, and exported in the STL format, which makes them <b>3D-printable</b>.
480
+ Start by uploading the image and click "Create" with the default parameters.
481
+ To improve the result, click "Clear", adjust the geometry sliders below, and click "Create" again.
482
+ </p>
483
+ """,
484
+ )
485
+ with gr.Row():
486
+ with gr.Column():
487
+ bas_input = gr.Image(
488
+ label="Input Image",
489
+ type="filepath",
490
+ )
491
+ with gr.Row():
492
+ bas_submit_btn = gr.Button(value="Create 3D", variant="primary")
493
+ bas_clear_btn = gr.Button(value="Clear")
494
+ bas_reset_btn = gr.Button(value="Reset")
495
+ with gr.Accordion("3D printing demo: Main options", open=True):
496
+ bas_plane_near = gr.Slider(
497
+ label="Relative position of the near plane (between 0 and 1)",
498
+ minimum=0.0,
499
+ maximum=1.0,
500
+ step=0.001,
501
+ value=default_bas_plane_near,
502
+ )
503
+ bas_plane_far = gr.Slider(
504
+ label="Relative position of the far plane (between near and 1)",
505
+ minimum=0.0,
506
+ maximum=1.0,
507
+ step=0.001,
508
+ value=default_bas_plane_far,
509
+ )
510
+ bas_embossing = gr.Slider(
511
+ label="Embossing level",
512
+ minimum=0,
513
+ maximum=100,
514
+ step=1,
515
+ value=default_bas_embossing,
516
+ )
517
+ with gr.Accordion("3D printing demo: Advanced options", open=False):
518
+ bas_denoise_steps = gr.Slider(
519
+ label="Number of denoising steps",
520
+ minimum=1,
521
+ maximum=4,
522
+ step=1,
523
+ value=default_bas_denoise_steps,
524
+ )
525
+ bas_ensemble_size = gr.Slider(
526
+ label="Ensemble size",
527
+ minimum=1,
528
+ maximum=10,
529
+ step=1,
530
+ value=default_bas_ensemble_size,
531
+ )
532
+ bas_processing_res = gr.Radio(
533
+ [
534
+ ("Native", 0),
535
+ ("Recommended", 768),
536
+ ],
537
+ label="Processing resolution",
538
+ value=default_bas_processing_res,
539
+ )
540
+ bas_size_longest_px = gr.Slider(
541
+ label="Size (px) of the longest side",
542
+ minimum=256,
543
+ maximum=1024,
544
+ step=256,
545
+ value=default_bas_size_longest_px,
546
+ )
547
+ bas_size_longest_cm = gr.Slider(
548
+ label="Size (cm) of the longest side",
549
+ minimum=1,
550
+ maximum=100,
551
+ step=1,
552
+ value=default_bas_size_longest_cm,
553
+ )
554
+ bas_filter_size = gr.Slider(
555
+ label="Size (px) of the smoothing filter",
556
+ minimum=1,
557
+ maximum=5,
558
+ step=2,
559
+ value=default_bas_filter_size,
560
+ )
561
+ bas_frame_thickness = gr.Slider(
562
+ label="Frame thickness",
563
+ minimum=0,
564
+ maximum=100,
565
+ step=1,
566
+ value=default_bas_frame_thickness,
567
+ )
568
+ bas_frame_near = gr.Slider(
569
+ label="Frame's near plane offset",
570
+ minimum=-100,
571
+ maximum=100,
572
+ step=1,
573
+ value=default_bas_frame_near,
574
+ )
575
+ bas_frame_far = gr.Slider(
576
+ label="Frame's far plane offset",
577
+ minimum=1,
578
+ maximum=10,
579
+ step=1,
580
+ value=default_bas_frame_far,
581
+ )
582
+ with gr.Column():
583
+ bas_output_viewer = gr.Model3D(
584
+ camera_position=(75.0, 90.0, 1.25),
585
+ elem_classes="viewport",
586
+ label="3D preview (low-res, relief highlight)",
587
+ interactive=False,
588
+ )
589
+ bas_output_files = gr.Files(
590
+ label="3D model outputs (high-res)",
591
+ elem_id="download",
592
+ interactive=False,
593
+ )
594
+ gr.Examples(
595
+ fn=process_pipe_bas,
596
+ examples=[
597
+ [
598
+ "files/basrelief/coin.jpg", # input
599
+ 0.0, # plane_near
600
+ 0.66, # plane_far
601
+ 15, # embossing
602
+ 4, # denoise_steps
603
+ 4, # ensemble_size
604
+ 768, # processing_res
605
+ 512, # size_longest_px
606
+ 10, # size_longest_cm
607
+ 3, # filter_size
608
+ 5, # frame_thickness
609
+ 0, # frame_near
610
+ 1, # frame_far
611
+ ],
612
+ [
613
+ "files/basrelief/einstein.jpg", # input
614
+ 0.0, # plane_near
615
+ 0.5, # plane_far
616
+ 50, # embossing
617
+ 2, # denoise_steps
618
+ 1, # ensemble_size
619
+ 768, # processing_res
620
+ 512, # size_longest_px
621
+ 10, # size_longest_cm
622
+ 3, # filter_size
623
+ 5, # frame_thickness
624
+ -15, # frame_near
625
+ 1, # frame_far
626
+ ],
627
+ [
628
+ "files/basrelief/food.jpeg", # input
629
+ 0.0, # plane_near
630
+ 1.0, # plane_far
631
+ 20, # embossing
632
+ 2, # denoise_steps
633
+ 4, # ensemble_size
634
+ 768, # processing_res
635
+ 512, # size_longest_px
636
+ 10, # size_longest_cm
637
+ 3, # filter_size
638
+ 5, # frame_thickness
639
+ -5, # frame_near
640
+ 1, # frame_far
641
+ ],
642
+ ],
643
+ inputs=[
644
+ bas_input,
645
+ bas_plane_near,
646
+ bas_plane_far,
647
+ bas_embossing,
648
+ bas_denoise_steps,
649
+ bas_ensemble_size,
650
+ bas_processing_res,
651
+ bas_size_longest_px,
652
+ bas_size_longest_cm,
653
+ bas_filter_size,
654
+ bas_frame_thickness,
655
+ bas_frame_near,
656
+ bas_frame_far,
657
+ ],
658
+ outputs=[bas_output_viewer, bas_output_files],
659
+ cache_examples=True,
660
+ )
661
+
662
+ image_submit_btn.click(
663
+ fn=process_pipe_image,
664
+ inputs=[
665
+ image_input,
666
+ image_denoise_steps,
667
+ image_ensemble_size,
668
+ image_processing_res,
669
+ ],
670
+ outputs=[image_output_slider, image_output_files],
671
+ concurrency_limit=1,
672
+ )
673
+
674
+ image_reset_btn.click(
675
+ fn=lambda: (
676
+ None,
677
+ None,
678
+ None,
679
+ default_image_ensemble_size,
680
+ default_image_denoise_steps,
681
+ default_image_processing_res,
682
+ ),
683
+ inputs=[],
684
+ outputs=[
685
+ image_input,
686
+ image_output_slider,
687
+ image_output_files,
688
+ image_ensemble_size,
689
+ image_denoise_steps,
690
+ image_processing_res,
691
+ ],
692
+ concurrency_limit=1,
693
+ )
694
+
695
+ video_submit_btn.click(
696
+ fn=process_pipe_video,
697
+ inputs=[video_input],
698
+ outputs=[video_output_video, video_output_files],
699
+ concurrency_limit=1,
700
+ )
701
+
702
+ video_reset_btn.click(
703
+ fn=lambda: (None, None, None),
704
+ inputs=[],
705
+ outputs=[video_input, video_output_video, video_output_files],
706
+ concurrency_limit=1,
707
+ )
708
+
709
+ def wrapper_process_pipe_bas(*args, **kwargs):
710
+ out = list(process_pipe_bas(*args, **kwargs))
711
+ out = [gr.Button(interactive=False), gr.Image(interactive=False)] + out
712
+ return out
713
+
714
+ bas_submit_btn.click(
715
+ fn=wrapper_process_pipe_bas,
716
+ inputs=[
717
+ bas_input,
718
+ bas_plane_near,
719
+ bas_plane_far,
720
+ bas_embossing,
721
+ bas_denoise_steps,
722
+ bas_ensemble_size,
723
+ bas_processing_res,
724
+ bas_size_longest_px,
725
+ bas_size_longest_cm,
726
+ bas_filter_size,
727
+ bas_frame_thickness,
728
+ bas_frame_near,
729
+ bas_frame_far,
730
+ ],
731
+ outputs=[bas_submit_btn, bas_input, bas_output_viewer, bas_output_files],
732
+ concurrency_limit=1,
733
+ )
734
+
735
+ bas_clear_btn.click(
736
+ fn=lambda: (gr.Button(interactive=True), None, None),
737
+ inputs=[],
738
+ outputs=[
739
+ bas_submit_btn,
740
+ bas_output_viewer,
741
+ bas_output_files,
742
+ ],
743
+ concurrency_limit=1,
744
+ )
745
+
746
+ bas_reset_btn.click(
747
+ fn=lambda: (
748
+ gr.Button(interactive=True),
749
+ None,
750
+ None,
751
+ None,
752
+ default_bas_plane_near,
753
+ default_bas_plane_far,
754
+ default_bas_embossing,
755
+ default_bas_denoise_steps,
756
+ default_bas_ensemble_size,
757
+ default_bas_processing_res,
758
+ default_bas_size_longest_px,
759
+ default_bas_size_longest_cm,
760
+ default_bas_filter_size,
761
+ default_bas_frame_thickness,
762
+ default_bas_frame_near,
763
+ default_bas_frame_far,
764
+ ),
765
+ inputs=[],
766
+ outputs=[
767
+ bas_submit_btn,
768
+ bas_input,
769
+ bas_output_viewer,
770
+ bas_output_files,
771
+ bas_plane_near,
772
+ bas_plane_far,
773
+ bas_embossing,
774
+ bas_denoise_steps,
775
+ bas_ensemble_size,
776
+ bas_processing_res,
777
+ bas_size_longest_px,
778
+ bas_size_longest_cm,
779
+ bas_filter_size,
780
+ bas_frame_thickness,
781
+ bas_frame_near,
782
+ bas_frame_far,
783
+ ],
784
+ concurrency_limit=1,
785
+ )
786
+
787
+ demo.queue(
788
+ api_open=False,
789
+ ).launch(
790
+ server_name="0.0.0.0",
791
+ server_port=7860,
792
+ )
793
+
794
+
795
+ def prefetch_hf_cache(pipe):
796
+ process_image(pipe, "files/image/bee.jpg", 1, 1, 64)
797
+ shutil.rmtree("files/image/bee_output")
798
+
799
+
800
+ def main():
801
+ CHECKPOINT = "prs-eth/marigold-v1-0"
802
+ CHECKPOINT_UNET_LCM = "prs-eth/marigold-lcm-v1-0"
803
+
804
+ login(token=os.environ["HF_TOKEN_COLAB_RO"])
805
+
806
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
807
+
808
+ pipe = MarigoldDepthConsistencyPipeline.from_pretrained(
809
+ CHECKPOINT,
810
+ unet=UNet2DConditionModel.from_pretrained(
811
+ CHECKPOINT_UNET_LCM, subfolder="unet", use_auth_token=True
812
+ ),
813
+ )
814
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
815
+ try:
816
+ import xformers
817
+
818
+ pipe.enable_xformers_memory_efficient_attention()
819
+ except:
820
+ pass # run without xformers
821
+
822
+ pipe = pipe.to(device)
823
+ prefetch_hf_cache(pipe)
824
+ run_demo_server(pipe)
825
+
826
+
827
+ if __name__ == "__main__":
828
+ main()
extrude.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import numpy as np
5
+ import pygltflib
6
+ import trimesh
7
+ from PIL import Image, ImageFilter
8
+
9
+
10
+ def quaternion_multiply(q1, q2):
11
+ x1, y1, z1, w1 = q1
12
+ x2, y2, z2, w2 = q2
13
+ return [
14
+ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
15
+ w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
16
+ w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
17
+ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
18
+ ]
19
+
20
+
21
+ def glb_add_lights(path_input, path_output):
22
+ """
23
+ Adds directional lights in the horizontal plane to the glb file.
24
+ :param path_input: path to input glb
25
+ :param path_output: path to output glb
26
+ :return: None
27
+ """
28
+ glb = pygltflib.GLTF2().load(path_input)
29
+
30
+ N = 3 # default max num lights in Babylon.js is 4
31
+ angle_step = 2 * math.pi / N
32
+ elevation_angle = math.radians(75)
33
+
34
+ light_colors = [
35
+ [1.0, 0.0, 0.0],
36
+ [0.0, 1.0, 0.0],
37
+ [0.0, 0.0, 1.0],
38
+ ]
39
+
40
+ lights_extension = {
41
+ "lights": [
42
+ {"type": "directional", "color": light_colors[i], "intensity": 2.0}
43
+ for i in range(N)
44
+ ]
45
+ }
46
+
47
+ if "KHR_lights_punctual" not in glb.extensionsUsed:
48
+ glb.extensionsUsed.append("KHR_lights_punctual")
49
+ glb.extensions["KHR_lights_punctual"] = lights_extension
50
+
51
+ light_nodes = []
52
+ for i in range(N):
53
+ angle = i * angle_step
54
+
55
+ pos_rot = [0.0, 0.0, math.sin(angle / 2), math.cos(angle / 2)]
56
+ elev_rot = [
57
+ math.sin(elevation_angle / 2),
58
+ 0.0,
59
+ 0.0,
60
+ math.cos(elevation_angle / 2),
61
+ ]
62
+ rotation = quaternion_multiply(pos_rot, elev_rot)
63
+
64
+ node = {
65
+ "rotation": rotation,
66
+ "extensions": {"KHR_lights_punctual": {"light": i}},
67
+ }
68
+ light_nodes.append(node)
69
+
70
+ light_node_indices = list(range(len(glb.nodes), len(glb.nodes) + N))
71
+ glb.nodes.extend(light_nodes)
72
+
73
+ root_node_index = glb.scenes[glb.scene].nodes[0]
74
+ root_node = glb.nodes[root_node_index]
75
+ if hasattr(root_node, "children"):
76
+ root_node.children.extend(light_node_indices)
77
+ else:
78
+ root_node.children = light_node_indices
79
+
80
+ glb.save(path_output)
81
+
82
+
83
+ def extrude_depth_3d(
84
+ path_rgb,
85
+ path_depth,
86
+ output_model_scale=100,
87
+ filter_size=3,
88
+ coef_near=0.0,
89
+ coef_far=1.0,
90
+ emboss=0.3,
91
+ f_thic=0.05,
92
+ f_near=-0.15,
93
+ f_back=0.01,
94
+ vertex_colors=True,
95
+ scene_lights=True,
96
+ prepare_for_3d_printing=False,
97
+ ):
98
+ f_far_inner = -emboss
99
+ f_far_outer = f_far_inner - f_back
100
+
101
+ f_near = max(f_near, f_far_inner)
102
+
103
+ depth_image = Image.open(path_depth)
104
+ assert depth_image.mode == "I", depth_image.mode
105
+ depth_image = depth_image.filter(ImageFilter.MedianFilter(size=filter_size))
106
+
107
+ w, h = depth_image.size
108
+ d_max = max(w, h)
109
+ depth_image = np.array(depth_image).astype(np.double)
110
+ z_min, z_max = np.min(depth_image), np.max(depth_image)
111
+ depth_image = (depth_image.astype(np.double) - z_min) / (z_max - z_min)
112
+ depth_image[depth_image < coef_near] = coef_near
113
+ depth_image[depth_image > coef_far] = coef_far
114
+ depth_image = emboss * (depth_image - coef_near) / (coef_far - coef_near)
115
+ rgb_image = np.array(
116
+ Image.open(path_rgb).convert("RGB").resize((w, h), Image.Resampling.LANCZOS)
117
+ )
118
+
119
+ w_norm = w / float(d_max - 1)
120
+ h_norm = h / float(d_max - 1)
121
+ w_half = w_norm / 2
122
+ h_half = h_norm / 2
123
+
124
+ x, y = np.meshgrid(np.arange(w), np.arange(h))
125
+ x = x / float(d_max - 1) - w_half # [-w_half, w_half]
126
+ y = -y / float(d_max - 1) + h_half # [-h_half, h_half]
127
+ z = -depth_image # -depth_emboss (far) - 0 (near)
128
+ vertices_2d = np.stack((x, y, z), axis=-1)
129
+ vertices = vertices_2d.reshape(-1, 3)
130
+ colors = rgb_image[:, :, :3].reshape(-1, 3) / 255.0
131
+
132
+ faces = []
133
+ for y in range(h - 1):
134
+ for x in range(w - 1):
135
+ idx = y * w + x
136
+ faces.append([idx, idx + w, idx + 1])
137
+ faces.append([idx + 1, idx + w, idx + 1 + w])
138
+
139
+ # OUTER frame
140
+
141
+ nv = len(vertices)
142
+ vertices = np.append(
143
+ vertices,
144
+ [
145
+ [-w_half - f_thic, -h_half - f_thic, f_near], # 00
146
+ [-w_half - f_thic, -h_half - f_thic, f_far_outer], # 01
147
+ [w_half + f_thic, -h_half - f_thic, f_near], # 02
148
+ [w_half + f_thic, -h_half - f_thic, f_far_outer], # 03
149
+ [w_half + f_thic, h_half + f_thic, f_near], # 04
150
+ [w_half + f_thic, h_half + f_thic, f_far_outer], # 05
151
+ [-w_half - f_thic, h_half + f_thic, f_near], # 06
152
+ [-w_half - f_thic, h_half + f_thic, f_far_outer], # 07
153
+ ],
154
+ axis=0,
155
+ )
156
+ faces.extend(
157
+ [
158
+ [nv + 0, nv + 1, nv + 2],
159
+ [nv + 2, nv + 1, nv + 3],
160
+ [nv + 2, nv + 3, nv + 4],
161
+ [nv + 4, nv + 3, nv + 5],
162
+ [nv + 4, nv + 5, nv + 6],
163
+ [nv + 6, nv + 5, nv + 7],
164
+ [nv + 6, nv + 7, nv + 0],
165
+ [nv + 0, nv + 7, nv + 1],
166
+ ]
167
+ )
168
+ colors = np.append(colors, [[0.5, 0.5, 0.5]] * 8, axis=0)
169
+
170
+ # INNER frame
171
+
172
+ nv = len(vertices)
173
+ vertices_left_data = vertices_2d[:, 0] # H x 3
174
+ vertices_left_frame = vertices_2d[:, 0].copy() # H x 3
175
+ vertices_left_frame[:, 2] = f_near
176
+ vertices = np.append(vertices, vertices_left_data, axis=0)
177
+ vertices = np.append(vertices, vertices_left_frame, axis=0)
178
+ colors = np.append(colors, [[0.5, 0.5, 0.5]] * (2 * h), axis=0)
179
+ for i in range(h - 1):
180
+ nvi_d = nv + i
181
+ nvi_f = nvi_d + h
182
+ faces.append([nvi_d, nvi_f, nvi_d + 1])
183
+ faces.append([nvi_d + 1, nvi_f, nvi_f + 1])
184
+
185
+ nv = len(vertices)
186
+ vertices_right_data = vertices_2d[:, -1] # H x 3
187
+ vertices_right_frame = vertices_2d[:, -1].copy() # H x 3
188
+ vertices_right_frame[:, 2] = f_near
189
+ vertices = np.append(vertices, vertices_right_data, axis=0)
190
+ vertices = np.append(vertices, vertices_right_frame, axis=0)
191
+ colors = np.append(colors, [[0.5, 0.5, 0.5]] * (2 * h), axis=0)
192
+ for i in range(h - 1):
193
+ nvi_d = nv + i
194
+ nvi_f = nvi_d + h
195
+ faces.append([nvi_d, nvi_d + 1, nvi_f])
196
+ faces.append([nvi_d + 1, nvi_f + 1, nvi_f])
197
+
198
+ nv = len(vertices)
199
+ vertices_top_data = vertices_2d[0, :] # H x 3
200
+ vertices_top_frame = vertices_2d[0, :].copy() # H x 3
201
+ vertices_top_frame[:, 2] = f_near
202
+ vertices = np.append(vertices, vertices_top_data, axis=0)
203
+ vertices = np.append(vertices, vertices_top_frame, axis=0)
204
+ colors = np.append(colors, [[0.5, 0.5, 0.5]] * (2 * w), axis=0)
205
+ for i in range(w - 1):
206
+ nvi_d = nv + i
207
+ nvi_f = nvi_d + w
208
+ faces.append([nvi_d, nvi_d + 1, nvi_f])
209
+ faces.append([nvi_d + 1, nvi_f + 1, nvi_f])
210
+
211
+ nv = len(vertices)
212
+ vertices_bottom_data = vertices_2d[-1, :] # H x 3
213
+ vertices_bottom_frame = vertices_2d[-1, :].copy() # H x 3
214
+ vertices_bottom_frame[:, 2] = f_near
215
+ vertices = np.append(vertices, vertices_bottom_data, axis=0)
216
+ vertices = np.append(vertices, vertices_bottom_frame, axis=0)
217
+ colors = np.append(colors, [[0.5, 0.5, 0.5]] * (2 * w), axis=0)
218
+ for i in range(w - 1):
219
+ nvi_d = nv + i
220
+ nvi_f = nvi_d + w
221
+ faces.append([nvi_d, nvi_f, nvi_d + 1])
222
+ faces.append([nvi_d + 1, nvi_f, nvi_f + 1])
223
+
224
+ # FRONT frame
225
+
226
+ nv = len(vertices)
227
+ vertices = np.append(
228
+ vertices,
229
+ [
230
+ [-w_half - f_thic, -h_half - f_thic, f_near],
231
+ [-w_half - f_thic, h_half + f_thic, f_near],
232
+ ],
233
+ axis=0,
234
+ )
235
+ vertices = np.append(vertices, vertices_left_frame, axis=0)
236
+ colors = np.append(colors, [[0.5, 0.5, 0.5]] * (2 + h), axis=0)
237
+ for i in range(h - 1):
238
+ faces.append([nv, nv + 2 + i + 1, nv + 2 + i])
239
+ faces.append([nv, nv + 2, nv + 1])
240
+
241
+ nv = len(vertices)
242
+ vertices = np.append(
243
+ vertices,
244
+ [
245
+ [w_half + f_thic, h_half + f_thic, f_near],
246
+ [w_half + f_thic, -h_half - f_thic, f_near],
247
+ ],
248
+ axis=0,
249
+ )
250
+ vertices = np.append(vertices, vertices_right_frame, axis=0)
251
+ colors = np.append(colors, [[0.5, 0.5, 0.5]] * (2 + h), axis=0)
252
+ for i in range(h - 1):
253
+ faces.append([nv, nv + 2 + i, nv + 2 + i + 1])
254
+ faces.append([nv, nv + h + 1, nv + 1])
255
+
256
+ nv = len(vertices)
257
+ vertices = np.append(
258
+ vertices,
259
+ [
260
+ [w_half + f_thic, h_half + f_thic, f_near],
261
+ [-w_half - f_thic, h_half + f_thic, f_near],
262
+ ],
263
+ axis=0,
264
+ )
265
+ vertices = np.append(vertices, vertices_top_frame, axis=0)
266
+ colors = np.append(colors, [[0.5, 0.5, 0.5]] * (2 + w), axis=0)
267
+ for i in range(w - 1):
268
+ faces.append([nv, nv + 2 + i, nv + 2 + i + 1])
269
+ faces.append([nv, nv + 1, nv + 2])
270
+
271
+ nv = len(vertices)
272
+ vertices = np.append(
273
+ vertices,
274
+ [
275
+ [-w_half - f_thic, -h_half - f_thic, f_near],
276
+ [w_half + f_thic, -h_half - f_thic, f_near],
277
+ ],
278
+ axis=0,
279
+ )
280
+ vertices = np.append(vertices, vertices_bottom_frame, axis=0)
281
+ colors = np.append(colors, [[0.5, 0.5, 0.5]] * (2 + w), axis=0)
282
+ for i in range(w - 1):
283
+ faces.append([nv, nv + 2 + i + 1, nv + 2 + i])
284
+ faces.append([nv, nv + 1, nv + w + 1])
285
+
286
+ # BACK frame
287
+
288
+ nv = len(vertices)
289
+ vertices = np.append(
290
+ vertices,
291
+ [
292
+ [-w_half - f_thic, -h_half - f_thic, f_far_outer], # 00
293
+ [w_half + f_thic, -h_half - f_thic, f_far_outer], # 01
294
+ [w_half + f_thic, h_half + f_thic, f_far_outer], # 02
295
+ [-w_half - f_thic, h_half + f_thic, f_far_outer], # 03
296
+ ],
297
+ axis=0,
298
+ )
299
+ faces.extend(
300
+ [
301
+ [nv + 0, nv + 2, nv + 1],
302
+ [nv + 2, nv + 0, nv + 3],
303
+ ]
304
+ )
305
+ colors = np.append(colors, [[0.5, 0.5, 0.5]] * 4, axis=0)
306
+
307
+ trimesh_kwargs = {}
308
+ if vertex_colors:
309
+ trimesh_kwargs["vertex_colors"] = colors
310
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces, **trimesh_kwargs)
311
+
312
+ mesh.merge_vertices()
313
+
314
+ current_max_dimension = max(mesh.extents)
315
+ scaling_factor = output_model_scale / current_max_dimension
316
+ mesh.apply_scale(scaling_factor)
317
+
318
+ if prepare_for_3d_printing:
319
+ rotation_mat = trimesh.transformations.rotation_matrix(np.radians(90), [-1, 0, 0])
320
+ mesh.apply_transform(rotation_mat)
321
+
322
+ path_out_base = os.path.splitext(path_depth)[0].replace("_16bit", "")
323
+ path_out_glb = path_out_base + ".glb"
324
+ path_out_stl = path_out_base + ".stl"
325
+
326
+ mesh.export(path_out_glb, file_type="glb")
327
+ if scene_lights:
328
+ glb_add_lights(path_out_glb, path_out_glb)
329
+
330
+ mesh.export(path_out_stl, file_type="stl")
331
+
332
+ return path_out_glb, path_out_stl
files/basrelief/coin.jpg ADDED

Git LFS Details

  • SHA256: d5295c5cb301ef73099e3dd91f80916e7b013f6b04d75759df57081b16a18adc
  • Pointer size: 131 Bytes
  • Size of remote file: 632 kB
files/basrelief/einstein.jpg ADDED

Git LFS Details

  • SHA256: d4a4543c0fffb2ca5ea3c17e23e88fcfcf66eae8b487173fbc5c25d0d614bdb6
  • Pointer size: 131 Bytes
  • Size of remote file: 367 kB
files/basrelief/food.jpeg ADDED

Git LFS Details

  • SHA256: a26151050a574b0dc0014e9c4806da3d6f6bc1297ee1035a16b9ace007a179af
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
files/image/arc.jpeg ADDED

Git LFS Details

  • SHA256: f888e3770134e2073459026f58c568f7cf30524dd26a9182413c84b709e1b63e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
files/image/bee.jpg ADDED

Git LFS Details

  • SHA256: 7643ccdbc9550e2bf6ebdd5c768db5bc829ef719b0d1a91b4f6f9184b52f4751
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
files/image/berries.jpeg ADDED

Git LFS Details

  • SHA256: dac1411ea48cf83b7a59c6424032f95b2ff496b3a98cdccf168bbed1c8f0aed4
  • Pointer size: 131 Bytes
  • Size of remote file: 940 kB
files/image/butterfly.jpeg ADDED

Git LFS Details

  • SHA256: e0364b8eec31d2c113c15c2b6c892754130765e8e2c960adc87d51ca5c0ea8f9
  • Pointer size: 131 Bytes
  • Size of remote file: 878 kB
files/image/cat.jpg ADDED

Git LFS Details

  • SHA256: 794796a86e56a4b372287661dc934daa2d15e988d01afe88afc50b32644c007a
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
files/image/concert.jpeg ADDED

Git LFS Details

  • SHA256: fc746e234cb8a3e483999ee4c4f4d22b4e6c48cb2655eaa47c0936f3a37b61dc
  • Pointer size: 131 Bytes
  • Size of remote file: 420 kB
files/image/dog.jpeg ADDED

Git LFS Details

  • SHA256: c932a965dfe63c8c6dbc1bb48f7ea245a6a6dd2fb40fd243545e908b3aa7aa62
  • Pointer size: 131 Bytes
  • Size of remote file: 672 kB
files/image/doughnuts.jpeg ADDED

Git LFS Details

  • SHA256: 2ede4170b4a17f0c076c1a336eb4d3c03d64688997a986e3a8101972016b799a
  • Pointer size: 131 Bytes
  • Size of remote file: 607 kB
files/image/einstein.jpg ADDED

Git LFS Details

  • SHA256: d4a4543c0fffb2ca5ea3c17e23e88fcfcf66eae8b487173fbc5c25d0d614bdb6
  • Pointer size: 131 Bytes
  • Size of remote file: 367 kB
files/image/food.jpeg ADDED

Git LFS Details

  • SHA256: a26151050a574b0dc0014e9c4806da3d6f6bc1297ee1035a16b9ace007a179af
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
files/image/glasses.jpeg ADDED

Git LFS Details

  • SHA256: de8c0c20adb7c187357c21e467d3f178888574962027cdd366c390b63913ffec
  • Pointer size: 131 Bytes
  • Size of remote file: 677 kB
files/image/house.jpg ADDED

Git LFS Details

  • SHA256: 4087027e84a6323099fc839fd0b6816fd614814e92d12df21051cff3ed472819
  • Pointer size: 133 Bytes
  • Size of remote file: 14.9 MB
files/image/lake.jpeg ADDED

Git LFS Details

  • SHA256: 181dc0f684f0f3b94bc4bec829becd3dec817f69032731edf55ee8370c6898f0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
files/image/marigold.jpeg ADDED

Git LFS Details

  • SHA256: 575c1a7bc1199d86b5ec305b4efc12286842dee4a189e8699dcf8a6d0276807c
  • Pointer size: 131 Bytes
  • Size of remote file: 416 kB
files/image/portrait_1.jpeg ADDED

Git LFS Details

  • SHA256: 76e3ad74311975f0db43cdebd4202d1464e19b6950cc3e7c5aa0a160f95493c3
  • Pointer size: 131 Bytes
  • Size of remote file: 506 kB
files/image/portrait_2.jpeg ADDED

Git LFS Details

  • SHA256: 805ad1127b0d9d09068df70e3ab7aa7450ff802fa5464db8430787dfee1ec6a0
  • Pointer size: 131 Bytes
  • Size of remote file: 525 kB
files/image/pumpkins.jpg ADDED

Git LFS Details

  • SHA256: 92f03bc05dc882231bce735f2afb8c27eb9d0616166abe3794b39ff24314fd0a
  • Pointer size: 133 Bytes
  • Size of remote file: 11.3 MB
files/image/puzzle.jpeg ADDED

Git LFS Details

  • SHA256: 60b66432124a0936c6143301a9f9b793af4184bc9340c567d11fdd5a22cc98cc
  • Pointer size: 131 Bytes
  • Size of remote file: 374 kB
files/image/road.jpg ADDED

Git LFS Details

  • SHA256: 58bb01aea37f6e1206260eddb6d003589d779e8b3fb3ef0a0f1e2e38a8fa3925
  • Pointer size: 133 Bytes
  • Size of remote file: 13.1 MB
files/image/scientists.jpg ADDED

Git LFS Details

  • SHA256: 7b164dfbc4ab6e491ce81972b8c0e076fdc4af622289d0aa3cb43ee3c2be4030
  • Pointer size: 131 Bytes
  • Size of remote file: 444 kB
files/image/surfboards.jpeg ADDED

Git LFS Details

  • SHA256: 326f9ffd3b85b29b971205eb87c2d0c9b5e4409b496be1eb961b46d5f7c5d6c6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
files/image/surfer.jpeg ADDED

Git LFS Details

  • SHA256: 52827abf2c3951b752d4e58c88fff7ab907672c58fda70b813df3922650c7495
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
files/image/swings.jpg ADDED

Git LFS Details

  • SHA256: cae2ac669c948313eae8aca53017f10b64b42f87c53b9c34639962b218fdf1f1
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
files/image/switzerland.jpeg ADDED

Git LFS Details

  • SHA256: 81e35ba90f7736167ea3e8a0a58f932ecded07b00b012a5bd7df5dabbe0eb3ce
  • Pointer size: 131 Bytes
  • Size of remote file: 847 kB
files/image/teamwork.jpeg ADDED

Git LFS Details

  • SHA256: 3cd48af8f3db4d89760cd6f40f2716570e697ae74a9bd88ed1ba36c0e68326b3
  • Pointer size: 131 Bytes
  • Size of remote file: 700 kB
files/image/wave.jpeg ADDED

Git LFS Details

  • SHA256: 7f14e77f7990d75104d6e3447077eb176d6437c58f5fb0fffcdb6015193b2d03
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
files/video/cab.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7857328de30257e2985e0218e18e35f0dbc6ca9dd9f89b28687881f13ca0a4a
3
+ size 3268179
files/video/elephant.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d198ec2e3e5a308c5eeb18c9f3a882f6c5812d329d9e8497e1bf79ff466dd84
3
+ size 3078416
files/video/obama.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4aa0ac19460e0966139247cc180f98398fb11a35e3ca5c90cb70f0c4704904de
3
+ size 955458
marigold_depth_estimation_lcm.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Anton Obukhov, Bingxin Ke, ETH Zurich and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import math
22
+ from typing import Dict, Union, Tuple
23
+
24
+ import matplotlib
25
+ import numpy as np
26
+ import torch
27
+ from PIL import Image
28
+ from scipy.optimize import minimize
29
+ from torch.utils.data import DataLoader, TensorDataset
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+
33
+ from diffusers import (
34
+ AutoencoderKL,
35
+ DDIMScheduler,
36
+ DiffusionPipeline,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.utils import BaseOutput, check_min_version
40
+
41
+
42
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
43
+ check_min_version("0.27.0.dev0")
44
+
45
+
46
+ class MarigoldDepthConsistencyOutput(BaseOutput):
47
+ """
48
+ Output class for Marigold monocular depth prediction pipeline.
49
+
50
+ Args:
51
+ depth_np (`np.ndarray`):
52
+ Predicted depth map, with depth values in the range of [0, 1].
53
+ depth_colored (`None` or `PIL.Image.Image`):
54
+ Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
55
+ depth_latent (`torch.Tensor`):
56
+ Depth map's latent, with the shape of [4, h, w].
57
+ uncertainty (`None` or `np.ndarray`):
58
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
59
+ """
60
+
61
+ depth_np: np.ndarray
62
+ depth_colored: Union[None, Image.Image]
63
+ depth_latent: torch.Tensor
64
+ uncertainty: Union[None, np.ndarray]
65
+
66
+
67
+ class MarigoldDepthConsistencyPipeline(DiffusionPipeline):
68
+ """
69
+ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
70
+
71
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
72
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
73
+
74
+ Args:
75
+ unet (`UNet2DConditionModel`):
76
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
77
+ vae (`AutoencoderKL`):
78
+ Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
79
+ to and from latent representations.
80
+ scheduler (`DDIMScheduler`):
81
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
82
+ text_encoder (`CLIPTextModel`):
83
+ Text-encoder, for empty text embedding.
84
+ tokenizer (`CLIPTokenizer`):
85
+ CLIP tokenizer.
86
+ """
87
+
88
+ rgb_latent_scale_factor = 0.18215
89
+ depth_latent_scale_factor = 0.18215
90
+
91
+ def __init__(
92
+ self,
93
+ unet: UNet2DConditionModel,
94
+ vae: AutoencoderKL,
95
+ scheduler: DDIMScheduler,
96
+ text_encoder: CLIPTextModel,
97
+ tokenizer: CLIPTokenizer,
98
+ ):
99
+ super().__init__()
100
+
101
+ self.register_modules(
102
+ unet=unet,
103
+ vae=vae,
104
+ scheduler=scheduler,
105
+ text_encoder=text_encoder,
106
+ tokenizer=tokenizer,
107
+ )
108
+
109
+ self.empty_text_embed = None
110
+
111
+ @torch.no_grad()
112
+ def __call__(
113
+ self,
114
+ input_image: Image,
115
+ denoising_steps: int = 1,
116
+ ensemble_size: int = 1,
117
+ processing_res: int = 768,
118
+ match_input_res: bool = True,
119
+ batch_size: int = 0,
120
+ depth_latent_init: torch.Tensor = None,
121
+ depth_latent_init_strength: float = 0.1,
122
+ seed: int = None,
123
+ color_map: str = "Spectral",
124
+ show_progress_bar: bool = True,
125
+ ensemble_kwargs: Dict = None,
126
+ ) -> MarigoldDepthConsistencyOutput:
127
+ """
128
+ Function invoked when calling the pipeline.
129
+
130
+ Args:
131
+ input_image (`Image`):
132
+ Input RGB (or gray-scale) image.
133
+ processing_res (`int`, *optional*, defaults to `768`):
134
+ Maximum resolution of processing.
135
+ If set to 0: will not resize at all.
136
+ match_input_res (`bool`, *optional*, defaults to `True`):
137
+ Resize depth prediction to match input resolution.
138
+ Only valid if `limit_input_res` is not None.
139
+ denoising_steps (`int`, *optional*, defaults to `1`):
140
+ Number of diffusion denoising steps (consistency) during inference.
141
+ ensemble_size (`int`, *optional*, defaults to `1`):
142
+ Number of predictions to be ensembled.
143
+ batch_size (`int`, *optional*, defaults to `0`):
144
+ Inference batch size, no bigger than `num_ensemble`.
145
+ If set to 0, the script will automatically decide the proper batch size.
146
+ depth_latent_init (`torch.Tensor`, *optional*, defaults to `None`):
147
+ Initial depth map latent for better temporal consistency.
148
+ depth_latent_init_strength (`float`, *optional*, defaults to `0.1`)
149
+ Degree of initial depth latent influence, must be between 0 and 1.
150
+ seed (`int`, *optional*, defaults to `None`)
151
+ Reproducibility seed.
152
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
153
+ Display a progress bar of diffusion denoising.
154
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
155
+ Colormap used to colorize the depth map.
156
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
157
+ Arguments for detailed ensembling settings.
158
+ Returns:
159
+ `MarigoldDepthConsistencyOutput`: Output class for Marigold monocular depth prediction pipeline, including:
160
+ - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
161
+ - **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
162
+ values in [0, 1]. None if `color_map` is `None`
163
+ - **depth_latent** (`torch.Tensor`) Predicted depth map latent
164
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
165
+ coming from ensembling. None if `ensemble_size = 1`
166
+ """
167
+
168
+ device = self.device
169
+ input_size = input_image.size
170
+
171
+ if not match_input_res:
172
+ assert (
173
+ processing_res is not None
174
+ ), "Value error: `resize_output_back` is only valid with "
175
+ assert processing_res >= 0, "Value error: `processing_res` must be non-negative"
176
+ assert (
177
+ 1 <= denoising_steps <= 10
178
+ ), "Value error: This model degrades with large number of steps"
179
+ assert ensemble_size >= 1
180
+
181
+ # ----------------- Image Preprocess -----------------
182
+ # Resize image
183
+ if processing_res > 0:
184
+ input_image = self.resize_max_res(
185
+ input_image, max_edge_resolution=processing_res
186
+ )
187
+ # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
188
+ input_image = input_image.convert("RGB")
189
+ image = np.asarray(input_image)
190
+
191
+ # Normalize rgb values
192
+ rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
193
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
194
+ rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
195
+ rgb_norm = rgb_norm.to(device)
196
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
197
+
198
+ # ----------------- Predicting depth -----------------
199
+ # Batch repeated input image
200
+ duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
201
+ batch_dataset = TensorDataset(duplicated_rgb)
202
+ if batch_size > 0:
203
+ _bs = batch_size
204
+ else:
205
+ _bs = self._find_batch_size(
206
+ ensemble_size=ensemble_size,
207
+ input_res=max(duplicated_rgb.shape[-2:]),
208
+ dtype=self.dtype,
209
+ )
210
+
211
+ batch_loader = DataLoader(batch_dataset, batch_size=_bs, shuffle=False)
212
+
213
+ # Predict depth maps (batched)
214
+ depth_pred_ls = []
215
+ if show_progress_bar:
216
+ iterable = tqdm(
217
+ batch_loader, desc=" " * 2 + "Inference batches", leave=False
218
+ )
219
+ else:
220
+ iterable = batch_loader
221
+ depth_latent = None
222
+ for batch in iterable:
223
+ (batched_img,) = batch
224
+ depth_pred_raw, depth_latent = self.single_infer(
225
+ rgb_in=batched_img,
226
+ num_inference_steps=denoising_steps,
227
+ depth_latent_init=depth_latent_init,
228
+ depth_latent_init_strength=depth_latent_init_strength,
229
+ seed=seed,
230
+ show_pbar=show_progress_bar,
231
+ )
232
+ depth_pred_ls.append(depth_pred_raw.detach())
233
+ depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
234
+ torch.cuda.empty_cache() # clear vram cache for ensembling
235
+
236
+ # ----------------- Test-time ensembling -----------------
237
+ if ensemble_size > 1:
238
+ depth_pred, pred_uncert = self.ensemble_depths(
239
+ depth_preds, **(ensemble_kwargs or {})
240
+ )
241
+ else:
242
+ depth_pred = depth_preds
243
+ pred_uncert = None
244
+
245
+ # ----------------- Post processing -----------------
246
+ # Scale prediction to [0, 1]
247
+ min_d = torch.min(depth_pred)
248
+ max_d = torch.max(depth_pred)
249
+ depth_pred = (depth_pred - min_d) / (max_d - min_d)
250
+ if ensemble_size > 1:
251
+ depth_latent = self._encode_depth(2 * depth_pred - 1)
252
+
253
+ # Convert to numpy
254
+ depth_pred = depth_pred.cpu().numpy().astype(np.float32)
255
+
256
+ # Resize back to original resolution
257
+ if match_input_res:
258
+ pred_img = Image.fromarray(depth_pred)
259
+ pred_img = pred_img.resize(input_size)
260
+ depth_pred = np.asarray(pred_img)
261
+
262
+ # Clip output range
263
+ depth_pred = depth_pred.clip(0, 1)
264
+
265
+ # Colorize
266
+ if color_map is not None:
267
+ depth_colored = self.colorize_depth_maps(
268
+ depth_pred, 0, 1, cmap=color_map
269
+ ).squeeze() # [3, H, W], value in (0, 1)
270
+ depth_colored = (depth_colored * 255).astype(np.uint8)
271
+ depth_colored_hwc = self.chw2hwc(depth_colored)
272
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
273
+ else:
274
+ depth_colored_img = None
275
+ return MarigoldDepthConsistencyOutput(
276
+ depth_np=depth_pred,
277
+ depth_colored=depth_colored_img,
278
+ depth_latent=depth_latent,
279
+ uncertainty=pred_uncert,
280
+ )
281
+
282
+ def _encode_empty_text(self):
283
+ """
284
+ Encode text embedding for empty prompt.
285
+ """
286
+ prompt = ""
287
+ text_inputs = self.tokenizer(
288
+ prompt,
289
+ padding="do_not_pad",
290
+ max_length=self.tokenizer.model_max_length,
291
+ truncation=True,
292
+ return_tensors="pt",
293
+ )
294
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
295
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
296
+
297
+ @torch.no_grad()
298
+ def single_infer(
299
+ self,
300
+ rgb_in: torch.Tensor,
301
+ num_inference_steps: int,
302
+ depth_latent_init: torch.Tensor,
303
+ depth_latent_init_strength: float,
304
+ seed: int,
305
+ show_pbar: bool,
306
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
307
+ """
308
+ Perform an individual depth prediction without ensembling.
309
+
310
+ Args:
311
+ rgb_in (`torch.Tensor`):
312
+ Input RGB image.
313
+ num_inference_steps (`int`):
314
+ Number of diffusion denoisign steps (DDIM) during inference.
315
+ depth_latent_init (`torch.Tensor`, `optional`):
316
+ Initial depth latent
317
+ depth_latent_init_strength (`float`, `optional`):
318
+ Degree of initial depth latent influence, must be between 0 and 1
319
+ seed (`int`, *optional*, defaults to `None`)
320
+ Reproducibility seed.
321
+ show_pbar (`bool`):
322
+ Display a progress bar of diffusion denoising.
323
+ Returns:
324
+ `torch.Tensor`: Predicted depth map.
325
+ """
326
+ device = rgb_in.device
327
+
328
+ # Set timesteps
329
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
330
+ timesteps = self.scheduler.timesteps # [T]
331
+
332
+ # Encode image
333
+ rgb_latent = self._encode_rgb(rgb_in)
334
+
335
+ # Initial depth map (noise)
336
+ if seed is None:
337
+ rng = None
338
+ else:
339
+ rng = torch.Generator(device=device)
340
+ rng.manual_seed(seed)
341
+ depth_latent = torch.randn(
342
+ rgb_latent.shape, device=device, dtype=self.dtype, generator=rng
343
+ ) # [B, 4, h, w]
344
+
345
+ if depth_latent_init is not None:
346
+ assert 0.0 <= depth_latent_init_strength <= 1.0
347
+ assert (
348
+ depth_latent_init.dim() == 4
349
+ and depth_latent.dim() == 4
350
+ and depth_latent_init.shape[0] == 1
351
+ )
352
+ if depth_latent.shape[0] != 1:
353
+ depth_latent_init = depth_latent_init.repeat(
354
+ depth_latent.shape[0], 1, 1, 1
355
+ )
356
+ depth_latent *= 1.0 - depth_latent_init_strength
357
+ depth_latent = depth_latent + depth_latent_init * depth_latent_init_strength
358
+
359
+ # Batched empty text embedding
360
+ if self.empty_text_embed is None:
361
+ self._encode_empty_text()
362
+ batch_empty_text_embed = self.empty_text_embed.repeat(
363
+ (rgb_latent.shape[0], 1, 1)
364
+ ) # [B, 2, 1024]
365
+
366
+ # Denoising loop
367
+ if show_pbar:
368
+ iterable = tqdm(
369
+ enumerate(timesteps),
370
+ total=len(timesteps),
371
+ leave=False,
372
+ desc=" " * 4 + "Diffusion denoising",
373
+ )
374
+ else:
375
+ iterable = enumerate(timesteps)
376
+
377
+ for i, t in iterable:
378
+ unet_input = torch.cat(
379
+ [rgb_latent, depth_latent], dim=1
380
+ ) # this order is important
381
+
382
+ # predict the noise residual
383
+ noise_pred = self.unet(
384
+ unet_input, t, encoder_hidden_states=batch_empty_text_embed
385
+ ).sample # [B, 4, h, w]
386
+
387
+ # compute the previous noisy sample x_t -> x_t-1
388
+ depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
389
+
390
+ depth = self._decode_depth(depth_latent)
391
+
392
+ # clip prediction
393
+ depth = torch.clip(depth, -1.0, 1.0)
394
+ # shift to [0, 1]
395
+ depth = (depth + 1.0) / 2.0
396
+
397
+ return depth, depth_latent
398
+
399
+ def _encode_depth(self, depth_in: torch.Tensor) -> torch.Tensor:
400
+ """
401
+ Encode depth image into latent.
402
+
403
+ Args:
404
+ depth_in (`torch.Tensor`):
405
+ Input Depth image to be encoded.
406
+
407
+ Returns:
408
+ `torch.Tensor`: Depth latent.
409
+ """
410
+ # encode
411
+ dims = depth_in.squeeze().shape
412
+ h = self.vae.encoder(depth_in.reshape(1, 1, *dims).repeat(1, 3, 1, 1))
413
+ moments = self.vae.quant_conv(h)
414
+ mean, _ = torch.chunk(moments, 2, dim=1)
415
+ depth_latent = mean * self.depth_latent_scale_factor
416
+ return depth_latent
417
+
418
+ def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
419
+ """
420
+ Encode RGB image into latent.
421
+
422
+ Args:
423
+ rgb_in (`torch.Tensor`):
424
+ Input RGB image to be encoded.
425
+
426
+ Returns:
427
+ `torch.Tensor`: Image latent.
428
+ """
429
+ # encode
430
+ h = self.vae.encoder(rgb_in)
431
+ moments = self.vae.quant_conv(h)
432
+ mean, logvar = torch.chunk(moments, 2, dim=1)
433
+ # scale latent
434
+ rgb_latent = mean * self.rgb_latent_scale_factor
435
+ return rgb_latent
436
+
437
+ def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
438
+ """
439
+ Decode depth latent into depth map.
440
+
441
+ Args:
442
+ depth_latent (`torch.Tensor`):
443
+ Depth latent to be decoded.
444
+
445
+ Returns:
446
+ `torch.Tensor`: Decoded depth map.
447
+ """
448
+ # scale latent
449
+ depth_latent = depth_latent / self.depth_latent_scale_factor
450
+ # decode
451
+ z = self.vae.post_quant_conv(depth_latent)
452
+ stacked = self.vae.decoder(z)
453
+ # mean of output channels
454
+ depth_mean = stacked.mean(dim=1, keepdim=True)
455
+ return depth_mean
456
+
457
+ @staticmethod
458
+ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
459
+ """
460
+ Resize image to limit maximum edge length while keeping aspect ratio.
461
+
462
+ Args:
463
+ img (`Image.Image`):
464
+ Image to be resized.
465
+ max_edge_resolution (`int`):
466
+ Maximum edge length (pixel).
467
+
468
+ Returns:
469
+ `Image.Image`: Resized image.
470
+ """
471
+ original_width, original_height = img.size
472
+ downscale_factor = min(
473
+ max_edge_resolution / original_width, max_edge_resolution / original_height
474
+ )
475
+
476
+ new_width = int(original_width * downscale_factor)
477
+ new_height = int(original_height * downscale_factor)
478
+
479
+ resized_img = img.resize((new_width, new_height))
480
+ return resized_img
481
+
482
+ @staticmethod
483
+ def colorize_depth_maps(
484
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
485
+ ):
486
+ """
487
+ Colorize depth maps.
488
+ """
489
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
490
+
491
+ if isinstance(depth_map, torch.Tensor):
492
+ depth = depth_map.detach().squeeze().numpy()
493
+ elif isinstance(depth_map, np.ndarray):
494
+ depth = depth_map.copy().squeeze()
495
+ # reshape to [ (B,) H, W ]
496
+ if depth.ndim < 3:
497
+ depth = depth[np.newaxis, :, :]
498
+
499
+ # colorize
500
+ cm = matplotlib.colormaps[cmap]
501
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
502
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
503
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
504
+
505
+ if valid_mask is not None:
506
+ if isinstance(depth_map, torch.Tensor):
507
+ valid_mask = valid_mask.detach().numpy()
508
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
509
+ if valid_mask.ndim < 3:
510
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
511
+ else:
512
+ valid_mask = valid_mask[:, np.newaxis, :, :]
513
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
514
+ img_colored_np[~valid_mask] = 0
515
+
516
+ if isinstance(depth_map, torch.Tensor):
517
+ img_colored = torch.from_numpy(img_colored_np).float()
518
+ elif isinstance(depth_map, np.ndarray):
519
+ img_colored = img_colored_np
520
+
521
+ return img_colored
522
+
523
+ @staticmethod
524
+ def chw2hwc(chw):
525
+ assert 3 == len(chw.shape)
526
+ if isinstance(chw, torch.Tensor):
527
+ hwc = torch.permute(chw, (1, 2, 0))
528
+ elif isinstance(chw, np.ndarray):
529
+ hwc = np.moveaxis(chw, 0, -1)
530
+ return hwc
531
+
532
+ @staticmethod
533
+ def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
534
+ """
535
+ Automatically search for suitable operating batch size.
536
+
537
+ Args:
538
+ ensemble_size (`int`):
539
+ Number of predictions to be ensembled.
540
+ input_res (`int`):
541
+ Operating resolution of the input image.
542
+
543
+ Returns:
544
+ `int`: Operating batch size.
545
+ """
546
+ # Search table for suggested max. inference batch size
547
+ bs_search_table = [
548
+ # tested on A100-PCIE-80GB
549
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
550
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
551
+ # tested on A100-PCIE-40GB
552
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
553
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
554
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
555
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
556
+ # tested on RTX3090, RTX4090
557
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
558
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
559
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
560
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
561
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
562
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
563
+ # tested on GTX1080Ti
564
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
565
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
566
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
567
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
568
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
569
+ ]
570
+
571
+ if not torch.cuda.is_available():
572
+ return 1
573
+
574
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
575
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
576
+ for settings in sorted(
577
+ filtered_bs_search_table,
578
+ key=lambda k: (k["res"], -k["total_vram"]),
579
+ ):
580
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
581
+ bs = settings["bs"]
582
+ if bs > ensemble_size:
583
+ bs = ensemble_size
584
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
585
+ bs = math.ceil(ensemble_size / 2)
586
+ return bs
587
+
588
+ return 1
589
+
590
+ @staticmethod
591
+ def ensemble_depths(
592
+ input_images: torch.Tensor,
593
+ regularizer_strength: float = 0.02,
594
+ max_iter: int = 2,
595
+ tol: float = 1e-3,
596
+ reduction: str = "median",
597
+ max_res: int = None,
598
+ ):
599
+ """
600
+ To ensemble multiple affine-invariant depth images (up to scale and shift),
601
+ by aligning estimating the scale and shift
602
+ """
603
+
604
+ def inter_distances(tensors: torch.Tensor):
605
+ """
606
+ To calculate the distance between each two depth maps.
607
+ """
608
+ distances = []
609
+ for i, j in torch.combinations(torch.arange(tensors.shape[0])):
610
+ arr1 = tensors[i : i + 1]
611
+ arr2 = tensors[j : j + 1]
612
+ distances.append(arr1 - arr2)
613
+ dist = torch.concatenate(distances, dim=0)
614
+ return dist
615
+
616
+ device = input_images.device
617
+ dtype = input_images.dtype
618
+ np_dtype = np.float32
619
+
620
+ original_input = input_images.clone()
621
+ n_img = input_images.shape[0]
622
+ ori_shape = input_images.shape
623
+
624
+ if max_res is not None:
625
+ scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
626
+ if scale_factor < 1:
627
+ downscaler = torch.nn.Upsample(
628
+ scale_factor=scale_factor, mode="nearest"
629
+ )
630
+ input_images = downscaler(torch.from_numpy(input_images)).numpy()
631
+
632
+ # init guess
633
+ _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
634
+ _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
635
+ s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
636
+ t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
637
+ x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
638
+
639
+ input_images = input_images.to(device)
640
+
641
+ # objective function
642
+ def closure(x):
643
+ l = len(x)
644
+ s = x[: int(l / 2)]
645
+ t = x[int(l / 2) :]
646
+ s = torch.from_numpy(s).to(dtype=dtype).to(device)
647
+ t = torch.from_numpy(t).to(dtype=dtype).to(device)
648
+
649
+ transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
650
+ dists = inter_distances(transformed_arrays)
651
+ sqrt_dist = torch.sqrt(torch.mean(dists**2))
652
+
653
+ if "mean" == reduction:
654
+ pred = torch.mean(transformed_arrays, dim=0)
655
+ elif "median" == reduction:
656
+ pred = torch.median(transformed_arrays, dim=0).values
657
+ else:
658
+ raise ValueError
659
+
660
+ near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
661
+ far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
662
+
663
+ err = sqrt_dist + (near_err + far_err) * regularizer_strength
664
+ err = err.detach().cpu().numpy().astype(np_dtype)
665
+ return err
666
+
667
+ res = minimize(
668
+ closure,
669
+ x,
670
+ method="BFGS",
671
+ tol=tol,
672
+ options={"maxiter": max_iter, "disp": False},
673
+ )
674
+ x = res.x
675
+ l = len(x)
676
+ s = x[: int(l / 2)]
677
+ t = x[int(l / 2) :]
678
+
679
+ # Prediction
680
+ s = torch.from_numpy(s).to(dtype=dtype).to(device)
681
+ t = torch.from_numpy(t).to(dtype=dtype).to(device)
682
+ transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
683
+ if "mean" == reduction:
684
+ aligned_images = torch.mean(transformed_arrays, dim=0)
685
+ std = torch.std(transformed_arrays, dim=0)
686
+ uncertainty = std
687
+ elif "median" == reduction:
688
+ aligned_images = torch.median(transformed_arrays, dim=0).values
689
+ # MAD (median absolute deviation) as uncertainty indicator
690
+ abs_dev = torch.abs(transformed_arrays - aligned_images)
691
+ mad = torch.median(abs_dev, dim=0).values
692
+ uncertainty = mad
693
+ else:
694
+ raise ValueError(f"Unknown reduction method: {reduction}")
695
+
696
+ # Scale and shift to [0, 1]
697
+ _min = torch.min(aligned_images)
698
+ _max = torch.max(aligned_images)
699
+ aligned_images = (aligned_images - _min) / (_max - _min)
700
+ uncertainty /= _max - _min
701
+
702
+ return aligned_images, uncertainty
marigold_logo_square.jpg ADDED

Git LFS Details

  • SHA256: bd5f1e527678fc913aee17ab69831551cfdb2934f673e9e97a7f011103b63c9e
  • Pointer size: 130 Bytes
  • Size of remote file: 76 kB
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.22.0
2
+ gradio-imageslider==0.0.16
3
+ pygltflib==1.16.1
4
+ trimesh==4.0.5
5
+ imageio
6
+ imageio-ffmpeg
7
+ Pillow
8
+
9
+ accelerate>=0.22.0
10
+ diffusers==0.27.2
11
+ matplotlib==3.8.2
12
+ scipy==1.11.4
13
+ torch==2.0.1
14
+ transformers>=4.32.1
15
+ xformers>=0.0.21