Spaces:
Running
on
Zero
Running
on
Zero
add header link to self for embedding
Browse filesadd badge for diffusers tutorial
bump to the latest diffusers
- app.py +27 -44
- marigold_iid_appearance.py +0 -561
- marigold_iid_lighting.py +0 -576
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -30,22 +30,16 @@
|
|
| 30 |
# --------------------------------------------------------------------------
|
| 31 |
|
| 32 |
import os
|
| 33 |
-
|
| 34 |
-
import numpy as np
|
| 35 |
-
|
| 36 |
os.system("pip freeze")
|
| 37 |
import spaces
|
| 38 |
|
| 39 |
import gradio as gr
|
| 40 |
import torch as torch
|
| 41 |
-
from diffusers import DDIMScheduler
|
| 42 |
from gradio_dualvision import DualVisionApp
|
| 43 |
from huggingface_hub import login
|
| 44 |
from PIL import Image
|
| 45 |
|
| 46 |
-
from marigold_iid_appearance import MarigoldIIDAppearancePipeline
|
| 47 |
-
from marigold_iid_lighting import MarigoldIIDLightingPipeline
|
| 48 |
-
|
| 49 |
CHECKPOINT_APPEARANCE = "prs-eth/marigold-iid-appearance-v1-1"
|
| 50 |
CHECKPOINT_LIGHTING = "prs-eth/marigold-iid-lighting-v1-1"
|
| 51 |
|
|
@@ -55,19 +49,11 @@ if "HF_TOKEN_LOGIN" in os.environ:
|
|
| 55 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 56 |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 57 |
|
| 58 |
-
pipe_appearance =
|
| 59 |
-
|
| 60 |
-
)
|
| 61 |
-
pipe_appearance.scheduler = DDIMScheduler.from_config(
|
| 62 |
-
pipe_appearance.scheduler.config, timestep_spacing="trailing"
|
| 63 |
-
)
|
| 64 |
pipe_appearance = pipe_appearance.to(device=device, dtype=dtype)
|
| 65 |
-
pipe_lighting =
|
| 66 |
-
|
| 67 |
-
)
|
| 68 |
-
pipe_lighting.scheduler = DDIMScheduler.from_config(
|
| 69 |
-
pipe_lighting.scheduler.config, timestep_spacing="trailing"
|
| 70 |
-
)
|
| 71 |
pipe_lighting = pipe_lighting.to(device=device, dtype=dtype)
|
| 72 |
try:
|
| 73 |
import xformers
|
|
@@ -87,7 +73,7 @@ class MarigoldIIDApp(DualVisionApp):
|
|
| 87 |
def make_header(self):
|
| 88 |
gr.Markdown(
|
| 89 |
"""
|
| 90 |
-
## Marigold Intrinsic Image Decomposition
|
| 91 |
"""
|
| 92 |
)
|
| 93 |
with gr.Row(elem_classes="remove-elements"):
|
|
@@ -97,6 +83,9 @@ class MarigoldIIDApp(DualVisionApp):
|
|
| 97 |
<a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
| 98 |
<img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue">
|
| 99 |
</a>
|
|
|
|
|
|
|
|
|
|
| 100 |
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
| 101 |
<img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-AF3436">
|
| 102 |
</a>
|
|
@@ -156,7 +145,7 @@ class MarigoldIIDApp(DualVisionApp):
|
|
| 156 |
ensemble_size = kwargs.get("ensemble_size", self.DEFAULT_ENSEMBLE_SIZE)
|
| 157 |
denoise_steps = kwargs.get("denoise_steps", self.DEFAULT_DENOISE_STEPS)
|
| 158 |
processing_res = kwargs.get("processing_res", self.DEFAULT_PROCESSING_RES)
|
| 159 |
-
|
| 160 |
|
| 161 |
pipe_out_appearance = pipe_appearance(
|
| 162 |
image_in,
|
|
@@ -165,19 +154,12 @@ class MarigoldIIDApp(DualVisionApp):
|
|
| 165 |
processing_resolution=processing_res,
|
| 166 |
batch_size=1 if processing_res == 0 else 2,
|
| 167 |
output_uncertainty=ensemble_size >= 3,
|
| 168 |
-
|
| 169 |
-
seed=self.DEFAULT_SEED,
|
| 170 |
)
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
roughness = Image.fromarray(roughness, mode="I;16")
|
| 176 |
-
|
| 177 |
-
metallicity = pipe_out_appearance.material[1].clip(-1, 1)
|
| 178 |
-
metallicity = (metallicity + 1.0) * 0.5
|
| 179 |
-
metallicity = (metallicity * 65535).astype(np.uint16)
|
| 180 |
-
metallicity = Image.fromarray(metallicity, mode="I;16")
|
| 181 |
|
| 182 |
pipe_out_lighting = pipe_lighting(
|
| 183 |
image_in,
|
|
@@ -186,22 +168,23 @@ class MarigoldIIDApp(DualVisionApp):
|
|
| 186 |
processing_resolution=processing_res,
|
| 187 |
batch_size=1 if processing_res == 0 else 2,
|
| 188 |
output_uncertainty=ensemble_size >= 3,
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
| 191 |
)
|
| 192 |
|
| 193 |
out_modalities = {
|
| 194 |
-
"Albedo":
|
| 195 |
-
"Materials":
|
| 196 |
-
"Roughness": roughness,
|
| 197 |
-
"Metallicity": metallicity,
|
| 198 |
-
"Albedo (HyperSim)":
|
| 199 |
-
"Shading (HyperSim)":
|
| 200 |
-
"Residual (HyperSim)":
|
| 201 |
}
|
| 202 |
-
#
|
| 203 |
-
# uncertainty = pipe.image_processor.visualize_uncertainty(pipe_out.uncertainty)[0]
|
| 204 |
-
# out_modalities["Uncertainty"] = uncertainty
|
| 205 |
|
| 206 |
out_settings = {
|
| 207 |
"ensemble_size": ensemble_size,
|
|
|
|
| 30 |
# --------------------------------------------------------------------------
|
| 31 |
|
| 32 |
import os
|
|
|
|
|
|
|
|
|
|
| 33 |
os.system("pip freeze")
|
| 34 |
import spaces
|
| 35 |
|
| 36 |
import gradio as gr
|
| 37 |
import torch as torch
|
| 38 |
+
from diffusers import MarigoldIntrinsicsPipeline, DDIMScheduler
|
| 39 |
from gradio_dualvision import DualVisionApp
|
| 40 |
from huggingface_hub import login
|
| 41 |
from PIL import Image
|
| 42 |
|
|
|
|
|
|
|
|
|
|
| 43 |
CHECKPOINT_APPEARANCE = "prs-eth/marigold-iid-appearance-v1-1"
|
| 44 |
CHECKPOINT_LIGHTING = "prs-eth/marigold-iid-lighting-v1-1"
|
| 45 |
|
|
|
|
| 49 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 51 |
|
| 52 |
+
pipe_appearance = MarigoldIntrinsicsPipeline.from_pretrained(CHECKPOINT_APPEARANCE)
|
| 53 |
+
pipe_appearance.scheduler = DDIMScheduler.from_config(pipe_appearance.scheduler.config, timestep_spacing="trailing")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
pipe_appearance = pipe_appearance.to(device=device, dtype=dtype)
|
| 55 |
+
pipe_lighting = MarigoldIntrinsicsPipeline.from_pretrained(CHECKPOINT_LIGHTING)
|
| 56 |
+
pipe_lighting.scheduler = DDIMScheduler.from_config(pipe_lighting.scheduler.config, timestep_spacing="trailing")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
pipe_lighting = pipe_lighting.to(device=device, dtype=dtype)
|
| 58 |
try:
|
| 59 |
import xformers
|
|
|
|
| 73 |
def make_header(self):
|
| 74 |
gr.Markdown(
|
| 75 |
"""
|
| 76 |
+
## [Marigold Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-intrinsics)
|
| 77 |
"""
|
| 78 |
)
|
| 79 |
with gr.Row(elem_classes="remove-elements"):
|
|
|
|
| 83 |
<a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
| 84 |
<img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue">
|
| 85 |
</a>
|
| 86 |
+
<a title="diffusers" href="https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
| 87 |
+
<img src="https://img.shields.io/badge/%F0%9F%A7%A8%20Read_diffusers-tutorial-yellow?labelColor=green">
|
| 88 |
+
</a>
|
| 89 |
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
| 90 |
<img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-AF3436">
|
| 91 |
</a>
|
|
|
|
| 145 |
ensemble_size = kwargs.get("ensemble_size", self.DEFAULT_ENSEMBLE_SIZE)
|
| 146 |
denoise_steps = kwargs.get("denoise_steps", self.DEFAULT_DENOISE_STEPS)
|
| 147 |
processing_res = kwargs.get("processing_res", self.DEFAULT_PROCESSING_RES)
|
| 148 |
+
generator = torch.Generator(device=device).manual_seed(self.DEFAULT_SEED)
|
| 149 |
|
| 150 |
pipe_out_appearance = pipe_appearance(
|
| 151 |
image_in,
|
|
|
|
| 154 |
processing_resolution=processing_res,
|
| 155 |
batch_size=1 if processing_res == 0 else 2,
|
| 156 |
output_uncertainty=ensemble_size >= 3,
|
| 157 |
+
generator=generator,
|
|
|
|
| 158 |
)
|
| 159 |
|
| 160 |
+
iid_appearance_vis = pipe_appearance.image_processor.visualize_intrinsics(
|
| 161 |
+
pipe_out_appearance.prediction, pipe_appearance.target_properties
|
| 162 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
pipe_out_lighting = pipe_lighting(
|
| 165 |
image_in,
|
|
|
|
| 168 |
processing_resolution=processing_res,
|
| 169 |
batch_size=1 if processing_res == 0 else 2,
|
| 170 |
output_uncertainty=ensemble_size >= 3,
|
| 171 |
+
generator=generator,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
iid_lighting_vis = pipe_lighting.image_processor.visualize_intrinsics(
|
| 175 |
+
pipe_out_lighting.prediction, pipe_lighting.target_properties
|
| 176 |
)
|
| 177 |
|
| 178 |
out_modalities = {
|
| 179 |
+
"Albedo": iid_appearance_vis[0]["albedo"],
|
| 180 |
+
"Materials": iid_appearance_vis[0]["material"],
|
| 181 |
+
"Roughness": iid_appearance_vis[0]["roughness"],
|
| 182 |
+
"Metallicity": iid_appearance_vis[0]["metallicity"],
|
| 183 |
+
"Albedo (HyperSim)": iid_lighting_vis[0]["albedo"],
|
| 184 |
+
"Shading (HyperSim)": iid_lighting_vis[0]["shading"],
|
| 185 |
+
"Residual (HyperSim)": iid_lighting_vis[0]["residual"],
|
| 186 |
}
|
| 187 |
+
# Additionally, uncertainty can be computed on any of the output modalities; we skip it to keep the demo light
|
|
|
|
|
|
|
| 188 |
|
| 189 |
out_settings = {
|
| 190 |
"ensemble_size": ensemble_size,
|
marigold_iid_appearance.py
DELETED
|
@@ -1,561 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 Anton Obukhov, Bingxin Ke, Bo Li & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
# --------------------------------------------------------------------------
|
| 15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
| 17 |
-
# More information about the method can be found at https://marigoldcomputervision.github.io
|
| 18 |
-
# --------------------------------------------------------------------------
|
| 19 |
-
import logging
|
| 20 |
-
import math
|
| 21 |
-
from typing import Optional, Tuple, Union, Dict, Any
|
| 22 |
-
|
| 23 |
-
import numpy as np
|
| 24 |
-
import torch
|
| 25 |
-
from diffusers import (
|
| 26 |
-
AutoencoderKL,
|
| 27 |
-
DDIMScheduler,
|
| 28 |
-
DiffusionPipeline,
|
| 29 |
-
UNet2DConditionModel,
|
| 30 |
-
)
|
| 31 |
-
from diffusers.utils import BaseOutput, check_min_version
|
| 32 |
-
from PIL import Image
|
| 33 |
-
from PIL.Image import Resampling
|
| 34 |
-
from torch.utils.data import DataLoader, TensorDataset
|
| 35 |
-
from tqdm.auto import tqdm
|
| 36 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
| 37 |
-
|
| 38 |
-
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 39 |
-
check_min_version("0.27.0.dev0")
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class MarigoldIIDAppearanceOutput(BaseOutput):
|
| 43 |
-
"""
|
| 44 |
-
Output class for Marigold IID Appearance pipeline.
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
albedo (`np.ndarray`):
|
| 48 |
-
Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1].
|
| 49 |
-
albedo_colored (`PIL.Image.Image`):
|
| 50 |
-
Colorized albedo map with the shape of [H, W, 3].
|
| 51 |
-
material (`np.ndarray`):
|
| 52 |
-
Predicted material map with the shape of [3, H, W] and values in [0, 1].
|
| 53 |
-
1st channel (Red) is roughness
|
| 54 |
-
2nd channel (Green) is metallicity
|
| 55 |
-
3rd channel (Blue) is empty (zero)
|
| 56 |
-
material_colored (`PIL.Image.Image`):
|
| 57 |
-
Colorized material map with the shape of [H, W, 3].
|
| 58 |
-
1st channel (Red) is roughness
|
| 59 |
-
2nd channel (Green) is metallicity
|
| 60 |
-
3rd channel (Blue) is empty (zero)
|
| 61 |
-
"""
|
| 62 |
-
|
| 63 |
-
albedo: np.ndarray
|
| 64 |
-
albedo_colored: Image.Image
|
| 65 |
-
material: np.ndarray
|
| 66 |
-
material_colored: Image.Image
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
class MarigoldIIDAppearancePipeline(DiffusionPipeline):
|
| 70 |
-
"""
|
| 71 |
-
Pipeline for Intrinsic Image Decomposition (Albedo and Material) using Marigold: https://marigoldcomputervision.github.io.
|
| 72 |
-
|
| 73 |
-
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 74 |
-
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 75 |
-
|
| 76 |
-
Args:
|
| 77 |
-
unet (`UNet2DConditionModel`):
|
| 78 |
-
Conditional U-Net to denoise the normals latent, conditioned on image latent.
|
| 79 |
-
vae (`AutoencoderKL`):
|
| 80 |
-
Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps
|
| 81 |
-
to and from latent representations.
|
| 82 |
-
scheduler (`DDIMScheduler`):
|
| 83 |
-
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
| 84 |
-
text_encoder (`CLIPTextModel`):
|
| 85 |
-
Text-encoder, for empty text embedding.
|
| 86 |
-
tokenizer (`CLIPTokenizer`):
|
| 87 |
-
CLIP tokenizer.
|
| 88 |
-
"""
|
| 89 |
-
|
| 90 |
-
latent_scale_factor = 0.18215
|
| 91 |
-
|
| 92 |
-
def __init__(
|
| 93 |
-
self,
|
| 94 |
-
unet: UNet2DConditionModel,
|
| 95 |
-
vae: AutoencoderKL,
|
| 96 |
-
scheduler: DDIMScheduler,
|
| 97 |
-
text_encoder: CLIPTextModel,
|
| 98 |
-
tokenizer: CLIPTokenizer,
|
| 99 |
-
prediction_type: Optional[str] = None,
|
| 100 |
-
target_properties: Optional[Dict[str, Any]] = None,
|
| 101 |
-
default_denoising_steps: Optional[int] = None,
|
| 102 |
-
default_processing_resolution: Optional[int] = None,
|
| 103 |
-
):
|
| 104 |
-
super().__init__()
|
| 105 |
-
|
| 106 |
-
self.register_modules(
|
| 107 |
-
unet=unet,
|
| 108 |
-
vae=vae,
|
| 109 |
-
scheduler=scheduler,
|
| 110 |
-
text_encoder=text_encoder,
|
| 111 |
-
tokenizer=tokenizer,
|
| 112 |
-
)
|
| 113 |
-
self.register_to_config(
|
| 114 |
-
prediction_type=prediction_type,
|
| 115 |
-
target_properties=target_properties,
|
| 116 |
-
default_denoising_steps=default_denoising_steps,
|
| 117 |
-
default_processing_resolution=default_processing_resolution,
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
self.empty_text_embed = None
|
| 121 |
-
|
| 122 |
-
self.n_targets = 2 # Albedo and material
|
| 123 |
-
|
| 124 |
-
@torch.no_grad()
|
| 125 |
-
def __call__(
|
| 126 |
-
self,
|
| 127 |
-
input_image: Image,
|
| 128 |
-
denoising_steps: int = 4,
|
| 129 |
-
ensemble_size: int = 10,
|
| 130 |
-
processing_res: int = 768,
|
| 131 |
-
match_input_res: bool = True,
|
| 132 |
-
resample_method: str = "bilinear",
|
| 133 |
-
batch_size: int = 0,
|
| 134 |
-
save_memory: bool = False,
|
| 135 |
-
seed: Union[int, None] = None,
|
| 136 |
-
color_map: str = "Spectral", # TODO change colorization api based on modality
|
| 137 |
-
show_progress_bar: bool = True,
|
| 138 |
-
**kwargs,
|
| 139 |
-
) -> MarigoldIIDAppearanceOutput:
|
| 140 |
-
"""
|
| 141 |
-
Function invoked when calling the pipeline.
|
| 142 |
-
|
| 143 |
-
Args:
|
| 144 |
-
input_image (`Image`):
|
| 145 |
-
Input RGB (or gray-scale) image.
|
| 146 |
-
denoising_steps (`int`, *optional*, defaults to `10`):
|
| 147 |
-
Number of diffusion denoising steps (DDIM) during inference.
|
| 148 |
-
ensemble_size (`int`, *optional*, defaults to `10`):
|
| 149 |
-
Number of predictions to be ensembled.
|
| 150 |
-
processing_res (`int`, *optional*, defaults to `768`):
|
| 151 |
-
Maximum resolution of processing.
|
| 152 |
-
If set to 0: will not resize at all.
|
| 153 |
-
match_input_res (`bool`, *optional*, defaults to `True`):
|
| 154 |
-
Resize normals prediction to match input resolution.
|
| 155 |
-
Only valid if `limit_input_res` is not None.
|
| 156 |
-
resample_method: (`str`, *optional*, defaults to `bilinear`):
|
| 157 |
-
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
|
| 158 |
-
batch_size (`int`, *optional*, defaults to `0`):
|
| 159 |
-
Inference batch size, no bigger than `num_ensemble`.
|
| 160 |
-
If set to 0, the script will automatically decide the proper batch size.
|
| 161 |
-
save_memory (`bool`, defaults to `False`):
|
| 162 |
-
Extra steps to save memory at the cost of perforance.
|
| 163 |
-
seed (`int`, *optional*, defaults to `None`)
|
| 164 |
-
Reproducibility seed.
|
| 165 |
-
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation):
|
| 166 |
-
Colormap used to colorize the normals map.
|
| 167 |
-
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
| 168 |
-
Display a progress bar of diffusion denoising.
|
| 169 |
-
Returns:
|
| 170 |
-
`MarigoldIIDAppearanceOutput`: Output class for Marigold monocular intrinsic image decomposition (appearance) prediction pipeline, including:
|
| 171 |
-
- **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
|
| 172 |
-
- **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
|
| 173 |
-
- **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
|
| 174 |
-
- **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1]
|
| 175 |
-
"""
|
| 176 |
-
|
| 177 |
-
if not match_input_res:
|
| 178 |
-
assert processing_res is not None
|
| 179 |
-
assert processing_res >= 0
|
| 180 |
-
assert denoising_steps >= 1
|
| 181 |
-
assert ensemble_size >= 1
|
| 182 |
-
|
| 183 |
-
# Check if denoising step is reasonable
|
| 184 |
-
self.check_inference_step(denoising_steps)
|
| 185 |
-
|
| 186 |
-
resample_method: Resampling = self.get_pil_resample_method(resample_method)
|
| 187 |
-
|
| 188 |
-
W, H = input_image.size
|
| 189 |
-
|
| 190 |
-
if processing_res > 0:
|
| 191 |
-
input_image = self.resize_max_res(
|
| 192 |
-
input_image,
|
| 193 |
-
max_edge_resolution=processing_res,
|
| 194 |
-
resample_method=resample_method,
|
| 195 |
-
)
|
| 196 |
-
input_image = input_image.convert("RGB")
|
| 197 |
-
image = np.asarray(input_image)
|
| 198 |
-
|
| 199 |
-
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
|
| 200 |
-
rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
|
| 201 |
-
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
|
| 202 |
-
rgb_norm = rgb_norm.to(self.device)
|
| 203 |
-
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this
|
| 204 |
-
|
| 205 |
-
def ensemble(
|
| 206 |
-
targets: torch.Tensor,
|
| 207 |
-
return_uncertainty: bool = False,
|
| 208 |
-
reduction="median",
|
| 209 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 210 |
-
uncertainty = None
|
| 211 |
-
if reduction == "mean":
|
| 212 |
-
prediction = torch.mean(targets, dim=0, keepdim=True)
|
| 213 |
-
if return_uncertainty:
|
| 214 |
-
uncertainty = torch.std(targets, dim=0, keepdim=True)
|
| 215 |
-
elif reduction == "median":
|
| 216 |
-
prediction = torch.median(targets, dim=0, keepdim=True).values
|
| 217 |
-
if return_uncertainty:
|
| 218 |
-
uncertainty = torch.median(
|
| 219 |
-
torch.abs(targets - prediction), dim=0, keepdim=True
|
| 220 |
-
).values
|
| 221 |
-
else:
|
| 222 |
-
raise ValueError(f"Unrecognized reduction method: {reduction}.")
|
| 223 |
-
return prediction, uncertainty
|
| 224 |
-
|
| 225 |
-
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
| 226 |
-
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
| 227 |
-
|
| 228 |
-
if batch_size <= 0:
|
| 229 |
-
batch_size = self.find_batch_size(
|
| 230 |
-
ensemble_size=ensemble_size,
|
| 231 |
-
input_res=max(rgb_norm.shape[1:]),
|
| 232 |
-
dtype=self.dtype,
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
single_rgb_loader = DataLoader(
|
| 236 |
-
single_rgb_dataset, batch_size=batch_size, shuffle=False
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
target_pred_ls = []
|
| 240 |
-
iterable = single_rgb_loader
|
| 241 |
-
if show_progress_bar:
|
| 242 |
-
iterable = tqdm(
|
| 243 |
-
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
|
| 244 |
-
)
|
| 245 |
-
|
| 246 |
-
for batch in iterable:
|
| 247 |
-
(batched_img,) = batch
|
| 248 |
-
target_pred = self.single_infer(
|
| 249 |
-
rgb_in=batched_img,
|
| 250 |
-
num_inference_steps=denoising_steps,
|
| 251 |
-
seed=seed,
|
| 252 |
-
show_pbar=show_progress_bar,
|
| 253 |
-
)
|
| 254 |
-
target_pred = target_pred.detach()
|
| 255 |
-
if save_memory:
|
| 256 |
-
target_pred = target_pred.cpu()
|
| 257 |
-
target_pred_ls.append(target_pred.detach())
|
| 258 |
-
|
| 259 |
-
target_preds = torch.concat(target_pred_ls, dim=0)
|
| 260 |
-
pred_uncert = None
|
| 261 |
-
|
| 262 |
-
if save_memory:
|
| 263 |
-
torch.cuda.empty_cache()
|
| 264 |
-
|
| 265 |
-
if ensemble_size > 1:
|
| 266 |
-
final_pred, pred_uncert = ensemble(
|
| 267 |
-
target_preds, reduction="median", return_uncertainty=False
|
| 268 |
-
)
|
| 269 |
-
else:
|
| 270 |
-
final_pred = target_preds
|
| 271 |
-
pred_uncert = None
|
| 272 |
-
|
| 273 |
-
if match_input_res:
|
| 274 |
-
final_pred = torch.nn.functional.interpolate(
|
| 275 |
-
final_pred, (H, W), mode="bilinear" # TODO: parameterize this method
|
| 276 |
-
) # [1,3,H,W]
|
| 277 |
-
|
| 278 |
-
if pred_uncert is not None:
|
| 279 |
-
pred_uncert = torch.nn.functional.interpolate(
|
| 280 |
-
pred_uncert.unsqueeze(1), (H, W), mode="bilinear"
|
| 281 |
-
).squeeze(
|
| 282 |
-
1
|
| 283 |
-
) # [1,H,W]
|
| 284 |
-
|
| 285 |
-
# Convert to numpy
|
| 286 |
-
final_pred = final_pred.squeeze()
|
| 287 |
-
final_pred = final_pred.cpu().float().numpy()
|
| 288 |
-
|
| 289 |
-
albedo = final_pred[0:3, :, :]
|
| 290 |
-
material = np.stack(
|
| 291 |
-
(final_pred[3, :, :], final_pred[4, :, :], final_pred[5, :, :]), axis=0
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
albedo_colored = (albedo + 1.0) * 0.5
|
| 295 |
-
albedo_colored = (albedo_colored * 255).astype(np.uint8)
|
| 296 |
-
albedo_colored = self.chw2hwc(albedo_colored)
|
| 297 |
-
albedo_colored_img = Image.fromarray(albedo_colored)
|
| 298 |
-
|
| 299 |
-
material_colored = (material + 1.0) * 0.5
|
| 300 |
-
material_colored = (material_colored * 255).astype(np.uint8)
|
| 301 |
-
material_colored = self.chw2hwc(material_colored)
|
| 302 |
-
material_colored_img = Image.fromarray(material_colored)
|
| 303 |
-
|
| 304 |
-
out = MarigoldIIDAppearanceOutput(
|
| 305 |
-
albedo=albedo,
|
| 306 |
-
albedo_colored=albedo_colored_img,
|
| 307 |
-
material=material,
|
| 308 |
-
material_colored=material_colored_img,
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
return out
|
| 312 |
-
|
| 313 |
-
def check_inference_step(self, n_step: int):
|
| 314 |
-
"""
|
| 315 |
-
Check if denoising step is reasonable
|
| 316 |
-
Args:
|
| 317 |
-
n_step (`int`): denoising steps
|
| 318 |
-
"""
|
| 319 |
-
assert n_step >= 1
|
| 320 |
-
|
| 321 |
-
if isinstance(self.scheduler, DDIMScheduler):
|
| 322 |
-
pass
|
| 323 |
-
else:
|
| 324 |
-
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
|
| 325 |
-
|
| 326 |
-
def encode_empty_text(self):
|
| 327 |
-
"""
|
| 328 |
-
Encode text embedding for empty prompt.
|
| 329 |
-
"""
|
| 330 |
-
prompt = ""
|
| 331 |
-
text_inputs = self.tokenizer(
|
| 332 |
-
prompt,
|
| 333 |
-
padding="do_not_pad",
|
| 334 |
-
max_length=self.tokenizer.model_max_length,
|
| 335 |
-
truncation=True,
|
| 336 |
-
return_tensors="pt",
|
| 337 |
-
)
|
| 338 |
-
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
| 339 |
-
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
| 340 |
-
|
| 341 |
-
@torch.no_grad()
|
| 342 |
-
def single_infer(
|
| 343 |
-
self,
|
| 344 |
-
rgb_in: torch.Tensor,
|
| 345 |
-
num_inference_steps: int,
|
| 346 |
-
seed: Union[int, None],
|
| 347 |
-
show_pbar: bool,
|
| 348 |
-
) -> torch.Tensor:
|
| 349 |
-
"""
|
| 350 |
-
Perform an individual iid prediction without ensembling.
|
| 351 |
-
"""
|
| 352 |
-
device = rgb_in.device
|
| 353 |
-
|
| 354 |
-
# Set timesteps
|
| 355 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 356 |
-
timesteps = self.scheduler.timesteps # [T]
|
| 357 |
-
|
| 358 |
-
# Encode image
|
| 359 |
-
rgb_latent = self.encode_rgb(rgb_in)
|
| 360 |
-
|
| 361 |
-
target_latent_shape = list(rgb_latent.shape)
|
| 362 |
-
target_latent_shape[
|
| 363 |
-
1
|
| 364 |
-
] *= 2 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w)
|
| 365 |
-
|
| 366 |
-
# Initialize prediction latent with noise
|
| 367 |
-
if seed is None:
|
| 368 |
-
rand_num_generator = None
|
| 369 |
-
else:
|
| 370 |
-
rand_num_generator = torch.Generator(device=device)
|
| 371 |
-
rand_num_generator.manual_seed(seed)
|
| 372 |
-
target_latents = torch.randn(
|
| 373 |
-
target_latent_shape,
|
| 374 |
-
device=device,
|
| 375 |
-
dtype=self.dtype,
|
| 376 |
-
generator=rand_num_generator,
|
| 377 |
-
) # [B, 4, h, w]
|
| 378 |
-
|
| 379 |
-
# Batched empty text embedding
|
| 380 |
-
if self.empty_text_embed is None:
|
| 381 |
-
self.encode_empty_text()
|
| 382 |
-
batch_empty_text_embed = self.empty_text_embed.repeat(
|
| 383 |
-
(rgb_latent.shape[0], 1, 1)
|
| 384 |
-
) # [B, 2, 1024]
|
| 385 |
-
|
| 386 |
-
# Denoising loop
|
| 387 |
-
if show_pbar:
|
| 388 |
-
iterable = tqdm(
|
| 389 |
-
enumerate(timesteps),
|
| 390 |
-
total=len(timesteps),
|
| 391 |
-
leave=False,
|
| 392 |
-
desc=" " * 4 + "Diffusion denoising",
|
| 393 |
-
)
|
| 394 |
-
else:
|
| 395 |
-
iterable = enumerate(timesteps)
|
| 396 |
-
|
| 397 |
-
for i, t in iterable:
|
| 398 |
-
unet_input = torch.cat(
|
| 399 |
-
[rgb_latent, target_latents], dim=1
|
| 400 |
-
) # this order is important
|
| 401 |
-
|
| 402 |
-
# predict the noise residual
|
| 403 |
-
noise_pred = self.unet(
|
| 404 |
-
unet_input, t, encoder_hidden_states=batch_empty_text_embed
|
| 405 |
-
).sample # [B, 4, h, w]
|
| 406 |
-
|
| 407 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 408 |
-
target_latents = self.scheduler.step(
|
| 409 |
-
noise_pred, t, target_latents, generator=rand_num_generator
|
| 410 |
-
).prev_sample
|
| 411 |
-
|
| 412 |
-
# torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving?
|
| 413 |
-
|
| 414 |
-
targets = self.decode_targets(target_latents) # [B, 3, H, W]
|
| 415 |
-
targets = torch.clip(targets, -1.0, 1.0)
|
| 416 |
-
|
| 417 |
-
return targets
|
| 418 |
-
|
| 419 |
-
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
| 420 |
-
"""
|
| 421 |
-
Encode RGB image into latent.
|
| 422 |
-
|
| 423 |
-
Args:
|
| 424 |
-
rgb_in (`torch.Tensor`):
|
| 425 |
-
Input RGB image to be encoded.
|
| 426 |
-
|
| 427 |
-
Returns:
|
| 428 |
-
`torch.Tensor`: Image latent.
|
| 429 |
-
"""
|
| 430 |
-
# encode
|
| 431 |
-
h = self.vae.encoder(rgb_in)
|
| 432 |
-
moments = self.vae.quant_conv(h)
|
| 433 |
-
mean, logvar = torch.chunk(moments, 2, dim=1)
|
| 434 |
-
# scale latent
|
| 435 |
-
rgb_latent = mean * self.latent_scale_factor
|
| 436 |
-
return rgb_latent
|
| 437 |
-
|
| 438 |
-
def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor:
|
| 439 |
-
"""
|
| 440 |
-
Decode target latent into target map.
|
| 441 |
-
|
| 442 |
-
Args:
|
| 443 |
-
target_latents (`torch.Tensor`):
|
| 444 |
-
Target latent to be decoded.
|
| 445 |
-
|
| 446 |
-
Returns:
|
| 447 |
-
`torch.Tensor`: Decoded target map.
|
| 448 |
-
"""
|
| 449 |
-
|
| 450 |
-
assert target_latents.shape[1] == 8 # self.n_targets * 4
|
| 451 |
-
|
| 452 |
-
# scale latent
|
| 453 |
-
target_latents = target_latents / self.latent_scale_factor
|
| 454 |
-
# decode
|
| 455 |
-
targets = []
|
| 456 |
-
for i in range(self.n_targets):
|
| 457 |
-
latent = target_latents[:, i * 4 : (i + 1) * 4, :, :]
|
| 458 |
-
z = self.vae.post_quant_conv(latent)
|
| 459 |
-
stacked = self.vae.decoder(z)
|
| 460 |
-
|
| 461 |
-
targets.append(stacked)
|
| 462 |
-
|
| 463 |
-
return torch.cat(targets, dim=1)
|
| 464 |
-
|
| 465 |
-
@staticmethod
|
| 466 |
-
def get_pil_resample_method(method_str: str) -> Resampling:
|
| 467 |
-
resample_method_dic = {
|
| 468 |
-
"bilinear": Resampling.BILINEAR,
|
| 469 |
-
"bicubic": Resampling.BICUBIC,
|
| 470 |
-
"nearest": Resampling.NEAREST,
|
| 471 |
-
}
|
| 472 |
-
resample_method = resample_method_dic.get(method_str, None)
|
| 473 |
-
if resample_method is None:
|
| 474 |
-
raise ValueError(f"Unknown resampling method: {resample_method}")
|
| 475 |
-
else:
|
| 476 |
-
return resample_method
|
| 477 |
-
|
| 478 |
-
@staticmethod
|
| 479 |
-
def resize_max_res(
|
| 480 |
-
img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR
|
| 481 |
-
) -> Image.Image:
|
| 482 |
-
"""
|
| 483 |
-
Resize image to limit maximum edge length while keeping aspect ratio.
|
| 484 |
-
"""
|
| 485 |
-
original_width, original_height = img.size
|
| 486 |
-
downscale_factor = min(
|
| 487 |
-
max_edge_resolution / original_width, max_edge_resolution / original_height
|
| 488 |
-
)
|
| 489 |
-
|
| 490 |
-
new_width = int(original_width * downscale_factor)
|
| 491 |
-
new_height = int(original_height * downscale_factor)
|
| 492 |
-
|
| 493 |
-
resized_img = img.resize((new_width, new_height), resample=resample_method)
|
| 494 |
-
return resized_img
|
| 495 |
-
|
| 496 |
-
@staticmethod
|
| 497 |
-
def chw2hwc(chw):
|
| 498 |
-
assert 3 == len(chw.shape)
|
| 499 |
-
if isinstance(chw, torch.Tensor):
|
| 500 |
-
hwc = torch.permute(chw, (1, 2, 0))
|
| 501 |
-
elif isinstance(chw, np.ndarray):
|
| 502 |
-
hwc = np.moveaxis(chw, 0, -1)
|
| 503 |
-
return hwc
|
| 504 |
-
|
| 505 |
-
@staticmethod
|
| 506 |
-
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
| 507 |
-
"""
|
| 508 |
-
Automatically search for suitable operating batch size.
|
| 509 |
-
|
| 510 |
-
Args:
|
| 511 |
-
ensemble_size (`int`):
|
| 512 |
-
Number of predictions to be ensembled.
|
| 513 |
-
input_res (`int`):
|
| 514 |
-
Operating resolution of the input image.
|
| 515 |
-
|
| 516 |
-
Returns:
|
| 517 |
-
`int`: Operating batch size.
|
| 518 |
-
"""
|
| 519 |
-
# Search table for suggested max. inference batch size
|
| 520 |
-
bs_search_table = [
|
| 521 |
-
# tested on A100-PCIE-80GB
|
| 522 |
-
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
| 523 |
-
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
| 524 |
-
# tested on A100-PCIE-40GB
|
| 525 |
-
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
| 526 |
-
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
| 527 |
-
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
| 528 |
-
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
| 529 |
-
# tested on RTX3090, RTX4090
|
| 530 |
-
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
| 531 |
-
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
| 532 |
-
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
| 533 |
-
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
| 534 |
-
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
| 535 |
-
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
| 536 |
-
# tested on GTX1080Ti
|
| 537 |
-
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
| 538 |
-
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
| 539 |
-
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
| 540 |
-
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
| 541 |
-
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
| 542 |
-
]
|
| 543 |
-
|
| 544 |
-
if not torch.cuda.is_available():
|
| 545 |
-
return 1
|
| 546 |
-
|
| 547 |
-
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
| 548 |
-
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
| 549 |
-
for settings in sorted(
|
| 550 |
-
filtered_bs_search_table,
|
| 551 |
-
key=lambda k: (k["res"], -k["total_vram"]),
|
| 552 |
-
):
|
| 553 |
-
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
| 554 |
-
bs = settings["bs"]
|
| 555 |
-
if bs > ensemble_size:
|
| 556 |
-
bs = ensemble_size
|
| 557 |
-
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
| 558 |
-
bs = math.ceil(ensemble_size / 2)
|
| 559 |
-
return bs
|
| 560 |
-
|
| 561 |
-
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
marigold_iid_lighting.py
DELETED
|
@@ -1,576 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 Anton Obukhov, Bingxin Ke & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
# --------------------------------------------------------------------------
|
| 15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
| 17 |
-
# More information about the method can be found at https://marigoldcomputervision.github.io
|
| 18 |
-
# --------------------------------------------------------------------------
|
| 19 |
-
import logging
|
| 20 |
-
import math
|
| 21 |
-
from typing import Optional, Tuple, Union, Dict, Any
|
| 22 |
-
|
| 23 |
-
import numpy as np
|
| 24 |
-
import torch
|
| 25 |
-
from diffusers import (
|
| 26 |
-
AutoencoderKL,
|
| 27 |
-
DDIMScheduler,
|
| 28 |
-
DiffusionPipeline,
|
| 29 |
-
UNet2DConditionModel,
|
| 30 |
-
)
|
| 31 |
-
from diffusers.utils import BaseOutput, check_min_version
|
| 32 |
-
from PIL import Image
|
| 33 |
-
from PIL.Image import Resampling
|
| 34 |
-
from torch.utils.data import DataLoader, TensorDataset
|
| 35 |
-
from tqdm.auto import tqdm
|
| 36 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
| 37 |
-
|
| 38 |
-
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 39 |
-
check_min_version("0.27.0.dev0")
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class MarigoldIIDLightingOutput(BaseOutput):
|
| 43 |
-
"""
|
| 44 |
-
Output class for Marigold-IID-Lighting pipeline.
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
albedo (`np.ndarray`):
|
| 48 |
-
Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1].
|
| 49 |
-
albedo_colored (`PIL.Image.Image`):
|
| 50 |
-
Colorized albedo map with the shape of [H, W, 3].
|
| 51 |
-
shading (`np.ndarray`):
|
| 52 |
-
Predicted diffuse shading map with the shape of [3, H, W] values in the range of [0, 1].
|
| 53 |
-
shading_colored (`PIL.Image.Image`):
|
| 54 |
-
Colorized diffuse shading map with the shape of [H, W, 3].
|
| 55 |
-
residual (`np.ndarray`):
|
| 56 |
-
Predicted non-diffuse residual map with the shape of [3, H, W] values in the range of [0, 1].
|
| 57 |
-
residual_colored (`PIL.Image.Image`):
|
| 58 |
-
Colorized non-diffuse residual map with the shape of [H, W, 3].
|
| 59 |
-
|
| 60 |
-
"""
|
| 61 |
-
|
| 62 |
-
albedo: np.ndarray
|
| 63 |
-
albedo_colored: Image.Image
|
| 64 |
-
shading: np.ndarray
|
| 65 |
-
shading_colored: Image.Image
|
| 66 |
-
residual: np.ndarray
|
| 67 |
-
residual_colored: Image.Image
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
class MarigoldIIDLightingPipeline(DiffusionPipeline):
|
| 71 |
-
"""
|
| 72 |
-
Pipeline for Intrinsic Image Decomposition (Albedo, diffuse shading and non-diffuse residual) using Marigold: https://marigoldcomputervision.github.io.
|
| 73 |
-
|
| 74 |
-
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 75 |
-
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
unet (`UNet2DConditionModel`):
|
| 79 |
-
Conditional U-Net to denoise the normals latent, conditioned on image latent.
|
| 80 |
-
vae (`AutoencoderKL`):
|
| 81 |
-
Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps
|
| 82 |
-
to and from latent representations.
|
| 83 |
-
scheduler (`DDIMScheduler`):
|
| 84 |
-
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
| 85 |
-
text_encoder (`CLIPTextModel`):
|
| 86 |
-
Text-encoder, for empty text embedding.
|
| 87 |
-
tokenizer (`CLIPTokenizer`):
|
| 88 |
-
CLIP tokenizer.
|
| 89 |
-
"""
|
| 90 |
-
|
| 91 |
-
latent_scale_factor = 0.18215
|
| 92 |
-
|
| 93 |
-
def __init__(
|
| 94 |
-
self,
|
| 95 |
-
unet: UNet2DConditionModel,
|
| 96 |
-
vae: AutoencoderKL,
|
| 97 |
-
scheduler: DDIMScheduler,
|
| 98 |
-
text_encoder: CLIPTextModel,
|
| 99 |
-
tokenizer: CLIPTokenizer,
|
| 100 |
-
prediction_type: Optional[str] = None,
|
| 101 |
-
target_properties: Optional[Dict[str, Any]] = None,
|
| 102 |
-
default_denoising_steps: Optional[int] = None,
|
| 103 |
-
default_processing_resolution: Optional[int] = None,
|
| 104 |
-
):
|
| 105 |
-
super().__init__()
|
| 106 |
-
|
| 107 |
-
self.register_modules(
|
| 108 |
-
unet=unet,
|
| 109 |
-
vae=vae,
|
| 110 |
-
scheduler=scheduler,
|
| 111 |
-
text_encoder=text_encoder,
|
| 112 |
-
tokenizer=tokenizer,
|
| 113 |
-
)
|
| 114 |
-
self.register_to_config(
|
| 115 |
-
prediction_type=prediction_type,
|
| 116 |
-
target_properties=target_properties,
|
| 117 |
-
default_denoising_steps=default_denoising_steps,
|
| 118 |
-
default_processing_resolution=default_processing_resolution,
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
self.empty_text_embed = None
|
| 122 |
-
self.n_targets = 3 # Albedo, shading, residual
|
| 123 |
-
|
| 124 |
-
@torch.no_grad()
|
| 125 |
-
def __call__(
|
| 126 |
-
self,
|
| 127 |
-
input_image: Image,
|
| 128 |
-
denoising_steps: int = 4,
|
| 129 |
-
ensemble_size: int = 10,
|
| 130 |
-
processing_res: int = 768,
|
| 131 |
-
match_input_res: bool = True,
|
| 132 |
-
resample_method: str = "bilinear",
|
| 133 |
-
batch_size: int = 0,
|
| 134 |
-
save_memory: bool = False,
|
| 135 |
-
seed: Union[int, None] = None,
|
| 136 |
-
color_map: str = "Spectral", # TODO change colorization api based on modality
|
| 137 |
-
show_progress_bar: bool = True,
|
| 138 |
-
**kwargs,
|
| 139 |
-
) -> MarigoldIIDLightingOutput:
|
| 140 |
-
"""
|
| 141 |
-
Function invoked when calling the pipeline.
|
| 142 |
-
|
| 143 |
-
Args:
|
| 144 |
-
input_image (`Image`):
|
| 145 |
-
Input RGB (or gray-scale) image.
|
| 146 |
-
denoising_steps (`int`, *optional*, defaults to `10`):
|
| 147 |
-
Number of diffusion denoising steps (DDIM) during inference.
|
| 148 |
-
ensemble_size (`int`, *optional*, defaults to `10`):
|
| 149 |
-
Number of predictions to be ensembled.
|
| 150 |
-
processing_res (`int`, *optional*, defaults to `768`):
|
| 151 |
-
Maximum resolution of processing.
|
| 152 |
-
If set to 0: will not resize at all.
|
| 153 |
-
match_input_res (`bool`, *optional*, defaults to `True`):
|
| 154 |
-
Resize normals prediction to match input resolution.
|
| 155 |
-
Only valid if `limit_input_res` is not None.
|
| 156 |
-
resample_method: (`str`, *optional*, defaults to `bilinear`):
|
| 157 |
-
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
|
| 158 |
-
batch_size (`int`, *optional*, defaults to `0`):
|
| 159 |
-
Inference batch size, no bigger than `num_ensemble`.
|
| 160 |
-
If set to 0, the script will automatically decide the proper batch size.
|
| 161 |
-
save_memory (`bool`, defaults to `False`):
|
| 162 |
-
Extra steps to save memory at the cost of perforance.
|
| 163 |
-
seed (`int`, *optional*, defaults to `None`)
|
| 164 |
-
Reproducibility seed.
|
| 165 |
-
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation):
|
| 166 |
-
Colormap used to colorize the normals map.
|
| 167 |
-
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
| 168 |
-
Display a progress bar of diffusion denoising.
|
| 169 |
-
Returns:
|
| 170 |
-
`MarigoldIIDLightingOutput`: Output class for Marigold monocular intrinsic image decomposition (lighting) prediction pipeline, including:
|
| 171 |
-
- **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
|
| 172 |
-
- **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
|
| 173 |
-
- **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
|
| 174 |
-
- **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1]
|
| 175 |
-
"""
|
| 176 |
-
|
| 177 |
-
if not match_input_res:
|
| 178 |
-
assert processing_res is not None
|
| 179 |
-
assert processing_res >= 0
|
| 180 |
-
assert denoising_steps >= 1
|
| 181 |
-
assert ensemble_size >= 1
|
| 182 |
-
|
| 183 |
-
# Check if denoising step is reasonable
|
| 184 |
-
self.check_inference_step(denoising_steps)
|
| 185 |
-
|
| 186 |
-
resample_method: Resampling = self.get_pil_resample_method(resample_method)
|
| 187 |
-
|
| 188 |
-
W, H = input_image.size
|
| 189 |
-
|
| 190 |
-
if processing_res > 0:
|
| 191 |
-
input_image = self.resize_max_res(
|
| 192 |
-
input_image,
|
| 193 |
-
max_edge_resolution=processing_res,
|
| 194 |
-
resample_method=resample_method,
|
| 195 |
-
)
|
| 196 |
-
input_image = input_image.convert("RGB")
|
| 197 |
-
image = np.asarray(input_image)
|
| 198 |
-
|
| 199 |
-
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
|
| 200 |
-
rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
|
| 201 |
-
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
|
| 202 |
-
rgb_norm = rgb_norm.to(self.device)
|
| 203 |
-
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this
|
| 204 |
-
|
| 205 |
-
def ensemble(
|
| 206 |
-
targets: torch.Tensor,
|
| 207 |
-
return_uncertainty: bool = False,
|
| 208 |
-
reduction="median",
|
| 209 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 210 |
-
uncertainty = None
|
| 211 |
-
if reduction == "mean":
|
| 212 |
-
prediction = torch.mean(targets, dim=0, keepdim=True)
|
| 213 |
-
if return_uncertainty:
|
| 214 |
-
uncertainty = torch.std(targets, dim=0, keepdim=True)
|
| 215 |
-
elif reduction == "median":
|
| 216 |
-
prediction = torch.median(targets, dim=0, keepdim=True).values
|
| 217 |
-
if return_uncertainty:
|
| 218 |
-
uncertainty = torch.median(
|
| 219 |
-
torch.abs(targets - prediction), dim=0, keepdim=True
|
| 220 |
-
).values
|
| 221 |
-
else:
|
| 222 |
-
raise ValueError(f"Unrecognized reduction method: {reduction}.")
|
| 223 |
-
return prediction, uncertainty
|
| 224 |
-
|
| 225 |
-
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
| 226 |
-
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
| 227 |
-
|
| 228 |
-
if batch_size <= 0:
|
| 229 |
-
batch_size = self.find_batch_size(
|
| 230 |
-
ensemble_size=ensemble_size,
|
| 231 |
-
input_res=max(rgb_norm.shape[1:]),
|
| 232 |
-
dtype=self.dtype,
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
single_rgb_loader = DataLoader(
|
| 236 |
-
single_rgb_dataset, batch_size=batch_size, shuffle=False
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
target_pred_ls = []
|
| 240 |
-
iterable = single_rgb_loader
|
| 241 |
-
if show_progress_bar:
|
| 242 |
-
iterable = tqdm(
|
| 243 |
-
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
|
| 244 |
-
)
|
| 245 |
-
|
| 246 |
-
for batch in iterable:
|
| 247 |
-
(batched_img,) = batch
|
| 248 |
-
target_pred = self.single_infer(
|
| 249 |
-
rgb_in=batched_img,
|
| 250 |
-
num_inference_steps=denoising_steps,
|
| 251 |
-
seed=seed,
|
| 252 |
-
show_pbar=show_progress_bar,
|
| 253 |
-
)
|
| 254 |
-
target_pred = target_pred.detach()
|
| 255 |
-
if save_memory:
|
| 256 |
-
target_pred = target_pred.cpu()
|
| 257 |
-
target_pred_ls.append(target_pred.detach())
|
| 258 |
-
|
| 259 |
-
target_preds = torch.concat(target_pred_ls, dim=0)
|
| 260 |
-
pred_uncert = None
|
| 261 |
-
|
| 262 |
-
if save_memory:
|
| 263 |
-
torch.cuda.empty_cache()
|
| 264 |
-
|
| 265 |
-
if ensemble_size > 1:
|
| 266 |
-
final_pred, pred_uncert = ensemble(
|
| 267 |
-
target_preds, reduction="median", return_uncertainty=False
|
| 268 |
-
)
|
| 269 |
-
else:
|
| 270 |
-
final_pred = target_preds
|
| 271 |
-
pred_uncert = None
|
| 272 |
-
|
| 273 |
-
if match_input_res:
|
| 274 |
-
final_pred = torch.nn.functional.interpolate(
|
| 275 |
-
final_pred, (H, W), mode="bilinear" # TODO: parameterize this method
|
| 276 |
-
) # [1,3,H,W]
|
| 277 |
-
|
| 278 |
-
if pred_uncert is not None:
|
| 279 |
-
pred_uncert = torch.nn.functional.interpolate(
|
| 280 |
-
pred_uncert.unsqueeze(1), (H, W), mode="bilinear"
|
| 281 |
-
).squeeze(
|
| 282 |
-
1
|
| 283 |
-
) # [1,H,W]
|
| 284 |
-
|
| 285 |
-
# Convert to numpy
|
| 286 |
-
final_pred = final_pred.squeeze()
|
| 287 |
-
final_pred = final_pred.cpu().float().numpy()
|
| 288 |
-
|
| 289 |
-
albedo = final_pred[0:3, :, :]
|
| 290 |
-
shading = final_pred[3:6, :, :]
|
| 291 |
-
residual = final_pred[6:, :, :]
|
| 292 |
-
|
| 293 |
-
albedo_colored = (albedo + 1.0) * 0.5 # [-1,1] -> [0,1]
|
| 294 |
-
albedo_colored = albedo_colored ** (
|
| 295 |
-
1 / 2.2
|
| 296 |
-
) # from linear to sRGB (to be consistent with IID-Appearance model)
|
| 297 |
-
albedo_colored = (albedo_colored * 255).astype(np.uint8)
|
| 298 |
-
albedo_colored = self.chw2hwc(albedo_colored)
|
| 299 |
-
albedo_colored_img = Image.fromarray(albedo_colored)
|
| 300 |
-
|
| 301 |
-
shading_colored = (shading + 1.0) * 0.5
|
| 302 |
-
shading_colored = (
|
| 303 |
-
shading_colored / shading_colored.max()
|
| 304 |
-
) # rescale for better visualization
|
| 305 |
-
shading_colored = (shading_colored * 255).astype(np.uint8)
|
| 306 |
-
shading_colored = self.chw2hwc(shading_colored)
|
| 307 |
-
shading_colored_img = Image.fromarray(shading_colored)
|
| 308 |
-
|
| 309 |
-
residual_colored = (residual + 1.0) * 0.5
|
| 310 |
-
residual_colored = (
|
| 311 |
-
residual_colored / residual_colored.max()
|
| 312 |
-
) # rescale for better visualization
|
| 313 |
-
residual_colored = (residual_colored * 255).astype(np.uint8)
|
| 314 |
-
residual_colored = self.chw2hwc(residual_colored)
|
| 315 |
-
residual_colored_img = Image.fromarray(residual_colored)
|
| 316 |
-
|
| 317 |
-
out = MarigoldIIDLightingOutput(
|
| 318 |
-
albedo=albedo,
|
| 319 |
-
albedo_colored=albedo_colored_img,
|
| 320 |
-
shading=shading,
|
| 321 |
-
shading_colored=shading_colored_img,
|
| 322 |
-
residual=residual,
|
| 323 |
-
residual_colored=residual_colored_img,
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
return out
|
| 327 |
-
|
| 328 |
-
def check_inference_step(self, n_step: int):
|
| 329 |
-
"""
|
| 330 |
-
Check if denoising step is reasonable
|
| 331 |
-
Args:
|
| 332 |
-
n_step (`int`): denoising steps
|
| 333 |
-
"""
|
| 334 |
-
assert n_step >= 1
|
| 335 |
-
|
| 336 |
-
if isinstance(self.scheduler, DDIMScheduler):
|
| 337 |
-
pass
|
| 338 |
-
else:
|
| 339 |
-
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
|
| 340 |
-
|
| 341 |
-
def encode_empty_text(self):
|
| 342 |
-
"""
|
| 343 |
-
Encode text embedding for empty prompt.
|
| 344 |
-
"""
|
| 345 |
-
prompt = ""
|
| 346 |
-
text_inputs = self.tokenizer(
|
| 347 |
-
prompt,
|
| 348 |
-
padding="do_not_pad",
|
| 349 |
-
max_length=self.tokenizer.model_max_length,
|
| 350 |
-
truncation=True,
|
| 351 |
-
return_tensors="pt",
|
| 352 |
-
)
|
| 353 |
-
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
| 354 |
-
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
| 355 |
-
|
| 356 |
-
@torch.no_grad()
|
| 357 |
-
def single_infer(
|
| 358 |
-
self,
|
| 359 |
-
rgb_in: torch.Tensor,
|
| 360 |
-
num_inference_steps: int,
|
| 361 |
-
seed: Union[int, None],
|
| 362 |
-
show_pbar: bool,
|
| 363 |
-
) -> torch.Tensor:
|
| 364 |
-
"""
|
| 365 |
-
Perform an individual iid prediction without ensembling.
|
| 366 |
-
"""
|
| 367 |
-
device = rgb_in.device
|
| 368 |
-
|
| 369 |
-
# Set timesteps
|
| 370 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 371 |
-
timesteps = self.scheduler.timesteps # [T]
|
| 372 |
-
|
| 373 |
-
# Encode image
|
| 374 |
-
rgb_latent = self.encode_rgb(rgb_in)
|
| 375 |
-
|
| 376 |
-
target_latent_shape = list(rgb_latent.shape)
|
| 377 |
-
target_latent_shape[
|
| 378 |
-
1
|
| 379 |
-
] *= 3 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w)
|
| 380 |
-
|
| 381 |
-
# Initialize prediction latent with noise
|
| 382 |
-
if seed is None:
|
| 383 |
-
rand_num_generator = None
|
| 384 |
-
else:
|
| 385 |
-
rand_num_generator = torch.Generator(device=device)
|
| 386 |
-
rand_num_generator.manual_seed(seed)
|
| 387 |
-
target_latents = torch.randn(
|
| 388 |
-
target_latent_shape,
|
| 389 |
-
device=device,
|
| 390 |
-
dtype=self.dtype,
|
| 391 |
-
generator=rand_num_generator,
|
| 392 |
-
) # [B, 4, h, w]
|
| 393 |
-
|
| 394 |
-
# Batched empty text embedding
|
| 395 |
-
if self.empty_text_embed is None:
|
| 396 |
-
self.encode_empty_text()
|
| 397 |
-
batch_empty_text_embed = self.empty_text_embed.repeat(
|
| 398 |
-
(rgb_latent.shape[0], 1, 1)
|
| 399 |
-
) # [B, 2, 1024]
|
| 400 |
-
|
| 401 |
-
# Denoising loop
|
| 402 |
-
if show_pbar:
|
| 403 |
-
iterable = tqdm(
|
| 404 |
-
enumerate(timesteps),
|
| 405 |
-
total=len(timesteps),
|
| 406 |
-
leave=False,
|
| 407 |
-
desc=" " * 4 + "Diffusion denoising",
|
| 408 |
-
)
|
| 409 |
-
else:
|
| 410 |
-
iterable = enumerate(timesteps)
|
| 411 |
-
|
| 412 |
-
for i, t in iterable:
|
| 413 |
-
unet_input = torch.cat(
|
| 414 |
-
[rgb_latent, target_latents], dim=1
|
| 415 |
-
) # this order is important
|
| 416 |
-
|
| 417 |
-
# predict the noise residual
|
| 418 |
-
noise_pred = self.unet(
|
| 419 |
-
unet_input, t, encoder_hidden_states=batch_empty_text_embed
|
| 420 |
-
).sample # [B, 4, h, w]
|
| 421 |
-
|
| 422 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 423 |
-
target_latents = self.scheduler.step(
|
| 424 |
-
noise_pred, t, target_latents, generator=rand_num_generator
|
| 425 |
-
).prev_sample
|
| 426 |
-
|
| 427 |
-
# torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving?
|
| 428 |
-
|
| 429 |
-
targets = self.decode_targets(target_latents) # [B, 3, H, W]
|
| 430 |
-
targets = torch.clip(targets, -1.0, 1.0)
|
| 431 |
-
|
| 432 |
-
return targets
|
| 433 |
-
|
| 434 |
-
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
| 435 |
-
"""
|
| 436 |
-
Encode RGB image into latent.
|
| 437 |
-
|
| 438 |
-
Args:
|
| 439 |
-
rgb_in (`torch.Tensor`):
|
| 440 |
-
Input RGB image to be encoded.
|
| 441 |
-
|
| 442 |
-
Returns:
|
| 443 |
-
`torch.Tensor`: Image latent.
|
| 444 |
-
"""
|
| 445 |
-
# encode
|
| 446 |
-
h = self.vae.encoder(rgb_in)
|
| 447 |
-
moments = self.vae.quant_conv(h)
|
| 448 |
-
mean, logvar = torch.chunk(moments, 2, dim=1)
|
| 449 |
-
# scale latent
|
| 450 |
-
rgb_latent = mean * self.latent_scale_factor
|
| 451 |
-
return rgb_latent
|
| 452 |
-
|
| 453 |
-
def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor:
|
| 454 |
-
"""
|
| 455 |
-
Decode target latent into target map.
|
| 456 |
-
|
| 457 |
-
Args:
|
| 458 |
-
target_latents (`torch.Tensor`):
|
| 459 |
-
Target latent to be decoded.
|
| 460 |
-
|
| 461 |
-
Returns:
|
| 462 |
-
`torch.Tensor`: Decoded target map.
|
| 463 |
-
"""
|
| 464 |
-
|
| 465 |
-
assert target_latents.shape[1] == 12 # self.n_targets * 4
|
| 466 |
-
|
| 467 |
-
# scale latent
|
| 468 |
-
target_latents = target_latents / self.latent_scale_factor
|
| 469 |
-
# decode
|
| 470 |
-
targets = []
|
| 471 |
-
for i in range(self.n_targets):
|
| 472 |
-
latent = target_latents[:, i * 4 : (i + 1) * 4, :, :]
|
| 473 |
-
z = self.vae.post_quant_conv(latent)
|
| 474 |
-
stacked = self.vae.decoder(z)
|
| 475 |
-
|
| 476 |
-
targets.append(stacked)
|
| 477 |
-
|
| 478 |
-
return torch.cat(targets, dim=1)
|
| 479 |
-
|
| 480 |
-
@staticmethod
|
| 481 |
-
def get_pil_resample_method(method_str: str) -> Resampling:
|
| 482 |
-
resample_method_dic = {
|
| 483 |
-
"bilinear": Resampling.BILINEAR,
|
| 484 |
-
"bicubic": Resampling.BICUBIC,
|
| 485 |
-
"nearest": Resampling.NEAREST,
|
| 486 |
-
}
|
| 487 |
-
resample_method = resample_method_dic.get(method_str, None)
|
| 488 |
-
if resample_method is None:
|
| 489 |
-
raise ValueError(f"Unknown resampling method: {resample_method}")
|
| 490 |
-
else:
|
| 491 |
-
return resample_method
|
| 492 |
-
|
| 493 |
-
@staticmethod
|
| 494 |
-
def resize_max_res(
|
| 495 |
-
img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR
|
| 496 |
-
) -> Image.Image:
|
| 497 |
-
"""
|
| 498 |
-
Resize image to limit maximum edge length while keeping aspect ratio.
|
| 499 |
-
"""
|
| 500 |
-
original_width, original_height = img.size
|
| 501 |
-
downscale_factor = min(
|
| 502 |
-
max_edge_resolution / original_width, max_edge_resolution / original_height
|
| 503 |
-
)
|
| 504 |
-
|
| 505 |
-
new_width = int(original_width * downscale_factor)
|
| 506 |
-
new_height = int(original_height * downscale_factor)
|
| 507 |
-
|
| 508 |
-
resized_img = img.resize((new_width, new_height), resample=resample_method)
|
| 509 |
-
return resized_img
|
| 510 |
-
|
| 511 |
-
@staticmethod
|
| 512 |
-
def chw2hwc(chw):
|
| 513 |
-
assert 3 == len(chw.shape)
|
| 514 |
-
if isinstance(chw, torch.Tensor):
|
| 515 |
-
hwc = torch.permute(chw, (1, 2, 0))
|
| 516 |
-
elif isinstance(chw, np.ndarray):
|
| 517 |
-
hwc = np.moveaxis(chw, 0, -1)
|
| 518 |
-
return hwc
|
| 519 |
-
|
| 520 |
-
@staticmethod
|
| 521 |
-
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
| 522 |
-
"""
|
| 523 |
-
Automatically search for suitable operating batch size.
|
| 524 |
-
|
| 525 |
-
Args:
|
| 526 |
-
ensemble_size (`int`):
|
| 527 |
-
Number of predictions to be ensembled.
|
| 528 |
-
input_res (`int`):
|
| 529 |
-
Operating resolution of the input image.
|
| 530 |
-
|
| 531 |
-
Returns:
|
| 532 |
-
`int`: Operating batch size.
|
| 533 |
-
"""
|
| 534 |
-
# Search table for suggested max. inference batch size
|
| 535 |
-
bs_search_table = [
|
| 536 |
-
# tested on A100-PCIE-80GB
|
| 537 |
-
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
| 538 |
-
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
| 539 |
-
# tested on A100-PCIE-40GB
|
| 540 |
-
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
| 541 |
-
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
| 542 |
-
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
| 543 |
-
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
| 544 |
-
# tested on RTX3090, RTX4090
|
| 545 |
-
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
| 546 |
-
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
| 547 |
-
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
| 548 |
-
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
| 549 |
-
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
| 550 |
-
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
| 551 |
-
# tested on GTX1080Ti
|
| 552 |
-
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
| 553 |
-
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
| 554 |
-
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
| 555 |
-
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
| 556 |
-
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
| 557 |
-
]
|
| 558 |
-
|
| 559 |
-
if not torch.cuda.is_available():
|
| 560 |
-
return 1
|
| 561 |
-
|
| 562 |
-
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
| 563 |
-
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
| 564 |
-
for settings in sorted(
|
| 565 |
-
filtered_bs_search_table,
|
| 566 |
-
key=lambda k: (k["res"], -k["total_vram"]),
|
| 567 |
-
):
|
| 568 |
-
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
| 569 |
-
bs = settings["bs"]
|
| 570 |
-
if bs > ensemble_size:
|
| 571 |
-
bs = ensemble_size
|
| 572 |
-
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
| 573 |
-
bs = math.ceil(ensemble_size / 2)
|
| 574 |
-
return bs
|
| 575 |
-
|
| 576 |
-
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
diffusers>=0.
|
| 2 |
git+https://github.com/toshas/gradio-dualvision.git@21346a4
|
| 3 |
accelerate
|
| 4 |
huggingface_hub
|
|
|
|
| 1 |
+
diffusers>=0.33.0
|
| 2 |
git+https://github.com/toshas/gradio-dualvision.git@21346a4
|
| 3 |
accelerate
|
| 4 |
huggingface_hub
|