Spaces:
Sleeping
Sleeping
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. | |
# This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License. | |
# See https://creativecommons.org/licenses/by-sa/4.0/ for details. | |
# -------------------------------------------------------------------------- | |
# DualVision is a Gradio template app for image processing. It was developed | |
# to support the Marigold project. If you find this code useful, we kindly | |
# ask you to cite our most relevant papers. | |
# More information about Marigold: | |
# https://marigoldmonodepth.github.io | |
# https://marigoldcomputervision.github.io | |
# Efficient inference pipelines are now part of diffusers: | |
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage | |
# https://huggingface.co/docs/diffusers/api/pipelines/marigold | |
# Examples of trained models and live demos: | |
# https://huggingface.co/prs-eth | |
# Related projects: | |
# https://marigolddepthcompletion.github.io/ | |
# https://rollingdepth.github.io/ | |
# Citation (BibTeX): | |
# https://github.com/prs-eth/Marigold#-citation | |
# https://github.com/prs-eth/Marigold-DC#-citation | |
# https://github.com/prs-eth/rollingdepth#-citation | |
# -------------------------------------------------------------------------- | |
import glob | |
import json | |
import os | |
import re | |
import gradio as gr | |
import spaces | |
from PIL import Image | |
from gradio.components.base import Component | |
from .gradio_patches.examples import Examples | |
from .gradio_patches.imagesliderplus import ImageSliderPlus | |
from .gradio_patches.radio import Radio | |
class DualVisionApp(gr.Blocks): | |
def __init__( | |
self, | |
title, | |
examples_path="examples", | |
examples_per_page=12, | |
examples_cache="lazy", | |
squeeze_canvas=True, | |
squeeze_viewport_height_pct=75, | |
left_selector_visible=False, | |
advanced_settings_can_be_half_width=True, | |
key_original_image="Original", | |
spaces_zero_gpu_enabled=False, | |
spaces_zero_gpu_duration=None, | |
slider_position=0.5, | |
slider_line_color="#FFF", | |
slider_line_width="4px", | |
slider_arrows_color="#FFF", | |
slider_arrows_width="2px", | |
gallery_thumb_min_size="96px", | |
**kwargs, | |
): | |
""" | |
A wrapper around Gradio Blocks class that implements an image processing demo template. All the user has to do | |
is to subclass this class and implement two methods: `process` implementing the image processing, and | |
`build_user_components` implementing Gradio components reading the additional processing arguments. | |
Args: | |
title: Title of the application (str, required). | |
examples_path: Base path where examples will be searched (Default: `"examples"`). | |
examples_per_page: How many examples to show at the bottom of the app (Default: `12`). | |
examples_cache: Examples caching policy, corresponding to `cache_examples` argument of gradio.Examples (Default: `"lazy"`). | |
squeeze_canvas: When True, the image is fit to the browser viewport. When False, the image is fit to width (Default: `True`). | |
squeeze_viewport_height_pct: Percentage of the browser viewport height (Default: `75`). | |
left_selector_visible: Whether controls for changing modalities in the left part of the slider are visible (Default: `False`). | |
key_original_image: Name of the key under which the input image is shown in the modality selectors (Default: `"Original"`). | |
advanced_settings_can_be_half_width: Whether allow placing advanced settings dropdown in half-column space whenever possible (Default: `True`). | |
spaces_zero_gpu_enabled: When True, the app wraps the processing function with the ZeroGPU decorator. | |
spaces_zero_gpu_duration: Defines an integer duration in seconds passed into the ZeroGPU decorator. | |
slider_position: Position of the slider between 0 and 1 (Default: `0.5`). | |
slider_line_color: Color of the slider line (Default: `"#FFF"`). | |
slider_line_width: Width of the slider line (Default: `"4px"`). | |
slider_arrows_color: Color of the slider arrows (Default: `"#FFF"`). | |
slider_arrows_width: Width of the slider arrows (Default: `2px`). | |
gallery_thumb_min_size: Min size of the gallery thumbnail (Default: `96px`). | |
**kwargs: Any other arguments that Gradio Blocks class can take. | |
""" | |
squeeze_viewport_height_pct = int(squeeze_viewport_height_pct) | |
if not 50 <= squeeze_viewport_height_pct <= 100: | |
raise gr.Error( | |
"`squeeze_viewport_height_pct` should be an integer between 50 and 100." | |
) | |
if not os.path.isdir(examples_path): | |
raise gr.Error("`examples_path` should be a directory.") | |
if not 0 <= slider_position <= 1: | |
raise gr.Error("`slider_position` should be between 0 and 1.") | |
kwargs = {k: v for k, v in kwargs.items()} | |
kwargs["title"] = title | |
self.examples_path = examples_path | |
self.examples_per_page = examples_per_page | |
self.examples_cache = examples_cache | |
self.key_original_image = key_original_image | |
self.slider_position = slider_position | |
self.input_keys = None | |
self.left_selector_visible = left_selector_visible | |
self.advanced_settings_can_be_half_width = advanced_settings_can_be_half_width | |
if spaces_zero_gpu_enabled: | |
self.process_components = spaces.GPU( | |
self.process_components, duration=spaces_zero_gpu_duration | |
) | |
self.head = "" | |
self.head += """ | |
<script> | |
let observerFooterButtons = new MutationObserver((mutationsList, observer) => { | |
const oldButtonLeft = document.querySelector(".show-api"); | |
const oldButtonRight = document.querySelector(".built-with"); | |
if (!oldButtonRight || !oldButtonLeft) { | |
return; | |
} | |
observer.disconnect(); | |
const parentDiv = oldButtonLeft.parentNode; | |
if (!parentDiv) return; | |
const createButton = (referenceButton, text, href) => { | |
let newButton = referenceButton.cloneNode(true); | |
newButton.href = href; | |
newButton.textContent = text; | |
newButton.className = referenceButton.className; | |
newButton.style.textDecoration = "none"; | |
newButton.style.display = "inline-block"; | |
newButton.style.cursor = "pointer"; | |
return newButton; | |
}; | |
const newButton0 = createButton(oldButtonRight, "Built with Gradio DualVision", "https://github.com/toshas/gradio-dualvision"); | |
const newButton1 = createButton(oldButtonRight, "Template by Anton Obukhov", "https://www.obukhov.ai"); | |
const newButton2 = createButton(oldButtonRight, "Licensed under CC BY-SA 4.0", "http://creativecommons.org/licenses/by-sa/4.0/"); | |
const separatorDiv = document.createElement("div"); | |
separatorDiv.className = "svelte-1rjryqp"; | |
separatorDiv.textContent = "·"; | |
parentDiv.replaceChild(newButton0, oldButtonLeft); | |
parentDiv.replaceChild(newButton1, oldButtonRight); | |
parentDiv.appendChild(separatorDiv); | |
parentDiv.appendChild(newButton2); | |
}); | |
observerFooterButtons.observe(document.body, { childList: true, subtree: true }); | |
</script> | |
""" | |
if kwargs.get("analytics_enabled") is not False: | |
self.head += f""" | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag() {{dataLayer.push(arguments);}} | |
gtag('js', new Date()); | |
gtag('config', 'G-1FWSVCGZTG'); | |
</script> | |
""" | |
self.css = f""" | |
body {{ /* tighten the layout */ | |
flex-grow: 0 !important; | |
}} | |
.sliderrow {{ /* center the slider */ | |
display: flex; | |
justify-content: center; | |
}} | |
.slider {{ /* center the slider */ | |
display: flex; | |
justify-content: center; | |
width: 100%; | |
}} | |
.slider .disabled {{ /* hide the main slider before image load */ | |
visibility: hidden; | |
}} | |
.slider .svelte-9gxdi0 {{ /* hide the component label in the top-left corner before image load */ | |
visibility: hidden; | |
}} | |
.slider .svelte-kzcjhc .icon-wrap {{ | |
height: 0px; /* remove unnecessary spaces in captions before image load */ | |
}} | |
.slider .svelte-kzcjhc.wrap {{ | |
padding-top: 0px; /* remove unnecessary spaces in captions before image load */ | |
}} | |
.slider .svelte-3w3rth {{ /* hide the dummy icon from the right pane before image load */ | |
visibility: hidden; | |
}} | |
.slider .svelte-106mu0e a {{ /* hide the download button */ | |
visibility: hidden; | |
}} | |
.slider .fixed {{ /* fix the opacity of the right pane image */ | |
background-color: var(--anim-block-background-fill); | |
}} | |
.slider .inner {{ /* style slider line */ | |
width: {slider_line_width}; | |
background: {slider_line_color}; | |
}} | |
.slider .icon-wrap svg {{ /* style slider arrows */ | |
stroke: {slider_arrows_color}; | |
stroke-width: {slider_arrows_width}; | |
}} | |
.slider .icon-wrap path {{ /* style slider arrows */ | |
fill: {slider_arrows_color}; | |
}} | |
.row_reverse {{ | |
flex-direction: row-reverse; | |
}} | |
.gallery.svelte-l4wpk0 {{ /* make examples gallery tiles square */ | |
width: max({gallery_thumb_min_size}, calc(100vw / 8)); | |
height: max({gallery_thumb_min_size}, calc(100vw / 8)); | |
}} | |
.gallery.svelte-l4wpk0 img {{ /* make examples gallery tiles square */ | |
width: max({gallery_thumb_min_size}, calc(100vw / 8)); | |
height: max({gallery_thumb_min_size}, calc(100vw / 8)); | |
}} | |
.gallery.svelte-l4wpk0 img {{ /* remove slider line from previews */ | |
clip-path: inset(0 0 0 0); | |
}} | |
.gallery.svelte-l4wpk0 span {{ /* remove slider line from previews */ | |
visibility: hidden; | |
}} | |
h1, h2, h3 {{ /* center markdown headings */ | |
text-align: center; | |
display: block; | |
}} | |
""" | |
if squeeze_canvas: | |
self.head += f""" | |
<script> | |
// fixes vertical size of the component when used inside of iframeResizer (on spaces) | |
function squeezeViewport() {{ | |
if (typeof window.parentIFrame === "undefined") return; | |
const images = document.querySelectorAll('.slider img'); | |
window.parentIFrame.getPageInfo((info) => {{ | |
images.forEach((img) => {{ | |
const imgMaxHeightNew = (info.clientHeight * {squeeze_viewport_height_pct}) / 100; | |
img.style.maxHeight = `${{imgMaxHeightNew}}px`; | |
// window.parentIFrame.size(0, null); // tighten the layout; body's flex-grow: 0 is less intrusive | |
}}); | |
}}); | |
}} | |
window.addEventListener('resize', squeezeViewport); | |
// fixes gradio-imageslider wrong position behavior when using fitting to content by triggering resize | |
let observer = new MutationObserver((mutationsList) => {{ | |
const images = document.querySelectorAll('.slider img'); | |
images.forEach((img) => {{ | |
if (img.complete) {{ | |
window.dispatchEvent(new Event('resize')); | |
}} else {{ | |
img.onload = () => {{ | |
window.dispatchEvent(new Event('resize')); | |
}} | |
}} | |
}}); | |
}}); | |
observer.observe(document.body, {{ childList: true, subtree: true }}); | |
</script> | |
""" | |
self.css += f""" | |
.slider {{ /* make the slider dimensions fit to the uploaded content dimensions */ | |
max-width: fit-content; | |
}} | |
.slider .half-wrap {{ /* make the empty component width almost full before image load */ | |
width: 70vw; | |
}} | |
.slider img {{ /* Ensures image fits inside the viewport */ | |
max-height: {squeeze_viewport_height_pct}vh; | |
}} | |
""" | |
else: | |
self.css += f""" | |
.slider .half-wrap {{ /* make the upload area full width */ | |
width: 100%; | |
}} | |
""" | |
kwargs["css"] = kwargs.get("css", "") + self.css | |
kwargs["head"] = kwargs.get("head", "") + self.head | |
super().__init__(**kwargs) | |
with self: | |
self.make_interface() | |
def process(self, image_in: Image.Image, **kwargs): | |
""" | |
Process an input image into multiple modalities using the provided arguments or default settings. | |
Returns two dictionaries: one containing the modalities and another with the actual settings. | |
Override this method in a subclass. | |
""" | |
raise NotImplementedError("Please override the `process` method.") | |
def build_user_components(self): | |
""" | |
Create gradio components for the Advanced Settings dropdown, that will be passed into the `process` method. | |
Use gr.Row(), gr.Column(), and other context managers to arrange the components. Return them as a flat dict. | |
Override this method in a subclass. | |
""" | |
raise NotImplementedError("Please override the `build_user_components` method.") | |
def discover_examples(self): | |
""" | |
Looks for valid image filenames. | |
""" | |
pattern = re.compile(r".*\.(jpg|JPG|jpeg|JPEG|png|PNG)$") | |
paths = glob.glob(f"{self.examples_path}/*") | |
out = list(sorted(filter(pattern.match, paths))) | |
return out | |
def process_components( | |
self, image_in, modality_selector_left, modality_selector_right, **kwargs | |
): | |
""" | |
Wraps the call to `process`. Returns results in a structure used by the gallery, slider, radio components. | |
""" | |
if image_in is None: | |
raise gr.Error("Input image is required") | |
image_settings = {} | |
if isinstance(image_in, str): | |
image_settings_path = image_in + ".settings.json" | |
if os.path.isfile(image_settings_path): | |
with open(image_settings_path, "r") as f: | |
image_settings = json.load(f) | |
image_in = Image.open(image_in).convert("RGB") | |
else: | |
if not isinstance(image_in, Image.Image): | |
raise gr.Error(f"Input must be a PIL image, got {type(image_in)}") | |
image_in = image_in.convert("RGB") | |
image_settings.update(kwargs) | |
results_dict, results_settings = self.process(image_in, **image_settings) | |
if not isinstance(results_dict, dict): | |
raise gr.Error( | |
f"`process` must return a dict[str, PIL.Image]. Got type: {type(results_dict)}" | |
) | |
if len(results_dict) == 0: | |
raise gr.Error("`process` did not return any modalities") | |
for k, v in results_dict.items(): | |
if not isinstance(k, str): | |
raise gr.Error( | |
f"Output dict must have string keys. Found key of type {type(k)}: {repr(k)}" | |
) | |
if k == self.key_original_image: | |
raise gr.Error( | |
f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input" | |
) | |
if not isinstance(v, Image.Image): | |
raise gr.Error( | |
f"Value for key '{k}' must be a PIL Image, got type {type(v)}" | |
) | |
if len(results_settings) != len(self.input_keys): | |
raise gr.Error( | |
f"Expected number of settings ({len(self.input_keys)}), returned ({len(results_settings)})" | |
) | |
if any(k not in results_settings for k in self.input_keys): | |
raise gr.Error(f"Mismatching setgings keys") | |
results_settings = {k: results_settings[k] for k in self.input_keys} | |
results_dict = { | |
self.key_original_image: image_in, | |
**results_dict, | |
} | |
results_state = [[v, k] for k, v in results_dict.items()] | |
modalities = list(results_dict.keys()) | |
modality_left = ( | |
modality_selector_left | |
if modality_selector_left in modalities | |
else modalities[0] | |
) | |
modality_right = ( | |
modality_selector_right | |
if modality_selector_right in modalities | |
else modalities[1] | |
) | |
return [ | |
results_state, # goes to a gr.Gallery | |
[ | |
results_dict[modality_left], | |
results_dict[modality_right], | |
], # ImageSliderPlus | |
Radio( | |
choices=modalities, | |
value=modality_left, | |
label="Left", | |
key="Left", | |
), | |
Radio( | |
choices=modalities if self.left_selector_visible else modalities[1:], | |
value=modality_right, | |
label="Right", | |
key="Right", | |
), | |
*results_settings.values(), | |
] | |
def on_process_first( | |
self, | |
image_slider, | |
modality_selector_left=None, | |
modality_selector_right=None, | |
*args, | |
): | |
image_in = image_slider[0] | |
input_dict = {} | |
if len(args) > 0: | |
input_dict = {k: v for k, v in zip(self.input_keys, args)} | |
return self.process_components( | |
image_in, modality_selector_left, modality_selector_right, **input_dict | |
) | |
def on_process_subsequent( | |
self, results_state, modality_selector_left, modality_selector_right, *args | |
): | |
if results_state is None: | |
raise gr.Error("Upload an image first or use an example below.") | |
results_state = {k: v for v, k in results_state} | |
image_in = results_state[self.key_original_image] | |
input_dict = {k: v for k, v in zip(self.input_keys, args)} | |
return self.process_components( | |
image_in, modality_selector_left, modality_selector_right, **input_dict | |
) | |
def on_selector_change_left( | |
self, results_state, image_slider, modality_selector_left | |
): | |
results_state = {k: v for v, k in results_state} | |
return [results_state[modality_selector_left], image_slider[1]] | |
def on_selector_change_right( | |
self, results_state, image_slider, modality_selector_right | |
): | |
results_state = {k: v for v, k in results_state} | |
return [image_slider[0], results_state[modality_selector_right]] | |
def make_interface(self): | |
""" | |
Constructs the entire Gradio Blocks interface. | |
""" | |
self.make_header() | |
results_state = gr.Gallery(visible=False, format="png") | |
image_slider = self.make_slider() | |
if self.left_selector_visible or not self.advanced_settings_can_be_half_width: | |
with gr.Row(): | |
modality_selector_left, modality_selector_right = ( | |
self.make_modality_selectors(reverse_visual_order=False) | |
) | |
user_components, btn_clear, btn_submit = self.make_advanced_settings() | |
else: | |
with gr.Row(equal_height=False, elem_classes="row_reverse"): | |
with gr.Column(): | |
modality_selector_left, modality_selector_right = ( | |
self.make_modality_selectors(reverse_visual_order=True) | |
) | |
with gr.Column(): | |
user_components, btn_clear, btn_submit = ( | |
self.make_advanced_settings() | |
) | |
self.make_examples( | |
image_slider, | |
[ | |
results_state, | |
image_slider, | |
modality_selector_left, | |
modality_selector_right, | |
*user_components.values(), | |
], | |
) | |
image_slider.upload( | |
fn=self.on_process_first, | |
inputs=[ | |
image_slider, | |
modality_selector_left, | |
modality_selector_right, | |
*user_components.values(), | |
], | |
outputs=[ | |
results_state, | |
image_slider, | |
modality_selector_left, | |
modality_selector_right, | |
*user_components.values(), | |
], | |
) | |
btn_submit.click( | |
fn=self.on_process_subsequent, | |
inputs=[ | |
results_state, | |
modality_selector_left, | |
modality_selector_right, | |
*user_components.values(), | |
], | |
outputs=[ | |
results_state, | |
image_slider, | |
modality_selector_left, | |
modality_selector_right, | |
*user_components.values(), | |
], | |
) | |
btn_clear.click( | |
fn=lambda: (None, None), | |
inputs=[], | |
outputs=[image_slider, results_state], | |
) | |
modality_selector_left.input( | |
fn=self.on_selector_change_left, | |
inputs=[results_state, image_slider, modality_selector_left], | |
outputs=image_slider, | |
) | |
modality_selector_right.input( | |
fn=self.on_selector_change_right, | |
inputs=[results_state, image_slider, modality_selector_right], | |
outputs=image_slider, | |
) | |
def make_header(self): | |
""" | |
Create a header section with Markdown and HTML. | |
Default: just the project title. | |
""" | |
gr.Markdown(f"# {self.title}") | |
def make_slider(self): | |
with gr.Row(elem_classes="sliderrow"): | |
return ImageSliderPlus( | |
label=self.title, | |
type="filepath", | |
elem_classes="slider", | |
position=self.slider_position, | |
) | |
def make_modality_selectors(self, reverse_visual_order=False): | |
modality_selector_left = Radio( | |
choices=None, | |
value=None, | |
label="Left", | |
key="Left", | |
show_label=False, | |
container=False, | |
visible=self.left_selector_visible, | |
render=not reverse_visual_order, | |
) | |
modality_selector_right = Radio( | |
choices=None, | |
value=None, | |
label="Right", | |
key="Right", | |
show_label=False, | |
container=False, | |
elem_id="selector_right", | |
render=not reverse_visual_order, | |
) | |
if reverse_visual_order: | |
modality_selector_right.render() | |
modality_selector_left.render() | |
return modality_selector_left, modality_selector_right | |
def make_examples(self, inputs, outputs): | |
examples = self.discover_examples() | |
if not isinstance(examples, list): | |
raise gr.Error("`discover_examples` must return a list of paths") | |
if any(not os.path.isfile(path) for path in examples): | |
raise gr.Error("Not all example paths are valid files") | |
examples_dirname = os.path.basename(os.path.normpath(self.examples_path)) | |
return Examples( | |
examples=[ | |
(e, e) for e in examples | |
], # tuples like this seem to work better with the gallery | |
inputs=inputs, | |
outputs=outputs, | |
examples_per_page=self.examples_per_page, | |
cache_examples=self.examples_cache, | |
fn=self.on_process_first, | |
directory_name=examples_dirname, | |
) | |
def make_advanced_settings(self): | |
with gr.Accordion("Advanced Settings", open=False): | |
user_components = self.build_user_components() | |
if not isinstance(user_components, dict) or any( | |
not isinstance(k, str) or not isinstance(v, Component) | |
for k, v in user_components.items() | |
): | |
raise gr.Error( | |
"`build_user_components` must return a dict of Gradio components with string keys. A dict of the " | |
"same structure will be passed into the `process` function." | |
) | |
with gr.Row(): | |
btn_clear, btn_submit = self.make_buttons() | |
self.input_keys = list(user_components.keys()) | |
return user_components, btn_clear, btn_submit | |
def make_buttons(self): | |
btn_clear = gr.Button("Clear") | |
btn_submit = gr.Button("Apply", variant="primary") | |
return btn_clear, btn_submit | |