Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import os | |
from PIL import Image | |
from math import ceil, floor | |
from numpy import ndarray | |
from typing import Callable, List | |
import scipy.signal | |
import onnxruntime as ort | |
from tqdm import tqdm | |
# needed to run locally | |
os.environ["GRADIO_TEMP_DIR"] = ".tmp" | |
WINDOW_CACHE = dict() | |
def _spline_window(window_size: int, power: int = 2) -> np.ndarray: | |
"""Generates a 1-dimensional spline of order 'power' (typically 2), in the designated | |
window. | |
Args: | |
window_size (int): size of the interested window | |
power (int, optional): Order of the spline. Defaults to 2. | |
Returns: | |
np.ndarray: 1D spline | |
""" | |
intersection = int(window_size / 4) | |
wind_outer = ( | |
abs(2 * (scipy.signal.windows.triang(window_size))) ** power) / 2 | |
wind_outer[intersection:-intersection] = 0 | |
wind_inner = ( | |
1 - (abs(2 * (scipy.signal.windows.triang(window_size) - 1)) ** power) / 2 | |
) | |
wind_inner[:intersection] = 0 | |
wind_inner[-intersection:] = 0 | |
wind = wind_inner + wind_outer | |
wind = wind / np.average(wind) | |
return wind | |
def _spline_2d(window_size: int, power: int = 2) -> ndarray: | |
"""Makes a 1D window spline function, then combines it to return a 2D window function. | |
The 2D window is useful to smoothly interpolate between patches. | |
Args: | |
window_size (int): size of the window (patch) | |
power (int, optional): Which order for the spline. Defaults to 2. | |
Returns: | |
np.ndarray: numpy array containing a 2D spline function | |
""" | |
# Memorization to avoid remaking it for every call | |
# since the same window is needed multiple times | |
wind = _spline_window(window_size, power) | |
# make it 2d | |
wind2 = wind[:, None] * wind[None, :] | |
wind2 = wind2 / np.max(wind2) | |
return wind2 | |
def _spline_4d( | |
window_size: int, | |
power: int = 2, | |
batch_size: int = 1, | |
channels: int = 1 | |
) -> ndarray: | |
"""Makes a 4D window spline function | |
Same as the 2D version, but repeated across all channels and batch""" | |
global WINDOW_CACHE | |
key = f"{window_size}_{power}" | |
if key in WINDOW_CACHE: | |
wind4 = WINDOW_CACHE[key] | |
else: | |
wind2 = _spline_2d(window_size, power) | |
wind4 = wind2[None, None, :, :] * np.ones((batch_size, channels, 1, 1)) | |
WINDOW_CACHE[key] = wind2 | |
return wind4 | |
def pad_image(image: np.array, tile_size: int, subdivisions: int) -> np.array: | |
"""Add borders to the given image for a "valid" border pattern according to "window_size" and "subdivisions". | |
Image is expected as a numpy array with shape (width, height, channels). | |
Args: | |
image (torch.Tensor): input image, 3D channels-last tensor | |
tile_size (int): size of a single patch, useful to compute padding | |
subdivisions (int): amount of overlap, useful for padding | |
Returns: | |
torch.Tensor: same image, padded specularly by a certain amount in every direction | |
""" | |
step = tile_size // subdivisions | |
_, in_h, in_w = image.shape | |
pad_h = step - (in_h % step) | |
pad_w = step - (in_w % step) | |
pad_h_l = pad_h // 2 | |
pad_h_r = (pad_h // 2) + (pad_h % 2) | |
pad_w_l = pad_w // 2 | |
pad_w_r = (pad_w // 2) + (pad_w % 2) | |
pad = int(round(tile_size * (1 - 1.0 / subdivisions))) | |
image = np.pad( | |
image, | |
((0, 0), (pad + pad_h_l, pad + pad_h_r), (pad + pad_w_l, pad + pad_w_r)), | |
mode="reflect", | |
) | |
return image, [pad + pad_h_l, pad + pad_h_r, pad + pad_w_l, pad + pad_w_r] | |
def unpad_image(padded_image: ndarray, pads) -> ndarray: | |
"""Reverts changes made by 'pad_image'. The same padding is removed, so tile_size and subdivisions | |
must be coherent. | |
Args: | |
padded_image (torch.Tensor): image with padding still applied | |
tile_size (int): size of a single patch | |
subdivisions (int): subdivisions to compute overlap | |
Returns: | |
torch.Tensor: image without padding, 2D channels-last tensor | |
""" | |
pad_left, pad_right, pad_top, pad_bottom = pads | |
# crop the image left, right, top and bottom | |
# get number of dimensions of padded_image | |
n_dims = len(padded_image.shape) | |
# if padded_image is 2d | |
if n_dims == 2: | |
result = padded_image[pad_left:-pad_right, pad_top:-pad_bottom] | |
# if padded_image is 3d | |
elif n_dims == 3: | |
result = padded_image[:, pad_left:-pad_right, pad_top:-pad_bottom] | |
else: | |
raise ValueError( | |
f"padded_image has {n_dims} dimensions, expected 2 or 3.") | |
return result | |
def windowed_generator( | |
padded_image: ndarray, window_size: int, subdivisions: int, batch_size: int = None | |
): | |
"""Generator that yield tiles grouped by batch size. | |
Args: | |
padded_image (np.ndarray): input image to be processed (already padded), supposed channels-first | |
window_size (int): size of a single patch | |
subdivisions (int): subdivision count on each patch to compute the step | |
batch_size (int, optional): amount of patches in each batch. Defaults to None. | |
Yields: | |
Tuple[List[tuple], np.ndarray]: list of coordinates and respective patches as single batch array | |
""" | |
step = window_size // subdivisions | |
channel, width, height = padded_image.shape | |
batch_size = batch_size or 1 | |
batch = [] | |
coords = [] | |
for x in range(0, width - window_size + 1, step): | |
for y in range(0, height - window_size + 1, step): | |
coords.append((x, y)) | |
# extract the tile, place channels first for batch | |
tile = padded_image[:, x: x + window_size, y: y + window_size] | |
batch.append(tile) | |
# yield the batch once full and restore lists right after | |
if len(batch) == batch_size: | |
yield coords, np.stack(batch) | |
coords.clear() | |
batch.clear() | |
# handle last (possibly unfinished) batch | |
if len(batch) > 0: | |
yield coords, np.stack(batch) | |
def reconstruct( | |
canvas: ndarray, tile_size: int, coords: List[tuple], predictions: ndarray | |
) -> ndarray: | |
"""Helper function that iterates the result batch onto the given canvas to reconstruct | |
the final result batch after batch. | |
Args: | |
canvas (torch.Tensor): container for the final image. | |
tile_size (int): size of a single patch. | |
coords (List[tuple]): list of pixel coordinates corresponding to the batch items | |
predictions (torch.Tensor): array containing patch predictions, shape (batch, tile_size, tile_size, num_classes) | |
Returns: | |
torch.Tensor: the updated canvas, shape (padded_w, padded_h, num_classes) | |
""" | |
for (x, y), patch in zip(coords, predictions): | |
# get canvas number of dimensions | |
n_dims = len(canvas.shape) | |
# if canvas is 2d | |
if n_dims == 2: | |
canvas[x: x + tile_size, y: y + tile_size] += patch | |
# if canvas is 3d | |
elif n_dims == 3: | |
canvas[:, x: x + tile_size, y: y + tile_size] += patch | |
else: | |
raise ValueError( | |
f"Canvas has {n_dims} dimensions, expected 2 or 3.") | |
return canvas | |
def predict_smooth_windowing( | |
image: ndarray, | |
tile_size: int, | |
subdivisions: int, | |
prediction_fn: Callable, | |
batch_size: int = 1, | |
out_dim: int = 1, | |
) -> np.ndarray: | |
"""Allows to predict a large image in one go, dividing it in squared, fixed-size tiles and smoothly | |
interpolating over them to produce a single, coherent output with the same dimensions. | |
Args: | |
image (np.ndarray): input image, expected a 3D vector | |
tile_size (int): size of each squared tile | |
subdivisions (int): number of subdivisions over the single tile for overlaps | |
prediction_fn (Callable): callback that takes the input batch and returns an output tensor | |
batch_size (int, optional): size of each batch. Defaults to None. | |
channels_first (int, optional): whether the input image is channels-first or not | |
mirrored (bool, optional): whether to use dihedral predictions (every simmetry). Defaults to False. | |
Returns: | |
np.ndarray: numpy array with dimensions (w, h), containing smooth predictions | |
""" | |
img, pads = pad_image(image=image, tile_size=tile_size, | |
subdivisions=subdivisions) | |
spline = _spline_4d(window_size=tile_size, power=2) | |
# canvas = np.zeros(img.shape[1], img.shape[2]) | |
canvas = np.zeros((out_dim, img.shape[1], img.shape[2])) | |
loop = tqdm(windowed_generator( | |
padded_image=img, | |
window_size=tile_size, | |
subdivisions=subdivisions, | |
batch_size=batch_size, | |
)) | |
for coords, batch in loop: | |
pred_batch = prediction_fn(batch) # .permute(0, 2, 3, 1) | |
# must be 3d for reconstruction to work | |
pred_batch = pred_batch * spline | |
canvas = reconstruct( | |
canvas, tile_size=tile_size, coords=coords, predictions=pred_batch | |
) | |
prediction = unpad_image(canvas, pads=pads) | |
return prediction | |
def center_pad(x, padding, div_factor=32, mode="reflect"): | |
# center pad with different padding for each city | |
# pads the image with the same padding on all sides | |
# the output size must be at least the size + 2*padding | |
# and divisible by div_factor | |
# first, compute the size of the padded image | |
size_x = x.shape[3] | |
size_y = x.shape[2] | |
# get the min padding | |
min_padding_x = size_x + 2 * padding | |
min_padding_y = size_y + 2 * padding | |
# get the new size | |
new_size_x = int(ceil(min_padding_x / div_factor) * div_factor) | |
new_size_y = int(ceil(min_padding_y / div_factor) * div_factor) | |
# get the padding | |
pad_x = new_size_x - size_x | |
pad_y = new_size_y - size_y | |
pad_left = int(floor(pad_x / 2)) | |
pad_right = int(ceil(pad_x / 2)) | |
pad_top = int(floor(pad_y / 2)) | |
pad_bottom = int(ceil(pad_y / 2)) | |
if pad_x > size_x or pad_y > size_y: | |
padded = np.pad( | |
x, | |
( | |
(0, 0), | |
(0, 0), | |
(int(floor(size_x / 2)), int(ceil(size_x / 2))), | |
(int(floor(size_y / 2)), int(ceil(size_y / 2))), | |
), | |
mode=mode, | |
) | |
# and then pad to size | |
padded = np.pad( | |
x, | |
( | |
(0, 0), | |
(0, 0), | |
(int(floor(new_size_x / 2)), int(ceil(new_size_x / 2))), | |
(int(floor(new_size_y / 2)), int(ceil(new_size_y / 2))), | |
), | |
mode=mode, | |
) | |
else: | |
padded = np.pad( | |
x, | |
( | |
(0, 0), | |
(0, 0), | |
(pad_top, pad_bottom), | |
(pad_left, pad_right), | |
), | |
mode=mode, | |
) | |
paddings = (pad_top, pad_bottom, pad_left, pad_right) | |
return padded, paddings | |
class Model: | |
def __init__(self): | |
path = "assets/models/model.onnx" | |
self.model = ort.InferenceSession(path) | |
self.size = 512 | |
self.subdivisions = 2 | |
self.batch_size = 2 | |
self.out_dim = 1 | |
def forward(self, x): | |
assert x.ndim == 3, "Expected 3D tensor" | |
# remove batch dimension | |
x = x/255 | |
# cast to fp32 | |
x = x.astype(np.float32) | |
pred = predict_smooth_windowing( | |
image=x, | |
tile_size=self.size, | |
subdivisions=self.subdivisions, | |
prediction_fn=self.callback, | |
batch_size=self.batch_size, | |
out_dim=self.out_dim | |
) | |
pred = pred > 0 | |
return pred | |
def callback(self, x: ndarray) -> ndarray: | |
# run onnx inference | |
out = self.model.run(None, {"input": x})[0] | |
return out | |
def infer(image): | |
print("Infering") | |
model = Model() | |
image = np.array(image)[:,:,0] | |
# add batch dim | |
image = image[None, :, :] | |
output_image = model.forward(image) | |
output_image = output_image[0] | |
output_image_color = np.zeros((output_image.shape[0], output_image.shape[1], 3)) | |
output_image_color[output_image == 0] = [0, 0, 0] | |
output_image_color[output_image == 1] = [255, 255, 255] | |
output_image = Image.fromarray(output_image_color.astype(np.uint8)) | |
return output_image | |
sample_images = [ | |
"assets/data/sample1.png", | |
"assets/data/sample2.png" | |
] | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Oil Spill Detection Demo") | |
gr.Markdown( | |
"This app allows you to detect oil spills in Synthetic Aperture Radar (SAR) images. Upload a SAR image or use the sample image provided below to detect oil spills." | |
) | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image", type="pil") | |
output_image = gr.Image(label="Model Output", type="pil") | |
submit_button = gr.Button("Run Inference") | |
examples = gr.Examples( | |
examples=[[img] for img in sample_images], | |
inputs=[input_image] | |
) | |
submit_button.click(fn=infer, inputs=input_image, outputs=output_image) | |
demo.launch() | |