Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import transforms, datasets | |
from torch.utils.data import DataLoader | |
import segmentation_models_pytorch as smp | |
import cv2 | |
import numpy as np | |
import gradio as gr | |
from skimage.transform import warp, PiecewiseAffineTransform | |
# Define U-Net model for cloth fold segmentation | |
class ClothFoldUNet(nn.Module): | |
def __init__(self): | |
super(ClothFoldUNet, self).__init__() | |
self.model = smp.Unet( | |
encoder_name="resnet34", # Pre-trained backbone | |
encoder_weights="imagenet", | |
in_channels=3, | |
classes=1, # Single channel output for segmentation | |
) | |
def forward(self, x): | |
return self.model(x) | |
# Load dataset (placeholder, replace with real dataset) | |
def get_dataloader(batch_size=8): | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
]) | |
dataset = datasets.FakeData(transform=transform) | |
return DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
# Train function | |
def train_model(): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = ClothFoldUNet().to(device) | |
optimizer = optim.Adam(model.parameters(), lr=1e-4) | |
criterion = nn.BCEWithLogitsLoss() | |
dataloader = get_dataloader() | |
for epoch in range(10): # Placeholder epoch count | |
for images, _ in dataloader: | |
images = images.to(device) | |
optimizer.zero_grad() | |
outputs = model(images) | |
loss = criterion(outputs, torch.ones_like(outputs)) # Placeholder loss | |
loss.backward() | |
optimizer.step() | |
print(f"Epoch {epoch+1}: Loss {loss.item():.4f}") | |
# Function to apply design onto cloth using segmentation mask | |
def apply_design(image, design, mask): | |
mask = cv2.resize(mask, (image.shape[1], image.shape[0])) | |
design = cv2.resize(design, (image.shape[1], image.shape[0])) | |
mask = np.expand_dims(mask, axis=-1) | |
blended = (mask * design) + ((1 - mask) * image) | |
return blended.astype(np.uint8) | |
# Gradio Interface | |
def process_image(image, design): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = ClothFoldUNet().to(device) | |
model.eval() | |
transform = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
]) | |
img_tensor = transform(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
mask = model(img_tensor).squeeze().cpu().numpy() | |
result = apply_design(np.array(image), np.array(design), mask) | |
return result | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=["image", "image"], | |
outputs="image", | |
title="AI Cloth Design Blending", | |
description="Upload a cloth image and a design to blend the design onto the cloth while considering the folds." | |
) | |
# Run Gradio app | |
if __name__ == "__main__": | |
train_model() | |
iface.launch() |