File size: 5,037 Bytes
984b1c3
 
 
 
 
 
ba17c5b
984b1c3
d00e30a
d3f9ca8
984b1c3
1c2f991
ba17c5b
1c2f991
984b1c3
 
91a732d
12f978b
 
 
 
 
91a732d
984b1c3
ba17c5b
984b1c3
 
 
 
12f978b
984b1c3
 
d3f9ca8
984b1c3
 
d00e30a
4175fd1
d00e30a
908fc7b
d00e30a
91a732d
 
 
 
 
 
 
4175fd1
 
 
 
 
 
91a732d
 
12f978b
4175fd1
91a732d
d3f9ca8
ba17c5b
91a732d
12f978b
ba17c5b
d3f9ca8
d00e30a
91a732d
ba17c5b
91a732d
ba17c5b
 
908fc7b
 
d00e30a
908fc7b
 
a22d3b1
12f978b
 
 
 
d3f9ca8
ba17c5b
12f978b
ba17c5b
91a732d
a22d3b1
12f978b
a22d3b1
908fc7b
91a732d
a22d3b1
ba17c5b
91a732d
95d0b08
a22d3b1
ba17c5b
a22d3b1
984b1c3
 
 
9cebca9
d3f9ca8
d00e30a
12f978b
d00e30a
490bf43
d00e30a
91a732d
 
 
490bf43
984b1c3
d3f9ca8
 
 
984b1c3
ba17c5b
 
 
 
95d0b08
1f7ad23
984b1c3
af8c4a2
984b1c3
490bf43
984b1c3
71096d7
d3f9ca8
984b1c3
71096d7
91a732d
ba17c5b
984b1c3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET  

# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()

def refine_mask(mask):
    """Refines mask using morphological closing followed by Gaussian blur"""
    kernel = np.ones((7, 7), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)  # Close holes inside dress
    mask = cv2.GaussianBlur(mask, (7, 7), 1.5)
    return mask

def segment_dress(image_np):
    """Segment dress using U²-Net"""
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])

    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)

    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()

    output = (output - output.min()) / (output.max() - output.min() + 1e-8)  # Normalize to [0, 1]
    dress_mask = (output > 0.5).astype(np.uint8) * 255
    dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_LINEAR)

    return refine_mask(dress_mask)

def apply_grabcut(image_np, dress_mask):
    """Refines the mask using GrabCut to avoid color bleeding"""
    bgd_model = np.zeros((1, 65), np.float64)
    fgd_model = np.zeros((1, 65), np.float64)

    mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')

    # Get bounding box of the mask
    coords = cv2.findNonZero(dress_mask)
    x, y, w, h = cv2.boundingRect(coords)
    rect = (x, y, w, h)

    cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_MASK)

    refined_mask = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
    return refine_mask(refined_mask)

def recolor_dress(image_np, dress_mask, target_color):
    """Changes dress color while keeping texture & lighting intact"""

    # Convert target color to LAB
    target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]

    # Convert image to LAB
    img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)

    # Compute mean LAB values of dress pixels
    dress_pixels = img_lab[dress_mask > 0]
    if len(dress_pixels) == 0:
        return image_np  # No dress detected

    mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)

    # Apply LAB shift
    a_shift = target_color_lab[1] - mean_A
    b_shift = target_color_lab[2] - mean_B
    img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
    img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)

    # Convert back to RGB
    img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)

    # Create feathered mask for smooth blending
    lightness_mask = (img_lab[..., 0] / 255.0)
    feathered_mask = cv2.GaussianBlur(dress_mask, (15, 15), 5)
    adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)

    # Blend the recolored dress with the original image
    img_final = (image_np * (1 - adaptive_feather[..., None] / 255) + img_recolored * (adaptive_feather[..., None] / 255)).astype(np.uint8)

    return img_final

def change_dress_color(img, color):
    """Main function to change dress color naturally"""
    if img is None:
        return None

    img_np = np.array(img)

    # Get dress segmentation mask
    dress_mask = segment_dress(img_np)

    if dress_mask is None:
        return img  # No dress detected

    # Further refine mask with GrabCut
    dress_mask = apply_grabcut(img_np, dress_mask)

    # Convert the selected color to BGR
    color_map = {
        "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
        "Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
        "White": (255, 255, 255), "Black": (0, 0, 0)
    }
    new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)

    # Apply recoloring with blending
    img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)

    return Image.fromarray(img_recolored)

# Gradio Interface
demo = gr.Interface(
    fn=change_dress_color,
    inputs=[
        gr.Image(type = "pil", label="Upload Dress Image"),
        gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
    ],
    outputs=gr.Image(type = "pil", label="Color Changed Dress"),
    title="AI-Powered Dress Color Changer",
    description="Upload an image of a dress and select a new color. The AI will change the dress color naturally while keeping the fabric texture."
)

if __name__ == "__main__":
    demo.launch()