Spaces:
Running
Running
File size: 2,720 Bytes
984b1c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth-segmentation.networks.u2net import U2NET # Import U²-Net
# Load U²-Net model
model_path = "cloth-segmentation/models/u2net.pth" # Ensure this path is correct
model = U2NET(3, 1)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
def segment_dress(image_np):
"""Segment the dress from the image using U²-Net."""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
mask = (output > 0.5).astype(np.uint8) # Thresholding for binary mask
mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0])) # Resize mask to original
return mask
def change_dress_color(image_path, color):
"""Change the dress color based on the detected dress mask."""
if image_path is None:
return None
img = Image.open(image_path).convert("RGB")
img_np = np.array(img)
mask = segment_dress(img_np)
if mask is None:
return img # No dress detected
# Convert the selected color to HSV
color_map = {
"Red": (0, 255, 255),
"Blue": (120, 255, 255),
"Green": (60, 255, 255),
"Yellow": (30, 255, 255),
"Purple": (150, 255, 255)
}
hsv_color = np.uint8([[color_map.get(color, (0, 255, 255))]]) # Default to Red
# Convert to BGR
new_color_bgr = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2BGR)[0][0]
# Apply the color change
img_hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
img_hsv[..., 0] = mask * new_color_bgr[0] + (1 - mask) * img_hsv[..., 0] # Adjust hue
img_hsv[..., 1] = mask * new_color_bgr[1] + (1 - mask) * img_hsv[..., 1] # Adjust saturation
img_hsv[..., 2] = mask * new_color_bgr[2] + (1 - mask) * img_hsv[..., 2] # Adjust value
img_recolored = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple"], label="Choose New Dress Color")
],
outputs=gr.Image(type="pil", label="Color Changed Dress"),
title="Dress Color Changer",
description="Upload an image of a dress and select a new color to change its appearance."
)
if __name__ == "__main__":
demo.launch() |