toshas commited on
Commit
e926577
·
1 Parent(s): 558cb45

add header link to self for embedding

Browse files

add badge for diffusers tutorial
bump to the latest diffusers

Files changed (4) hide show
  1. app.py +27 -44
  2. marigold_iid_appearance.py +0 -561
  3. marigold_iid_lighting.py +0 -576
  4. requirements.txt +1 -1
app.py CHANGED
@@ -30,22 +30,16 @@
30
  # --------------------------------------------------------------------------
31
 
32
  import os
33
-
34
- import numpy as np
35
-
36
  os.system("pip freeze")
37
  import spaces
38
 
39
  import gradio as gr
40
  import torch as torch
41
- from diffusers import DDIMScheduler
42
  from gradio_dualvision import DualVisionApp
43
  from huggingface_hub import login
44
  from PIL import Image
45
 
46
- from marigold_iid_appearance import MarigoldIIDAppearancePipeline
47
- from marigold_iid_lighting import MarigoldIIDLightingPipeline
48
-
49
  CHECKPOINT_APPEARANCE = "prs-eth/marigold-iid-appearance-v1-1"
50
  CHECKPOINT_LIGHTING = "prs-eth/marigold-iid-lighting-v1-1"
51
 
@@ -55,19 +49,11 @@ if "HF_TOKEN_LOGIN" in os.environ:
55
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
57
 
58
- pipe_appearance = MarigoldIIDAppearancePipeline.from_pretrained(
59
- "prs-eth/marigold-iid-appearance-v1-1"
60
- )
61
- pipe_appearance.scheduler = DDIMScheduler.from_config(
62
- pipe_appearance.scheduler.config, timestep_spacing="trailing"
63
- )
64
  pipe_appearance = pipe_appearance.to(device=device, dtype=dtype)
65
- pipe_lighting = MarigoldIIDLightingPipeline.from_pretrained(
66
- "prs-eth/marigold-iid-lighting-v1-1"
67
- )
68
- pipe_lighting.scheduler = DDIMScheduler.from_config(
69
- pipe_lighting.scheduler.config, timestep_spacing="trailing"
70
- )
71
  pipe_lighting = pipe_lighting.to(device=device, dtype=dtype)
72
  try:
73
  import xformers
@@ -87,7 +73,7 @@ class MarigoldIIDApp(DualVisionApp):
87
  def make_header(self):
88
  gr.Markdown(
89
  """
90
- ## Marigold Intrinsic Image Decomposition
91
  """
92
  )
93
  with gr.Row(elem_classes="remove-elements"):
@@ -97,6 +83,9 @@ class MarigoldIIDApp(DualVisionApp):
97
  <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
98
  <img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue">
99
  </a>
 
 
 
100
  <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
101
  <img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-AF3436">
102
  </a>
@@ -156,7 +145,7 @@ class MarigoldIIDApp(DualVisionApp):
156
  ensemble_size = kwargs.get("ensemble_size", self.DEFAULT_ENSEMBLE_SIZE)
157
  denoise_steps = kwargs.get("denoise_steps", self.DEFAULT_DENOISE_STEPS)
158
  processing_res = kwargs.get("processing_res", self.DEFAULT_PROCESSING_RES)
159
- # generator = torch.Generator(device=device).manual_seed(self.DEFAULT_SEED)
160
 
161
  pipe_out_appearance = pipe_appearance(
162
  image_in,
@@ -165,19 +154,12 @@ class MarigoldIIDApp(DualVisionApp):
165
  processing_resolution=processing_res,
166
  batch_size=1 if processing_res == 0 else 2,
167
  output_uncertainty=ensemble_size >= 3,
168
- # generator=generator,
169
- seed=self.DEFAULT_SEED,
170
  )
171
 
172
- roughness = pipe_out_appearance.material[0].clip(-1, 1)
173
- roughness = (roughness + 1.0) * 0.5
174
- roughness = (roughness * 65535).astype(np.uint16)
175
- roughness = Image.fromarray(roughness, mode="I;16")
176
-
177
- metallicity = pipe_out_appearance.material[1].clip(-1, 1)
178
- metallicity = (metallicity + 1.0) * 0.5
179
- metallicity = (metallicity * 65535).astype(np.uint16)
180
- metallicity = Image.fromarray(metallicity, mode="I;16")
181
 
182
  pipe_out_lighting = pipe_lighting(
183
  image_in,
@@ -186,22 +168,23 @@ class MarigoldIIDApp(DualVisionApp):
186
  processing_resolution=processing_res,
187
  batch_size=1 if processing_res == 0 else 2,
188
  output_uncertainty=ensemble_size >= 3,
189
- # generator=generator,
190
- seed=self.DEFAULT_SEED,
 
 
 
191
  )
192
 
193
  out_modalities = {
194
- "Albedo": pipe_out_appearance.albedo_colored,
195
- "Materials": pipe_out_appearance.material_colored,
196
- "Roughness": roughness,
197
- "Metallicity": metallicity,
198
- "Albedo (HyperSim)": pipe_out_lighting.albedo_colored,
199
- "Shading (HyperSim)": pipe_out_lighting.shading_colored,
200
- "Residual (HyperSim)": pipe_out_lighting.residual_colored,
201
  }
202
- # if ensemble_size >= 3:
203
- # uncertainty = pipe.image_processor.visualize_uncertainty(pipe_out.uncertainty)[0]
204
- # out_modalities["Uncertainty"] = uncertainty
205
 
