Spaces:
Running
Running
File size: 5,037 Bytes
984b1c3 ba17c5b 984b1c3 d00e30a d3f9ca8 984b1c3 1c2f991 ba17c5b 1c2f991 984b1c3 91a732d 12f978b 91a732d 984b1c3 ba17c5b 984b1c3 12f978b 984b1c3 d3f9ca8 984b1c3 d00e30a 4175fd1 d00e30a 908fc7b d00e30a 91a732d 4175fd1 91a732d 12f978b 4175fd1 91a732d d3f9ca8 ba17c5b 91a732d 12f978b ba17c5b d3f9ca8 d00e30a 91a732d ba17c5b 91a732d ba17c5b 908fc7b d00e30a 908fc7b a22d3b1 12f978b d3f9ca8 ba17c5b 12f978b ba17c5b 91a732d a22d3b1 12f978b a22d3b1 908fc7b 91a732d a22d3b1 ba17c5b 91a732d 95d0b08 a22d3b1 ba17c5b a22d3b1 984b1c3 9cebca9 d3f9ca8 d00e30a 12f978b d00e30a 490bf43 d00e30a 91a732d 490bf43 984b1c3 d3f9ca8 984b1c3 ba17c5b 95d0b08 1f7ad23 984b1c3 af8c4a2 984b1c3 490bf43 984b1c3 71096d7 d3f9ca8 984b1c3 71096d7 91a732d ba17c5b 984b1c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
def refine_mask(mask):
"""Refines mask using morphological closing followed by Gaussian blur"""
kernel = np.ones((7, 7), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Close holes inside dress
mask = cv2.GaussianBlur(mask, (7, 7), 1.5)
return mask
def segment_dress(image_np):
"""Segment dress using U²-Net"""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
output = (output - output.min()) / (output.max() - output.min() + 1e-8) # Normalize to [0, 1]
dress_mask = (output > 0.5).astype(np.uint8) * 255
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_LINEAR)
return refine_mask(dress_mask)
def apply_grabcut(image_np, dress_mask):
"""Refines the mask using GrabCut to avoid color bleeding"""
bgd_model = np.zeros((1, 65), np.float64)
fgd_model = np.zeros((1, 65), np.float64)
mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
# Get bounding box of the mask
coords = cv2.findNonZero(dress_mask)
x, y, w, h = cv2.boundingRect(coords)
rect = (x, y, w, h)
cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_MASK)
refined_mask = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
return refine_mask(refined_mask)
def recolor_dress(image_np, dress_mask, target_color):
"""Changes dress color while keeping texture & lighting intact"""
# Convert target color to LAB
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
# Convert image to LAB
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
# Compute mean LAB values of dress pixels
dress_pixels = img_lab[dress_mask > 0]
if len(dress_pixels) == 0:
return image_np # No dress detected
mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
# Apply LAB shift
a_shift = target_color_lab[1] - mean_A
b_shift = target_color_lab[2] - mean_B
img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)
# Convert back to RGB
img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
# Create feathered mask for smooth blending
lightness_mask = (img_lab[..., 0] / 255.0)
feathered_mask = cv2.GaussianBlur(dress_mask, (15, 15), 5)
adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)
# Blend the recolored dress with the original image
img_final = (image_np * (1 - adaptive_feather[..., None] / 255) + img_recolored * (adaptive_feather[..., None] / 255)).astype(np.uint8)
return img_final
def change_dress_color(img, color):
"""Main function to change dress color naturally"""
if img is None:
return None
img_np = np.array(img)
# Get dress segmentation mask
dress_mask = segment_dress(img_np)
if dress_mask is None:
return img # No dress detected
# Further refine mask with GrabCut
dress_mask = apply_grabcut(img_np, dress_mask)
# Convert the selected color to BGR
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
"White": (255, 255, 255), "Black": (0, 0, 0)
}
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)
# Apply recoloring with blending
img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type = "pil", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
],
outputs=gr.Image(type = "pil", label="Color Changed Dress"),
title="AI-Powered Dress Color Changer",
description="Upload an image of a dress and select a new color. The AI will change the dress color naturally while keeping the fabric texture."
)
if __name__ == "__main__":
demo.launch() |