detecting_dress / app.py
gaur3009's picture
Update app.py
dfa09bc verified
raw
history blame
2.72 kB
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET # Import U²-Net
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth" # Ensure this path is correct
model = U2NET(3, 1)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
def segment_dress(image_np):
"""Segment the dress from the image using U²-Net."""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
mask = (output > 0.5).astype(np.uint8) # Thresholding for binary mask
mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0])) # Resize mask to original
return mask
def change_dress_color(image_path, color):
"""Change the dress color based on the detected dress mask."""
if image_path is None:
return None
img = Image.open(image_path).convert("RGB")
img_np = np.array(img)
mask = segment_dress(img_np)
if mask is None:
return img # No dress detected
# Convert the selected color to HSV
color_map = {
"Red": (0, 255, 255),
"Blue": (120, 255, 255),
"Green": (60, 255, 255),
"Yellow": (30, 255, 255),
"Purple": (150, 255, 255)
}
hsv_color = np.uint8([[color_map.get(color, (0, 255, 255))]]) # Default to Red
# Convert to BGR
new_color_bgr = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2BGR)[0][0]
# Apply the color change
img_hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
img_hsv[..., 0] = mask * new_color_bgr[0] + (1 - mask) * img_hsv[..., 0] # Adjust hue
img_hsv[..., 1] = mask * new_color_bgr[1] + (1 - mask) * img_hsv[..., 1] # Adjust saturation
img_hsv[..., 2] = mask * new_color_bgr[2] + (1 - mask) * img_hsv[..., 2] # Adjust value
img_recolored = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple"], label="Choose New Dress Color")
],
outputs=gr.Image(type="pil", label="Color Changed Dress"),
title="Dress Color Changer",
description="Upload an image of a dress and select a new color to change its appearance."
)
if __name__ == "__main__":
demo.launch()