detecting_dress / app.py
gaur3009's picture
Update app.py
6e5e70e verified
raw
history blame
5.77 kB
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
def refine_mask(mask):
"""Enhanced mask refinement with erosion and morphological operations"""
# First closing to fill small holes
close_kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
# Erosion to remove small protrusions and extra areas
erode_kernel = np.ones((3, 3), np.uint8)
mask = cv2.erode(mask, erode_kernel, iterations=1)
# Second closing to refine edges after erosion
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
# Final blur to smooth edges while preserving shape
mask = cv2.GaussianBlur(mask, (5, 5), 1.5)
return mask
def segment_dress(image_np):
"""Improved dress segmentation with adaptive thresholding"""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
# Adaptive threshold calculation
output = (output - output.min()) / (output.max() - output.min() + 1e-8)
adaptive_thresh = np.mean(output) + 0.2 # Increased threshold for tighter mask
dress_mask = (output > adaptive_thresh).astype(np.uint8) * 255
# Preserve hard edges during resize
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]),
interpolation=cv2.INTER_NEAREST)
return refine_mask(dress_mask)
def apply_grabcut(image_np, dress_mask):
"""Mask refinement using GrabCut"""
bgd_model = np.zeros((1, 65), np.float64)
fgd_model = np.zeros((1, 65), np.float64)
mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
# Get bounding box coordinates
coords = cv2.findNonZero(dress_mask)
if coords is not None:
x, y, w, h = cv2.boundingRect(coords)
rect = (x, y, w, h)
cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 3, cv2.GC_INIT_WITH_MASK)
refined_mask = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
return refine_mask(refined_mask)
def recolor_dress(image_np, dress_mask, target_color):
"""Color transformation with improved blending"""
# Convert colors to LAB space
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
# Calculate color shifts
dress_pixels = img_lab[dress_mask > 0]
if len(dress_pixels) == 0:
return image_np
mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
a_shift = target_color_lab[1] - mean_A
b_shift = target_color_lab[2] - mean_B
# Apply color transformation
img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)
# Create adaptive blending mask
img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
feathered_mask = cv2.GaussianBlur(dress_mask, (21, 21), 7)
lightness_mask = (img_lab[..., 0] / 255.0) ** 0.7
adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)
# Smooth blending
return (image_np * (1 - adaptive_feather[..., None]/255) + img_recolored * (adaptive_feather[..., None]/255)).astype(np.uint8)
def change_dress_color(img, color):
"""Main processing function with error handling"""
if img is None:
return None
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0),
"Yellow": (0, 255, 255), "Purple": (128, 0, 128), "Orange": (0, 165, 255),
"Cyan": (255, 255, 0), "Magenta": (255, 0, 255), "White": (255, 255, 255),
"Black": (0, 0, 0)
}
new_color_bgr = color_map.get(color, (0, 0, 255))
img_np = np.array(img)
try:
dress_mask = segment_dress(img_np)
if np.sum(dress_mask) < 1000: # Minimum mask area threshold
return img
dress_mask = apply_grabcut(img_np, dress_mask)
img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)
return Image.fromarray(img_recolored)
except Exception as e:
print(f"Error processing image: {str(e)}")
return img
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# AI Dress Color Changer")
gr.Markdown("Upload a dress image and select a new color for realistic recoloring")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
color_choice = gr.Dropdown(
choices=["Red", "Blue", "Green", "Yellow", "Purple",
"Orange", "Cyan", "Magenta", "White", "Black"],
value="Red",
label="Select New Color"
)
process_btn = gr.Button("Recolor Dress")
with gr.Column():
output_image = gr.Image(type="pil", label="Result")
process_btn.click(
fn=change_dress_color,
inputs=[input_image, color_choice],
outputs=output_image
)
if __name__ == "__main__":
demo.launch()