Newgen_2025 / app.py
gaur3009's picture
Update app.py
1d6741e verified
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import cv2
import numpy as np
import gradio as gr
from skimage.transform import warp, PiecewiseAffineTransform
# Define U-Net model for cloth fold segmentation
class ClothFoldUNet(nn.Module):
def __init__(self):
super(ClothFoldUNet, self).__init__()
self.model = smp.Unet(
encoder_name="resnet34", # Pre-trained backbone
encoder_weights="imagenet",
in_channels=3,
classes=1, # Single channel output for segmentation
)
def forward(self, x):
return self.model(x)
# Load dataset (placeholder, replace with real dataset)
def get_dataloader(batch_size=8):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
dataset = datasets.FakeData(transform=transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Train function
def train_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ClothFoldUNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
dataloader = get_dataloader()
for epoch in range(10): # Placeholder epoch count
for images, _ in dataloader:
images = images.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, torch.ones_like(outputs)) # Placeholder loss
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}: Loss {loss.item():.4f}")
# Function to apply design onto cloth using segmentation mask
def apply_design(image, design, mask):
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
design = cv2.resize(design, (image.shape[1], image.shape[0]))
mask = np.expand_dims(mask, axis=-1)
blended = (mask * design) + ((1 - mask) * image)
return blended.astype(np.uint8)
# Gradio Interface
def process_image(image, design):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ClothFoldUNet().to(device)
model.eval()
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
mask = model(img_tensor).squeeze().cpu().numpy()
result = apply_design(np.array(image), np.array(design), mask)
return result
iface = gr.Interface(
fn=process_image,
inputs=["image", "image"],
outputs="image",
title="AI Cloth Design Blending",
description="Upload a cloth image and a design to blend the design onto the cloth while considering the folds."
)
# Run Gradio app
if __name__ == "__main__":
train_model()
iface.launch()