|
import cv2 |
|
import numpy as np |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
import torch |
|
from pathlib import Path |
|
import torch.nn.functional as F |
|
from typing import Dict, Any, List, Union, Tuple |
|
from torchvision.transforms.functional import normalize |
|
|
|
INPUT_SIZE = [1200, 1800] |
|
|
|
def keep_large_components(a: np.ndarray) -> np.ndarray: |
|
"""Remove small connected components from a binary mask, keeping only large regions. |
|
|
|
Args: |
|
a: Input binary mask as numpy array of shape (H,W) or (H,W,1) |
|
|
|
Returns: |
|
Processed mask with only large connected components remaining, shape (H,W,1) |
|
""" |
|
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(9, 9)) |
|
a_mask = (a > 25).astype(np.uint8) * 255 |
|
|
|
|
|
analysis = cv2.connectedComponentsWithStats(a_mask, 4, cv2.CV_32S) |
|
(totalLabels, label_ids, values, centroid) = analysis |
|
|
|
|
|
h, w = a.shape[:2] |
|
area_limit = 50000 * (h * w) / (INPUT_SIZE[1] * INPUT_SIZE[0]) |
|
i_to_keep = [] |
|
for i in range(1, totalLabels): |
|
area = values[i, cv2.CC_STAT_AREA] |
|
if area > area_limit: |
|
i_to_keep.append(i) |
|
|
|
if len(i_to_keep) > 0: |
|
|
|
final_mask = np.zeros_like(a, dtype=np.uint8) |
|
for i in i_to_keep: |
|
componentMask = (label_ids == i).astype("uint8") * 255 |
|
final_mask = cv2.bitwise_or(final_mask, componentMask) |
|
|
|
|
|
|
|
final_mask = cv2.dilate(final_mask, dilate_kernel, iterations = 2) |
|
a = cv2.bitwise_and(a, final_mask) |
|
a = a.reshape((a.shape[0], a.shape[1], 1)) |
|
|
|
return a |
|
|
|
def read_img(img: Union[str, Path]) -> np.ndarray: |
|
"""Read an image from a URL or local path. |
|
|
|
Args: |
|
img: URL or file path to image |
|
|
|
Returns: |
|
Image as numpy array in RGB format with shape (H,W,3) |
|
""" |
|
if img[0: 4] == 'http': |
|
response = requests.get(img) |
|
im = np.asarray(Image.open(BytesIO(response.content))) |
|
|
|
else: |
|
im = cv2.imread(str(img)) |
|
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) |
|
|
|
return im |
|
|
|
def preprocess_input(im: np.ndarray) -> torch.Tensor: |
|
"""Preprocess image for model input. |
|
|
|
Args: |
|
im: Input image as numpy array of shape (H,W,C) |
|
|
|
Returns: |
|
Preprocessed image as normalized torch tensor of shape (1,3,H,W) |
|
""" |
|
if len(im.shape) < 3: |
|
im = im[:, :, np.newaxis] |
|
|
|
if im.shape[2] == 4: |
|
im = im[:,:,:3] |
|
|
|
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1) |
|
im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), INPUT_SIZE, mode="bilinear").type(torch.uint8) |
|
image = torch.divide(im_tensor,255.0) |
|
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0]) |
|
|
|
if torch.cuda.is_available(): |
|
image=image.cuda() |
|
|
|
return image |
|
|
|
def postprocess_output(result: np.ndarray, orig_im_shape: Tuple[int, int]) -> np.ndarray: |
|
"""Postprocess ONNX model output. |
|
|
|
Args: |
|
result: Model output as numpy array of shape (1,1,H,W) |
|
orig_im_shape: Original image dimensions (height, width) |
|
|
|
Returns: |
|
Processed binary mask as numpy array of shape (H,W,1) |
|
""" |
|
result = torch.squeeze(F.upsample( |
|
torch.from_numpy(result).unsqueeze(0), (orig_im_shape), mode='bilinear'), 0) |
|
ma = torch.max(result) |
|
mi = torch.min(result) |
|
result = (result-mi)/(ma-mi) |
|
|
|
|
|
a = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8) |
|
|
|
|
|
a = keep_large_components(a) |
|
|
|
return a |
|
|
|
def process_image(src: Union[str, Path], ort_session: Any, model_path: Union[str, Path], outname: str) -> None: |
|
"""Process an image through ONNX model to generate alpha mask and save result. |
|
|
|
Args: |
|
src: Source image URL or path |
|
ort_session: ONNX runtime inference session |
|
model_path: Path to ONNX model file |
|
outname: Output filename for saving result |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
image_orig = read_img(src) |
|
image = preprocess_input(image_orig) |
|
|
|
|
|
inputs: Dict[str, Any] = {ort_session.get_inputs()[0].name: image.numpy()} |
|
|
|
|
|
result = ort_session.run(None, inputs)[0][0] |
|
alpha = postprocess_output(result, (image_orig.shape[0], image_orig.shape[1])) |
|
|
|
|
|
img_w_alpha = np.dstack((cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB), alpha)) |
|
cv2.imwrite(outname, img_w_alpha) |
|
print(f"Saved: {outname}") |