ty-bg-remover-test / utils.py
haruntrkmn's picture
Upload 2 files
5c7578b verified
import cv2
import numpy as np
import requests
from PIL import Image
from io import BytesIO
import torch
from pathlib import Path
import torch.nn.functional as F
from typing import Dict, Any, List, Union, Tuple
from torchvision.transforms.functional import normalize
INPUT_SIZE = [1200, 1800]
def keep_large_components(a: np.ndarray) -> np.ndarray:
"""Remove small connected components from a binary mask, keeping only large regions.
Args:
a: Input binary mask as numpy array of shape (H,W) or (H,W,1)
Returns:
Processed mask with only large connected components remaining, shape (H,W,1)
"""
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(9, 9))
a_mask = (a > 25).astype(np.uint8) * 255
# Apply the Component analysis function
analysis = cv2.connectedComponentsWithStats(a_mask, 4, cv2.CV_32S)
(totalLabels, label_ids, values, centroid) = analysis
# Find the components to be kept
h, w = a.shape[:2]
area_limit = 50000 * (h * w) / (INPUT_SIZE[1] * INPUT_SIZE[0])
i_to_keep = []
for i in range(1, totalLabels):
area = values[i, cv2.CC_STAT_AREA]
if area > area_limit:
i_to_keep.append(i)
if len(i_to_keep) > 0:
# Or masks to be kept
final_mask = np.zeros_like(a, dtype=np.uint8)
for i in i_to_keep:
componentMask = (label_ids == i).astype("uint8") * 255
final_mask = cv2.bitwise_or(final_mask, componentMask)
# Remove other components
# Keep edges
final_mask = cv2.dilate(final_mask, dilate_kernel, iterations = 2)
a = cv2.bitwise_and(a, final_mask)
a = a.reshape((a.shape[0], a.shape[1], 1))
return a
def read_img(img: Union[str, Path]) -> np.ndarray:
"""Read an image from a URL or local path.
Args:
img: URL or file path to image
Returns:
Image as numpy array in RGB format with shape (H,W,3)
"""
if img[0: 4] == 'http':
response = requests.get(img)
im = np.asarray(Image.open(BytesIO(response.content)))
else:
im = cv2.imread(str(img))
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
return im
def preprocess_input(im: np.ndarray) -> torch.Tensor:
"""Preprocess image for model input.
Args:
im: Input image as numpy array of shape (H,W,C)
Returns:
Preprocessed image as normalized torch tensor of shape (1,3,H,W)
"""
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
if im.shape[2] == 4: # if image has alpha channel, remove it
im = im[:,:,:3]
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), INPUT_SIZE, mode="bilinear").type(torch.uint8)
image = torch.divide(im_tensor,255.0)
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
image=image.cuda()
return image
def postprocess_output(result: np.ndarray, orig_im_shape: Tuple[int, int]) -> np.ndarray:
"""Postprocess ONNX model output.
Args:
result: Model output as numpy array of shape (1,1,H,W)
orig_im_shape: Original image dimensions (height, width)
Returns:
Processed binary mask as numpy array of shape (H,W,1)
"""
result = torch.squeeze(F.upsample(
torch.from_numpy(result).unsqueeze(0), (orig_im_shape), mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
# a is alpha channel. 255 means foreground, 0 means background.
a = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
# postprocessing
a = keep_large_components(a)
return a
def process_image(src: Union[str, Path], ort_session: Any, model_path: Union[str, Path], outname: str) -> None:
"""Process an image through ONNX model to generate alpha mask and save result.
Args:
src: Source image URL or path
ort_session: ONNX runtime inference session
model_path: Path to ONNX model file
outname: Output filename for saving result
Returns:
None
"""
# Load and preprocess image
image_orig = read_img(src)
image = preprocess_input(image_orig)
# Prepare ONNX input
inputs: Dict[str, Any] = {ort_session.get_inputs()[0].name: image.numpy()}
# Get ONNX output and post-process
result = ort_session.run(None, inputs)[0][0]
alpha = postprocess_output(result, (image_orig.shape[0], image_orig.shape[1]))
# Combine RGB image with alpha mask and save
img_w_alpha = np.dstack((cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB), alpha))
cv2.imwrite(outname, img_w_alpha)
print(f"Saved: {outname}")