File size: 2,722 Bytes
984b1c3
 
 
 
 
 
ac99583
984b1c3
 
dfa09bc
984b1c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET  # Import U²-Net

# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"  # Ensure this path is correct
model = U2NET(3, 1)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

def segment_dress(image_np):
    """Segment the dress from the image using U²-Net."""
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])
    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()
    
    mask = (output > 0.5).astype(np.uint8)  # Thresholding for binary mask
    mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]))  # Resize mask to original
    return mask

def change_dress_color(image_path, color):
    """Change the dress color based on the detected dress mask."""
    if image_path is None:
        return None

    img = Image.open(image_path).convert("RGB")
    img_np = np.array(img)
    mask = segment_dress(img_np)
    
    if mask is None:
        return img  # No dress detected
    
    # Convert the selected color to HSV
    color_map = {
        "Red": (0, 255, 255),
        "Blue": (120, 255, 255),
        "Green": (60, 255, 255),
        "Yellow": (30, 255, 255),
        "Purple": (150, 255, 255)
    }
    hsv_color = np.uint8([[color_map.get(color, (0, 255, 255))]])  # Default to Red
    
    # Convert to BGR
    new_color_bgr = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2BGR)[0][0]

    # Apply the color change
    img_hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
    img_hsv[..., 0] = mask * new_color_bgr[0] + (1 - mask) * img_hsv[..., 0]  # Adjust hue
    img_hsv[..., 1] = mask * new_color_bgr[1] + (1 - mask) * img_hsv[..., 1]  # Adjust saturation
    img_hsv[..., 2] = mask * new_color_bgr[2] + (1 - mask) * img_hsv[..., 2]  # Adjust value

    img_recolored = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
    
    return Image.fromarray(img_recolored)

# Gradio Interface
demo = gr.Interface(
    fn=change_dress_color,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple"], label="Choose New Dress Color")
    ],
    outputs=gr.Image(type="pil", label="Color Changed Dress"),
    title="Dress Color Changer",
    description="Upload an image of a dress and select a new color to change its appearance."
)

if __name__ == "__main__":
    demo.launch()