detecting_dress / app.py
gaur3009's picture
Update app.py
71096d7 verified
raw
history blame
5.04 kB
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
def refine_mask(mask):
"""Refines mask using morphological closing followed by Gaussian blur"""
kernel = np.ones((7, 7), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Close holes inside dress
mask = cv2.GaussianBlur(mask, (7, 7), 1.5)
return mask
def segment_dress(image_np):
"""Segment dress using U²-Net"""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
output = (output - output.min()) / (output.max() - output.min() + 1e-8) # Normalize to [0, 1]
dress_mask = (output > 0.5).astype(np.uint8) * 255
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_LINEAR)
return refine_mask(dress_mask)
def apply_grabcut(image_np, dress_mask):
"""Refines the mask using GrabCut to avoid color bleeding"""
bgd_model = np.zeros((1, 65), np.float64)
fgd_model = np.zeros((1, 65), np.float64)
mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
# Get bounding box of the mask
coords = cv2.findNonZero(dress_mask)
x, y, w, h = cv2.boundingRect(coords)
rect = (x, y, w, h)
cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_MASK)
refined_mask = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
return refine_mask(refined_mask)
def recolor_dress(image_np, dress_mask, target_color):
"""Changes dress color while keeping texture & lighting intact"""
# Convert target color to LAB
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
# Convert image to LAB
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
# Compute mean LAB values of dress pixels
dress_pixels = img_lab[dress_mask > 0]
if len(dress_pixels) == 0:
return image_np # No dress detected
mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
# Apply LAB shift
a_shift = target_color_lab[1] - mean_A
b_shift = target_color_lab[2] - mean_B
img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)
# Convert back to RGB
img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
# Create feathered mask for smooth blending
lightness_mask = (img_lab[..., 0] / 255.0)
feathered_mask = cv2.GaussianBlur(dress_mask, (15, 15), 5)
adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)
# Blend the recolored dress with the original image
img_final = (image_np * (1 - adaptive_feather[..., None] / 255) + img_recolored * (adaptive_feather[..., None] / 255)).astype(np.uint8)
return img_final
def change_dress_color(img, color):
"""Main function to change dress color naturally"""
if img is None:
return None
img_np = np.array(img)
# Get dress segmentation mask
dress_mask = segment_dress(img_np)
if dress_mask is None:
return img # No dress detected
# Further refine mask with GrabCut
dress_mask = apply_grabcut(img_np, dress_mask)
# Convert the selected color to BGR
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
"White": (255, 255, 255), "Black": (0, 0, 0)
}
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)
# Apply recoloring with blending
img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type = "pil", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
],
outputs=gr.Image(type = "pil", label="Color Changed Dress"),
title="AI-Powered Dress Color Changer",
description="Upload an image of a dress and select a new color. The AI will change the dress color naturally while keeping the fabric texture."
)
if __name__ == "__main__":
demo.launch()