206
  out_settings = {
207
  "ensemble_size": ensemble_size,
 
30
  # --------------------------------------------------------------------------
31
 
32
  import os
 
 
 
33
  os.system("pip freeze")
34
  import spaces
35
 
36
  import gradio as gr
37
  import torch as torch
38
+ from diffusers import MarigoldIntrinsicsPipeline, DDIMScheduler
39
  from gradio_dualvision import DualVisionApp
40
  from huggingface_hub import login
41
  from PIL import Image
42
 
 
 
 
43
  CHECKPOINT_APPEARANCE = "prs-eth/marigold-iid-appearance-v1-1"
44
  CHECKPOINT_LIGHTING = "prs-eth/marigold-iid-lighting-v1-1"
45
 
 
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
51
 
52
+ pipe_appearance = MarigoldIntrinsicsPipeline.from_pretrained(CHECKPOINT_APPEARANCE)
53
+ pipe_appearance.scheduler = DDIMScheduler.from_config(pipe_appearance.scheduler.config, timestep_spacing="trailing")
 
 
 
 
54
  pipe_appearance = pipe_appearance.to(device=device, dtype=dtype)
55
+ pipe_lighting = MarigoldIntrinsicsPipeline.from_pretrained(CHECKPOINT_LIGHTING)
56
+ pipe_lighting.scheduler = DDIMScheduler.from_config(pipe_lighting.scheduler.config, timestep_spacing="trailing")
 
 
 
 
57
  pipe_lighting = pipe_lighting.to(device=device, dtype=dtype)
58
  try:
59
  import xformers
 
73
  def make_header(self):
74
  gr.Markdown(
75
  """
76
+ ## [Marigold Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-intrinsics)
77
  """
78
  )
79
  with gr.Row(elem_classes="remove-elements"):
 
83
  <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
84
  <img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue">
85
  </a>
86
+ <a title="diffusers" href="https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
87
+ <img src="https://img.shields.io/badge/%F0%9F%A7%A8%20Read_diffusers-tutorial-yellow?labelColor=green">
88
+ </a>
89
  <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
90
  <img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-AF3436">
91
  </a>
 
145
  ensemble_size = kwargs.get("ensemble_size", self.DEFAULT_ENSEMBLE_SIZE)
146
  denoise_steps = kwargs.get("denoise_steps", self.DEFAULT_DENOISE_STEPS)
147
  processing_res = kwargs.get("processing_res", self.DEFAULT_PROCESSING_RES)
148
+ generator = torch.Generator(device=device).manual_seed(self.DEFAULT_SEED)
149
 
150
  pipe_out_appearance = pipe_appearance(
151
  image_in,
 
154
  processing_resolution=processing_res,
155
  batch_size=1 if processing_res == 0 else 2,
156
  output_uncertainty=ensemble_size >= 3,
157
+ generator=generator,
 
158
  )
159
 
160
+ iid_appearance_vis = pipe_appearance.image_processor.visualize_intrinsics(
161
+ pipe_out_appearance.prediction, pipe_appearance.target_properties
162
+ )
 
 
 
 
 
 
163
 
164
  pipe_out_lighting = pipe_lighting(
165
  image_in,
 
168
  processing_resolution=processing_res,
169
  batch_size=1 if processing_res == 0 else 2,
170
  output_uncertainty=ensemble_size >= 3,
171
+ generator=generator,
172
+ )
173
+
174
+ iid_lighting_vis = pipe_lighting.image_processor.visualize_intrinsics(
175
+ pipe_out_lighting.prediction, pipe_lighting.target_properties
176
  )
177
 
178
  out_modalities = {
179
+ "Albedo": iid_appearance_vis[0]["albedo"],
180
+ "Materials": iid_appearance_vis[0]["material"],
181
+ "Roughness": iid_appearance_vis[0]["roughness"],
182
+ "Metallicity": iid_appearance_vis[0]["metallicity"],
183
+ "Albedo (HyperSim)": iid_lighting_vis[0]["albedo"],
184
+ "Shading (HyperSim)": iid_lighting_vis[0]["shading"],
185
+ "Residual (HyperSim)": iid_lighting_vis[0]["residual"],
186
  }
187
+ # Additionally, uncertainty can be computed on any of the output modalities; we skip it to keep the demo light
 
 
188
 
189
  out_settings = {
190
  "ensemble_size": ensemble_size,
marigold_iid_appearance.py DELETED
@@ -1,561 +0,0 @@
1
- # Copyright 2024 Anton Obukhov, Bingxin Ke, Bo Li & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # If you find this code useful, we kindly ask you to cite our paper in your work.
16
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldcomputervision.github.io
18
- # --------------------------------------------------------------------------
19
- import logging
20
- import math
21
- from typing import Optional, Tuple, Union, Dict, Any
22
-
23
- import numpy as np
24
- import torch
25
- from diffusers import (
26
- AutoencoderKL,
27
- DDIMScheduler,
28
- DiffusionPipeline,
29
- UNet2DConditionModel,
30
- )
31
- from diffusers.utils import BaseOutput, check_min_version
32
- from PIL import Image
33
- from PIL.Image import Resampling
34
- from torch.utils.data import DataLoader, TensorDataset
35
- from tqdm.auto import tqdm
36
- from transformers import CLIPTextModel, CLIPTokenizer
37
-
38
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
- check_min_version("0.27.0.dev0")
40
-
41
-
42
- class MarigoldIIDAppearanceOutput(BaseOutput):
43
- """
44
- Output class for Marigold IID Appearance pipeline.
45
-
46
- Args:
47
- albedo (`np.ndarray`):
48
- Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1].
49
- albedo_colored (`PIL.Image.Image`):
50
- Colorized albedo map with the shape of [H, W, 3].
51
- material (`np.ndarray`):
52
- Predicted material map with the shape of [3, H, W] and values in [0, 1].
53
- 1st channel (Red) is roughness
54
- 2nd channel (Green) is metallicity
55
- 3rd channel (Blue) is empty (zero)
56
- material_colored (`PIL.Image.Image`):
57
- Colorized material map with the shape of [H, W, 3].
58
- 1st channel (Red) is roughness
59
- 2nd channel (Green) is metallicity
60
- 3rd channel (Blue) is empty (zero)
61
- """
62
-
63
- albedo: np.ndarray
64
- albedo_colored: Image.Image
65
- material: np.ndarray
66
- material_colored: Image.Image
67
-
68
-
69
- class MarigoldIIDAppearancePipeline(DiffusionPipeline):
70
- """
71
- Pipeline for Intrinsic Image Decomposition (Albedo and Material) using Marigold: https://marigoldcomputervision.github.io.
72
-
73
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
74
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
75
-
76
- Args:
77
- unet (`UNet2DConditionModel`):
78
- Conditional U-Net to denoise the normals latent, conditioned on image latent.
79
- vae (`AutoencoderKL`):
80
- Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps
81
- to and from latent representations.
82
- scheduler (`DDIMScheduler`):
83
- A scheduler to be used in combination with `unet` to denoise the encoded image latents.
84
- text_encoder (`CLIPTextModel`):
85
- Text-encoder, for empty text embedding.
86
- tokenizer (`CLIPTokenizer`):
87
- CLIP tokenizer.
88
- """
89
-
90
- latent_scale_factor = 0.18215
91
-
92
- def __init__(
93
- self,
94
- unet: UNet2DConditionModel,
95
- vae: AutoencoderKL,
96
- scheduler: DDIMScheduler,
97
- text_encoder: CLIPTextModel,
98
- tokenizer: CLIPTokenizer,
99
- prediction_type: Optional[str] = None,
100
- target_properties: Optional[Dict[str, Any]] = None,
101
- default_denoising_steps: Optional[int] = None,
102
- default_processing_resolution: Optional[int] = None,
103
- ):
104
- super().__init__()
105
-
106
- self.register_modules(
107
- unet=unet,
108
- vae=vae,
109
- scheduler=scheduler,
110
- text_encoder=text_encoder,
111
- tokenizer=tokenizer,
112
- )
113
- self.register_to_config(
114
- prediction_type=prediction_type,
115
- target_properties=target_properties,
116
- default_denoising_steps=default_denoising_steps,
117
- default_processing_resolution=default_processing_resolution,
118
- )
119
-
120
- self.empty_text_embed = None
121
-
122
- self.n_targets = 2 # Albedo and material
123
-
124
- @torch.no_grad()
125
- def __call__(
126
- self,
127
- input_image: Image,
128
- denoising_steps: int = 4,
129
- ensemble_size: int = 10,
130
- processing_res: int = 768,
131
- match_input_res: bool = True,
132
- resample_method: str = "bilinear",
133
- batch_size: int = 0,
134
- save_memory: bool = False,
135
- seed: Union[int, None] = None,
136
- color_map: str = "Spectral", # TODO change colorization api based on modality
137
- show_progress_bar: bool = True,
138
- **kwargs,
139
- ) -> MarigoldIIDAppearanceOutput:
140
- """
141
- Function invoked when calling the pipeline.
142
-
143
- Args:
144
- input_image (`Image`):
145
- Input RGB (or gray-scale) image.
146
- denoising_steps (`int`, *optional*, defaults to `10`):
147
- Number of diffusion denoising steps (DDIM) during inference.
148
- ensemble_size (`int`, *optional*, defaults to `10`):
149
- Number of predictions to be ensembled.
150
- processing_res (`int`, *optional*, defaults to `768`):
151
- Maximum resolution of processing.
152
- If set to 0: will not resize at all.
153
- match_input_res (`bool`, *optional*, defaults to `True`):
154
- Resize normals prediction to match input resolution.
155
- Only valid if `limit_input_res` is not None.
156
- resample_method: (`str`, *optional*, defaults to `bilinear`):
157
- Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
158
- batch_size (`int`, *optional*, defaults to `0`):
159
- Inference batch size, no bigger than `num_ensemble`.
160
- If set to 0, the script will automatically decide the proper batch size.
161
- save_memory (`bool`, defaults to `False`):
162
- Extra steps to save memory at the cost of perforance.
163
- seed (`int`, *optional*, defaults to `None`)
164
- Reproducibility seed.
165
- color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation):
166
- Colormap used to colorize the normals map.
167
- show_progress_bar (`bool`, *optional*, defaults to `True`):
168
- Display a progress bar of diffusion denoising.
169
- Returns:
170
- `MarigoldIIDAppearanceOutput`: Output class for Marigold monocular intrinsic image decomposition (appearance) prediction pipeline, including:
171
- - **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
172
- - **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
173
- - **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
174
- - **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1]
175
- """
176
-
177
- if not match_input_res:
178
- assert processing_res is not None
179
- assert processing_res >= 0
180
- assert denoising_steps >= 1
181
- assert ensemble_size >= 1
182
-
183
- # Check if denoising step is reasonable
184
- self.check_inference_step(denoising_steps)
185
-
186
- resample_method: Resampling = self.get_pil_resample_method(resample_method)
187
-
188
- W, H = input_image.size
189
-
190
- if processing_res > 0:
191
- input_image = self.resize_max_res(
192
- input_image,
193
- max_edge_resolution=processing_res,
194
- resample_method=resample_method,
195
- )
196
- input_image = input_image.convert("RGB")
197
- image = np.asarray(input_image)
198
-
199
- rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
200
- rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
201
- rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
202
- rgb_norm = rgb_norm.to(self.device)
203
- assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this
204
-
205
- def ensemble(
206
- targets: torch.Tensor,
207
- return_uncertainty: bool = False,
208
- reduction="median",
209
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
210
- uncertainty = None
211
- if reduction == "mean":
212
- prediction = torch.mean(targets, dim=0, keepdim=True)
213
- if return_uncertainty:
214
- uncertainty = torch.std(targets, dim=0, keepdim=True)
215
- elif reduction == "median":
216
- prediction = torch.median(targets, dim=0, keepdim=True).values
217
- if return_uncertainty:
218
- uncertainty = torch.median(
219
- torch.abs(targets - prediction), dim=0, keepdim=True
220
- ).values
221
- else:
222
- raise ValueError(f"Unrecognized reduction method: {reduction}.")
223
- return prediction, uncertainty
224
-
225
- duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
226
- single_rgb_dataset = TensorDataset(duplicated_rgb)
227
-
228
- if batch_size <= 0:
229
- batch_size = self.find_batch_size(
230
- ensemble_size=ensemble_size,
231
- input_res=max(rgb_norm.shape[1:]),
232
- dtype=self.dtype,
233
- )
234
-
235
- single_rgb_loader = DataLoader(
236
- single_rgb_dataset, batch_size=batch_size, shuffle=False
237
- )
238
-
239
- target_pred_ls = []
240
- iterable = single_rgb_loader
241
- if show_progress_bar:
242
- iterable = tqdm(
243
- single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
244
- )
245
-
246
- for batch in iterable:
247
- (batched_img,) = batch
248
- target_pred = self.single_infer(
249
- rgb_in=batched_img,
250
- num_inference_steps=denoising_steps,
251
- seed=seed,
252
- show_pbar=show_progress_bar,
253
- )
254
- target_pred = target_pred.detach()
255
- if save_memory:
256
- target_pred = target_pred.cpu()
257
- target_pred_ls.append(target_pred.detach())
258
-
259
- target_preds = torch.concat(target_pred_ls, dim=0)
260
- pred_uncert = None
261
-
262
- if save_memory:
263
- torch.cuda.empty_cache()
264
-
265
- if ensemble_size > 1:
266
- final_pred, pred_uncert = ensemble(
267
- target_preds, reduction="median", return_uncertainty=False
268
- )
269
- else:
270
- final_pred = target_preds
271
- pred_uncert = None
272
-
273
- if match_input_res:
274
- final_pred = torch.nn.functional.interpolate(
275
- final_pred, (H, W), mode="bilinear" # TODO: parameterize this method
276
- ) # [1,3,H,W]
277
-
278
- if pred_uncert is not None:
279
- pred_uncert = torch.nn.functional.interpolate(
280
- pred_uncert.unsqueeze(1), (H, W), mode="bilinear"
281
- ).squeeze(
282
- 1
283
- ) # [1,H,W]
284
-
285
- # Convert to numpy
286
- final_pred = final_pred.squeeze()
287
- final_pred = final_pred.cpu().float().numpy()
288
-
289
- albedo = final_pred[0:3, :, :]
290
- material = np.stack(
291
- (final_pred[3, :, :], final_pred[4, :, :], final_pred[5, :, :]), axis=0
292
- )
293
-
294
- albedo_colored = (albedo + 1.0) * 0.5
295
- albedo_colored = (albedo_colored * 255).astype(np.uint8)
296
- albedo_colored = self.chw2hwc(albedo_colored)
297
- albedo_colored_img = Image.fromarray(albedo_colored)
298
-
299
- material_colored = (material + 1.0) * 0.5
300
- material_colored = (material_colored * 255).astype(np.uint8)
301
- material_colored = self.chw2hwc(material_colored)
302
- material_colored_img = Image.fromarray(material_colored)
303
-
304
- out = MarigoldIIDAppearanceOutput(
305
- albedo=albedo,
306
- albedo_colored=albedo_colored_img,
307
- material=material,
308
- material_colored=material_colored_img,
309
- )
310
-
311
- return out
312
-
313
- def check_inference_step(self, n_step: int):
314
- """
315
- Check if denoising step is reasonable
316
- Args:
317
- n_step (`int`): denoising steps
318
- """
319
- assert n_step >= 1
320
-
321
- if isinstance(self.scheduler, DDIMScheduler):
322
- pass
323
- else:
324
- raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
325
-
326
- def encode_empty_text(self):
327
- """
328
- Encode text embedding for empty prompt.
329
- """
330
- prompt = ""
331
- text_inputs = self.tokenizer(
332
- prompt,
333
- padding="do_not_pad",
334
- max_length=self.tokenizer.model_max_length,
335
- truncation=True,
336
- return_tensors="pt",
337
- )
338
- text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
339
- self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
340
-
341
- @torch.no_grad()
342
- def single_infer(
343
- self,
344
- rgb_in: torch.Tensor,
345
- num_inference_steps: int,
346
- seed: Union[int, None],
347
- show_pbar: bool,
348
- ) -> torch.Tensor:
349
- """
350
- Perform an individual iid prediction without ensembling.
351
- """
352
- device = rgb_in.device
353
-
354
- # Set timesteps
355
- self.scheduler.set_timesteps(num_inference_steps, device=device)
356
- timesteps = self.scheduler.timesteps # [T]
357
-
358
- # Encode image
359
- rgb_latent = self.encode_rgb(rgb_in)
360
-
361
- target_latent_shape = list(rgb_latent.shape)
362
- target_latent_shape[
363
- 1
364
- ] *= 2 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w)
365
-
366
- # Initialize prediction latent with noise
367
- if seed is None:
368
- rand_num_generator = None
369
- else:
370
- rand_num_generator = torch.Generator(device=device)
371
- rand_num_generator.manual_seed(seed)
372
- target_latents = torch.randn(
373
- target_latent_shape,
374
- device=device,
375
- dtype=self.dtype,
376
- generator=rand_num_generator,
377
- ) # [B, 4, h, w]
378
-
379
- # Batched empty text embedding
380
- if self.empty_text_embed is None:
381
- self.encode_empty_text()
382
- batch_empty_text_embed = self.empty_text_embed.repeat(
383
- (rgb_latent.shape[0], 1, 1)
384
- ) # [B, 2, 1024]
385
-
386
- # Denoising loop
387
- if show_pbar:
388
- iterable = tqdm(
389
- enumerate(timesteps),
390
- total=len(timesteps),
391
- leave=False,
392
- desc=" " * 4 + "Diffusion denoising",
393
- )
394
- else:
395
- iterable = enumerate(timesteps)
396
-
397
- for i, t in iterable:
398
- unet_input = torch.cat(
399
- [rgb_latent, target_latents], dim=1
400
- ) # this order is important
401
-
402
- # predict the noise residual
403
- noise_pred = self.unet(
404
- unet_input, t, encoder_hidden_states=batch_empty_text_embed
405
- ).sample # [B, 4, h, w]
406
-
407
- # compute the previous noisy sample x_t -> x_t-1
408
- target_latents = self.scheduler.step(
409
- noise_pred, t, target_latents, generator=rand_num_generator
410
- ).prev_sample
411
-
412
- # torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving?
413
-
414
- targets = self.decode_targets(target_latents) # [B, 3, H, W]
415
- targets = torch.clip(targets, -1.0, 1.0)
416
-
417
- return targets
418
-
419
- def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
420
- """
421
- Encode RGB image into latent.
422
-
423
- Args:
424
- rgb_in (`torch.Tensor`):
425
- Input RGB image to be encoded.
426
-
427
- Returns:
428
- `torch.Tensor`: Image latent.
429
- """
430
- # encode
431
- h = self.vae.encoder(rgb_in)
432
- moments = self.vae.quant_conv(h)
433
- mean, logvar = torch.chunk(moments, 2, dim=1)
434
- # scale latent
435
- rgb_latent = mean * self.latent_scale_factor
436
- return rgb_latent
437
-
438
- def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor:
439
- """
440
- Decode target latent into target map.
441
-
442
- Args:
443
- target_latents (`torch.Tensor`):
444
- Target latent to be decoded.
445
-
446
- Returns:
447
- `torch.Tensor`: Decoded target map.
448
- """
449
-
450
- assert target_latents.shape[1] == 8 # self.n_targets * 4
451
-
452
- # scale latent
453
- target_latents = target_latents / self.latent_scale_factor
454
- # decode
455
- targets = []
456
- for i in range(self.n_targets):
457
- latent = target_latents[:, i * 4 : (i + 1) * 4, :, :]
458
- z = self.vae.post_quant_conv(latent)
459
- stacked = self.vae.decoder(z)
460
-
461
- targets.append(stacked)
462
-
463
- return torch.cat(targets, dim=1)
464
-
465
- @staticmethod
466
- def get_pil_resample_method(method_str: str) -> Resampling:
467
- resample_method_dic = {
468
- "bilinear": Resampling.BILINEAR,
469
- "bicubic": Resampling.BICUBIC,
470
- "nearest": Resampling.NEAREST,
471
- }
472
- resample_method = resample_method_dic.get(method_str, None)
473
- if resample_method is None:
474
- raise ValueError(f"Unknown resampling method: {resample_method}")
475
- else:
476
- return resample_method
477
-
478
- @staticmethod
479
- def resize_max_res(
480
- img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR
481
- ) -> Image.Image:
482
- """
483
- Resize image to limit maximum edge length while keeping aspect ratio.
484
- """
485
- original_width, original_height = img.size
486
- downscale_factor = min(
487
- max_edge_resolution / original_width, max_edge_resolution / original_height
488
- )
489
-
490
- new_width = int(original_width * downscale_factor)
491
- new_height = int(original_height * downscale_factor)
492
-
493
- resized_img = img.resize((new_width, new_height), resample=resample_method)
494
- return resized_img
495
-
496
- @staticmethod
497
- def chw2hwc(chw):
498
- assert 3 == len(chw.shape)
499
- if isinstance(chw, torch.Tensor):
500
- hwc = torch.permute(chw, (1, 2, 0))
501
- elif isinstance(chw, np.ndarray):
502
- hwc = np.moveaxis(chw, 0, -1)
503
- return hwc
504
-
505
- @staticmethod
506
- def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
507
- """
508
- Automatically search for suitable operating batch size.
509
-
510
- Args:
511
- ensemble_size (`int`):
512
- Number of predictions to be ensembled.
513
- input_res (`int`):
514
- Operating resolution of the input image.
515
-
516
- Returns:
517
- `int`: Operating batch size.
518
- """
519
- # Search table for suggested max. inference batch size
520
- bs_search_table = [
521
- # tested on A100-PCIE-80GB
522
- {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
523
- {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
524
- # tested on A100-PCIE-40GB
525
- {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
526
- {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
527
- {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
528
- {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
529
- # tested on RTX3090, RTX4090
530
- {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
531
- {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
532
- {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
533
- {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
534
- {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
535
- {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
536
- # tested on GTX1080Ti
537
- {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
538
- {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
539
- {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
540
- {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
541
- {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
542
- ]
543
-
544
- if not torch.cuda.is_available():
545
- return 1
546
-
547
- total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
548
- filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
549
- for settings in sorted(
550
- filtered_bs_search_table,
551
- key=lambda k: (k["res"], -k["total_vram"]),
552
- ):
553
- if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
554
- bs = settings["bs"]
555
- if bs > ensemble_size:
556
- bs = ensemble_size
557
- elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
558
- bs = math.ceil(ensemble_size / 2)
559
- return bs
560
-
561
- return 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
marigold_iid_lighting.py DELETED
@@ -1,576 +0,0 @@
1
- # Copyright 2024 Anton Obukhov, Bingxin Ke & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # If you find this code useful, we kindly ask you to cite our paper in your work.
16
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldcomputervision.github.io
18
- # --------------------------------------------------------------------------
19
- import logging
20
- import math
21
- from typing import Optional, Tuple, Union, Dict, Any
22
-
23
- import numpy as np
24
- import torch
25
- from diffusers import (
26
- AutoencoderKL,
27
- DDIMScheduler,
28
- DiffusionPipeline,
29
- UNet2DConditionModel,
30
- )
31
- from diffusers.utils import BaseOutput, check_min_version
32
- from PIL import Image
33
- from PIL.Image import Resampling
34
- from torch.utils.data import DataLoader, TensorDataset
35
- from tqdm.auto import tqdm
36
- from transformers import CLIPTextModel, CLIPTokenizer
37
-
38
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
- check_min_version("0.27.0.dev0")
40
-
41
-
42
- class MarigoldIIDLightingOutput(BaseOutput):
43
- """
44
- Output class for Marigold-IID-Lighting pipeline.
45
-
46
- Args:
47
- albedo (`np.ndarray`):
48
- Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1].
49
- albedo_colored (`PIL.Image.Image`):
50
- Colorized albedo map with the shape of [H, W, 3].
51
- shading (`np.ndarray`):
52
- Predicted diffuse shading map with the shape of [3, H, W] values in the range of [0, 1].
53
- shading_colored (`PIL.Image.Image`):
54
- Colorized diffuse shading map with the shape of [H, W, 3].
55
- residual (`np.ndarray`):
56
- Predicted non-diffuse residual map with the shape of [3, H, W] values in the range of [0, 1].
57
- residual_colored (`PIL.Image.Image`):
58
- Colorized non-diffuse residual map with the shape of [H, W, 3].
59
-
60
- """
61
-
62
- albedo: np.ndarray
63
- albedo_colored: Image.Image
64
- shading: np.ndarray
65
- shading_colored: Image.Image
66
- residual: np.ndarray
67
- residual_colored: Image.Image
68
-
69
-
70
- class MarigoldIIDLightingPipeline(DiffusionPipeline):
71
- """
72
- Pipeline for Intrinsic Image Decomposition (Albedo, diffuse shading and non-diffuse residual) using Marigold: https://marigoldcomputervision.github.io.
73
-
74
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
75
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
76
-
77
- Args:
78
- unet (`UNet2DConditionModel`):
79
- Conditional U-Net to denoise the normals latent, conditioned on image latent.
80
- vae (`AutoencoderKL`):
81
- Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps
82
- to and from latent representations.
83
- scheduler (`DDIMScheduler`):
84
- A scheduler to be used in combination with `unet` to denoise the encoded image latents.
85
- text_encoder (`CLIPTextModel`):
86
- Text-encoder, for empty text embedding.
87
- tokenizer (`CLIPTokenizer`):
88
- CLIP tokenizer.
89
- """
90
-
91
- latent_scale_factor = 0.18215
92
-
93
- def __init__(
94
- self,
95
- unet: UNet2DConditionModel,
96
- vae: AutoencoderKL,
97
- scheduler: DDIMScheduler,
98
- text_encoder: CLIPTextModel,
99
- tokenizer: CLIPTokenizer,
100
- prediction_type: Optional[str] = None,
101
- target_properties: Optional[Dict[str, Any]] = None,
102
- default_denoising_steps: Optional[int] = None,
103
- default_processing_resolution: Optional[int] = None,
104
- ):
105
- super().__init__()
106
-
107
- self.register_modules(
108
- unet=unet,
109
- vae=vae,
110
- scheduler=scheduler,
111
- text_encoder=text_encoder,
112
- tokenizer=tokenizer,
113
- )
114
- self.register_to_config(
115
- prediction_type=prediction_type,
116
- target_properties=target_properties,
117
- default_denoising_steps=default_denoising_steps,
118
- default_processing_resolution=default_processing_resolution,
119
- )
120
-
121
- self.empty_text_embed = None
122
- self.n_targets = 3 # Albedo, shading, residual
123
-
124
- @torch.no_grad()
125
- def __call__(
126
- self,
127
- input_image: Image,
128
- denoising_steps: int = 4,
129
- ensemble_size: int = 10,
130
- processing_res: int = 768,
131
- match_input_res: bool = True,
132
- resample_method: str = "bilinear",
133
- batch_size: int = 0,
134
- save_memory: bool = False,
135
- seed: Union[int, None] = None,
136
- color_map: str = "Spectral", # TODO change colorization api based on modality
137
- show_progress_bar: bool = True,
138
- **kwargs,
139
- ) -> MarigoldIIDLightingOutput:
140
- """
141
- Function invoked when calling the pipeline.
142
-
143
- Args:
144
- input_image (`Image`):
145
- Input RGB (or gray-scale) image.
146
- denoising_steps (`int`, *optional*, defaults to `10`):
147
- Number of diffusion denoising steps (DDIM) during inference.
148
- ensemble_size (`int`, *optional*, defaults to `10`):
149
- Number of predictions to be ensembled.
150
- processing_res (`int`, *optional*, defaults to `768`):
151
- Maximum resolution of processing.
152
- If set to 0: will not resize at all.
153
- match_input_res (`bool`, *optional*, defaults to `True`):
154
- Resize normals prediction to match input resolution.
155
- Only valid if `limit_input_res` is not None.
156
- resample_method: (`str`, *optional*, defaults to `bilinear`):
157
- Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
158
- batch_size (`int`, *optional*, defaults to `0`):
159
- Inference batch size, no bigger than `num_ensemble`.
160
- If set to 0, the script will automatically decide the proper batch size.
161
- save_memory (`bool`, defaults to `False`):
162
- Extra steps to save memory at the cost of perforance.
163
- seed (`int`, *optional*, defaults to `None`)
164
- Reproducibility seed.
165
- color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation):
166
- Colormap used to colorize the normals map.
167
- show_progress_bar (`bool`, *optional*, defaults to `True`):
168
- Display a progress bar of diffusion denoising.
169
- Returns:
170
- `MarigoldIIDLightingOutput`: Output class for Marigold monocular intrinsic image decomposition (lighting) prediction pipeline, including:
171
- - **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
172
- - **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
173
- - **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
174
- - **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1]
175
- """
176
-
177
- if not match_input_res:
178
- assert processing_res is not None
179
- assert processing_res >= 0
180
- assert denoising_steps >= 1
181
- assert ensemble_size >= 1
182
-
183
- # Check if denoising step is reasonable
184
- self.check_inference_step(denoising_steps)
185
-
186
- resample_method: Resampling = self.get_pil_resample_method(resample_method)
187
-
188
- W, H = input_image.size
189
-
190
- if processing_res > 0:
191
- input_image = self.resize_max_res(
192
- input_image,
193
- max_edge_resolution=processing_res,
194
- resample_method=resample_method,
195
- )
196
- input_image = input_image.convert("RGB")
197
- image = np.asarray(input_image)
198
-
199
- rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
200
- rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
201
- rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
202
- rgb_norm = rgb_norm.to(self.device)
203
- assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this
204
-
205
- def ensemble(
206
- targets: torch.Tensor,
207
- return_uncertainty: bool = False,
208
- reduction="median",
209
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
210
- uncertainty = None
211
- if reduction == "mean":
212
- prediction = torch.mean(targets, dim=0, keepdim=True)
213
- if return_uncertainty:
214
- uncertainty = torch.std(targets, dim=0, keepdim=True)
215
- elif reduction == "median":
216
- prediction = torch.median(targets, dim=0, keepdim=True).values
217
- if return_uncertainty:
218
- uncertainty = torch.median(
219
- torch.abs(targets - prediction), dim=0, keepdim=True
220
- ).values
221
- else:
222
- raise ValueError(f"Unrecognized reduction method: {reduction}.")
223
- return prediction, uncertainty
224
-
225
- duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
226
- single_rgb_dataset = TensorDataset(duplicated_rgb)
227
-
228
- if batch_size <= 0:
229
- batch_size = self.find_batch_size(
230
- ensemble_size=ensemble_size,
231
- input_res=max(rgb_norm.shape[1:]),
232
- dtype=self.dtype,
233
- )
234
-
235
- single_rgb_loader = DataLoader(
236
- single_rgb_dataset, batch_size=batch_size, shuffle=False
237
- )
238
-
239
- target_pred_ls = []
240
- iterable = single_rgb_loader
241
- if show_progress_bar:
242
- iterable = tqdm(
243
- single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
244
- )
245
-
246
- for batch in iterable:
247
- (batched_img,) = batch
248
- target_pred = self.single_infer(
249
- rgb_in=batched_img,
250
- num_inference_steps=denoising_steps,
251
- seed=seed,
252
- show_pbar=show_progress_bar,
253
- )
254
- target_pred = target_pred.detach()
255
- if save_memory:
256
- target_pred = target_pred.cpu()
257
- target_pred_ls.append(target_pred.detach())
258
-
259
- target_preds = torch.concat(target_pred_ls, dim=0)
260
- pred_uncert = None
261
-
262
- if save_memory:
263
- torch.cuda.empty_cache()
264
-
265
- if ensemble_size > 1:
266
- final_pred, pred_uncert = ensemble(
267
- target_preds, reduction="median", return_uncertainty=False
268
- )
269
- else:
270
- final_pred = target_preds
271
- pred_uncert = None
272
-
273
- if match_input_res:
274
- final_pred = torch.nn.functional.interpolate(
275
- final_pred, (H, W), mode="bilinear" # TODO: parameterize this method
276
- ) # [1,3,H,W]
277
-
278
- if pred_uncert is not None:
279
- pred_uncert = torch.nn.functional.interpolate(
280
- pred_uncert.unsqueeze(1), (H, W), mode="bilinear"
281
- ).squeeze(
282
- 1
283
- ) # [1,H,W]
284
-
285
- # Convert to numpy
286
- final_pred = final_pred.squeeze()
287
- final_pred = final_pred.cpu().float().numpy()
288
-
289
- albedo = final_pred[0:3, :, :]
290
- shading = final_pred[3:6, :, :]
291
- residual = final_pred[6:, :, :]
292
-
293
- albedo_colored = (albedo + 1.0) * 0.5 # [-1,1] -> [0,1]
294
- albedo_colored = albedo_colored ** (
295
- 1 / 2.2
296
- ) # from linear to sRGB (to be consistent with IID-Appearance model)
297
- albedo_colored = (albedo_colored * 255).astype(np.uint8)
298
- albedo_colored = self.chw2hwc(albedo_colored)
299
- albedo_colored_img = Image.fromarray(albedo_colored)
300
-
301
- shading_colored = (shading + 1.0) * 0.5
302
- shading_colored = (
303
- shading_colored / shading_colored.max()
304
- ) # rescale for better visualization
305
- shading_colored = (shading_colored * 255).astype(np.uint8)
306
- shading_colored = self.chw2hwc(shading_colored)
307
- shading_colored_img = Image.fromarray(shading_colored)
308
-
309
- residual_colored = (residual + 1.0) * 0.5
310
- residual_colored = (
311
- residual_colored / residual_colored.max()
312
- ) # rescale for better visualization
313
- residual_colored = (residual_colored * 255).astype(np.uint8)
314
- residual_colored = self.chw2hwc(residual_colored)
315
- residual_colored_img = Image.fromarray(residual_colored)
316
-
317
- out = MarigoldIIDLightingOutput(
318
- albedo=albedo,
319
- albedo_colored=albedo_colored_img,
320
- shading=shading,
321
- shading_colored=shading_colored_img,
322
- residual=residual,
323
- residual_colored=residual_colored_img,
324
- )
325
-
326
- return out
327
-
328
- def check_inference_step(self, n_step: int):
329
- """
330
- Check if denoising step is reasonable
331
- Args:
332
- n_step (`int`): denoising steps
333
- """
334
- assert n_step >= 1
335
-
336
- if isinstance(self.scheduler, DDIMScheduler):
337
- pass
338
- else:
339
- raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
340
-
341
- def encode_empty_text(self):
342
- """
343
- Encode text embedding for empty prompt.
344
- """
345
- prompt = ""
346
- text_inputs = self.tokenizer(
347
- prompt,
348
- padding="do_not_pad",
349
- max_length=self.tokenizer.model_max_length,
350
- truncation=True,
351
- return_tensors="pt",
352
- )
353
- text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
354
- self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
355
-
356
- @torch.no_grad()
357
- def single_infer(
358
- self,
359
- rgb_in: torch.Tensor,
360
- num_inference_steps: int,
361
- seed: Union[int, None],
362
- show_pbar: bool,
363
- ) -> torch.Tensor:
364
- """
365
- Perform an individual iid prediction without ensembling.
366
- """
367
- device = rgb_in.device
368
-
369
- # Set timesteps
370
- self.scheduler.set_timesteps(num_inference_steps, device=device)
371
- timesteps = self.scheduler.timesteps # [T]
372
-
373
- # Encode image
374
- rgb_latent = self.encode_rgb(rgb_in)
375
-
376
- target_latent_shape = list(rgb_latent.shape)
377
- target_latent_shape[
378
- 1
379
- ] *= 3 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w)
380
-
381
- # Initialize prediction latent with noise
382
- if seed is None:
383
- rand_num_generator = None
384
- else:
385
- rand_num_generator = torch.Generator(device=device)
386
- rand_num_generator.manual_seed(seed)
387
- target_latents = torch.randn(
388
- target_latent_shape,
389
- device=device,
390
- dtype=self.dtype,
391
- generator=rand_num_generator,
392
- ) # [B, 4, h, w]
393
-
394
- # Batched empty text embedding
395
- if self.empty_text_embed is None:
396
- self.encode_empty_text()
397
- batch_empty_text_embed = self.empty_text_embed.repeat(
398
- (rgb_latent.shape[0], 1, 1)
399
- ) # [B, 2, 1024]
400
-
401
- # Denoising loop
402
- if show_pbar:
403
- iterable = tqdm(
404
- enumerate(timesteps),
405
- total=len(timesteps),
406
- leave=False,
407
- desc=" " * 4 + "Diffusion denoising",
408
- )
409
- else:
410
- iterable = enumerate(timesteps)
411
-
412
- for i, t in iterable:
413
- unet_input = torch.cat(
414
- [rgb_latent, target_latents], dim=1
415
- ) # this order is important
416
-
417
- # predict the noise residual
418
- noise_pred = self.unet(
419
- unet_input, t, encoder_hidden_states=batch_empty_text_embed
420
- ).sample # [B, 4, h, w]
421
-
422
- # compute the previous noisy sample x_t -> x_t-1
423
- target_latents = self.scheduler.step(
424
- noise_pred, t, target_latents, generator=rand_num_generator
425
- ).prev_sample
426
-
427
- # torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving?
428
-
429
- targets = self.decode_targets(target_latents) # [B, 3, H, W]
430
- targets = torch.clip(targets, -1.0, 1.0)
431
-
432
- return targets
433
-
434
- def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
435
- """
436
- Encode RGB image into latent.
437
-
438
- Args:
439
- rgb_in (`torch.Tensor`):
440
- Input RGB image to be encoded.
441
-
442
- Returns:
443
- `torch.Tensor`: Image latent.
444
- """
445
- # encode
446
- h = self.vae.encoder(rgb_in)
447
- moments = self.vae.quant_conv(h)
448
- mean, logvar = torch.chunk(moments, 2, dim=1)
449
- # scale latent
450
- rgb_latent = mean * self.latent_scale_factor
451
- return rgb_latent
452
-
453
- def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor:
454
- """
455
- Decode target latent into target map.
456
-
457
- Args:
458
- target_latents (`torch.Tensor`):
459
- Target latent to be decoded.
460
-
461
- Returns:
462
- `torch.Tensor`: Decoded target map.
463
- """
464
-
465
- assert target_latents.shape[1] == 12 # self.n_targets * 4
466
-
467
- # scale latent
468
- target_latents = target_latents / self.latent_scale_factor
469
- # decode
470
- targets = []
471
- for i in range(self.n_targets):
472
- latent = target_latents[:, i * 4 : (i + 1) * 4, :, :]
473
- z = self.vae.post_quant_conv(latent)
474
- stacked = self.vae.decoder(z)
475
-
476
- targets.append(stacked)
477
-
478
- return torch.cat(targets, dim=1)
479
-
480
- @staticmethod
481
- def get_pil_resample_method(method_str: str) -> Resampling:
482
- resample_method_dic = {
483
- "bilinear": Resampling.BILINEAR,
484
- "bicubic": Resampling.BICUBIC,
485
- "nearest": Resampling.NEAREST,
486
- }
487
- resample_method = resample_method_dic.get(method_str, None)
488
- if resample_method is None:
489
- raise ValueError(f"Unknown resampling method: {resample_method}")
490
- else:
491
- return resample_method
492
-
493
- @staticmethod
494
- def resize_max_res(
495
- img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR
496
- ) -> Image.Image:
497
- """
498
- Resize image to limit maximum edge length while keeping aspect ratio.
499
- """
500
- original_width, original_height = img.size
501
- downscale_factor = min(
502
- max_edge_resolution / original_width, max_edge_resolution / original_height
503
- )
504
-
505
- new_width = int(original_width * downscale_factor)
506
- new_height = int(original_height * downscale_factor)
507
-
508
- resized_img = img.resize((new_width, new_height), resample=resample_method)
509
- return resized_img
510
-
511
- @staticmethod
512
- def chw2hwc(chw):
513
- assert 3 == len(chw.shape)
514
- if isinstance(chw, torch.Tensor):
515
- hwc = torch.permute(chw, (1, 2, 0))
516
- elif isinstance(chw, np.ndarray):
517
- hwc = np.moveaxis(chw, 0, -1)
518
- return hwc
519
-
520
- @staticmethod
521
- def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
522
- """
523
- Automatically search for suitable operating batch size.
524
-
525
- Args:
526
- ensemble_size (`int`):
527
- Number of predictions to be ensembled.
528
- input_res (`int`):
529
- Operating resolution of the input image.
530
-
531
- Returns:
532
- `int`: Operating batch size.
533
- """
534
- # Search table for suggested max. inference batch size
535
- bs_search_table = [
536
- # tested on A100-PCIE-80GB
537
- {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
538
- {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
539
- # tested on A100-PCIE-40GB
540
- {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
541
- {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
542
- {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
543
- {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
544
- # tested on RTX3090, RTX4090
545
- {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
546
- {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
547
- {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
548
- {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
549
- {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
550
- {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
551
- # tested on GTX1080Ti
552
- {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
553
- {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
554
- {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
555
- {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
556
- {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
557
- ]
558
-
559
- if not torch.cuda.is_available():
560
- return 1
561
-
562
- total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
563
- filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
564
- for settings in sorted(
565
- filtered_bs_search_table,
566
- key=lambda k: (k["res"], -k["total_vram"]),
567
- ):
568
- if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
569
- bs = settings["bs"]
570
- if bs > ensemble_size:
571
- bs = ensemble_size
572
- elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
573
- bs = math.ceil(ensemble_size / 2)
574
- return bs
575
-
576
- return 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- diffusers>=0.32.2
2
  git+https://github.com/toshas/gradio-dualvision.git@21346a4
3
  accelerate
4
  huggingface_hub
 
1
+ diffusers>=0.33.0
2
  git+https://github.com/toshas/gradio-dualvision.git@21346a4
3
  accelerate
4
  huggingface_hub