gaur3009 commited on
Commit
1d6741e
·
verified ·
1 Parent(s): ffc60dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -83
app.py CHANGED
@@ -1,95 +1,88 @@
 
 
 
 
 
 
1
  import cv2
2
  import numpy as np
3
- from PIL import Image, ImageDraw, ImageFont
4
  import gradio as gr
5
- import torch
6
- import torchvision.transforms as transforms
7
- from skimage.filters import sobel
8
- from skimage.restoration import denoise_tv_chambolle
9
- from scipy.interpolate import Rbf
10
-
11
-
12
- # Function to estimate a normal map from cloth texture
13
- def estimate_normal_map(image):
14
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
15
- sobel_x = sobel(gray)
16
- sobel_y = sobel(gray)
17
 
18
- normal_map = np.stack([sobel_x, sobel_y, np.ones_like(sobel_x)], axis=-1)
19
- normal_map /= np.linalg.norm(normal_map, axis=-1, keepdims=True)
20
-
21
- return (normal_map * 255).astype(np.uint8)
22
-
23
-
24
- def apply_tps_warping(design, normal_map):
25
- # Resize normal map to match the design size
26
- normal_map = cv2.resize(normal_map, (design.shape[1], design.shape[0]))
27
-
28
- h, w = design.shape[:2]
29
- x, y = np.meshgrid(np.arange(w), np.arange(h))
30
-
31
- # Generate warp offsets from normal map
32
- control_x = x + (normal_map[:, :, 0] - 128) * 0.5
33
- control_y = y + (normal_map[:, :, 1] - 128) * 0.5
34
-
35
- # Apply Radial Basis Function (RBF) interpolation
36
- rbf_x = Rbf(x.flatten(), y.flatten(), control_x.flatten(), function='thin_plate')
37
- rbf_y = Rbf(x.flatten(), y.flatten(), control_y.flatten(), function='thin_plate')
38
-
39
- warped_x = rbf_x(x, y).astype(np.float32)
40
- warped_y = rbf_y(x, y).astype(np.float32)
41
-
42
- # Warp the design
43
- warped_design = cv2.remap(design, warped_x, warped_y, interpolation=cv2.INTER_LINEAR)
44
-
45
- return warped_design
46
-
47
-
48
- # Function to blend design onto the cloth using Poisson Editing
49
- def blend_design_cloth(cloth, design, x=50, y=50):
50
- cloth_bgr = np.array(cloth)
51
- design_bgr = np.array(design)
52
-
53
- normal_map = estimate_normal_map(cloth_bgr)
54
-
55
- # Resize design to fit the center of the cloth
56
- design_resized = cv2.resize(design_bgr, (cloth_bgr.shape[1] // 2, cloth_bgr.shape[0] // 5))
57
 
58
- # Convert to grayscale and create a mask
59
- design_gray = cv2.cvtColor(design_resized, cv2.COLOR_BGR2GRAY)
60
- _, mask = cv2.threshold(design_gray, 1, 255, cv2.THRESH_BINARY)
61
-
62
- # Warp design using normal map
63
- warped_design = apply_tps_warping(design_resized, normal_map)
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # Blend using Poisson seamless cloning
66
- center = (x + design_resized.shape[1] // 2, y + design_resized.shape[0] // 2)
67
- blended = cv2.seamlessClone(warped_design, cloth_bgr, mask, center, cv2.MIXED_CLONE)
68
-
69
- return Image.fromarray(blended)
70
-
71
-
72
- # Gradio function
73
- def process_image(cloth_image, design_image, x=50, y=50):
74
- # Blend design onto cloth
75
- result = blend_design_cloth(cloth_image, design_image, x, y)
 
 
 
76
  return result
77
 
78
-
79
- # Gradio Interface
80
- interface = gr.Interface(
81
  fn=process_image,
82
- inputs=[
83
- gr.Image(type="pil", label="Upload Cloth Image"),
84
- gr.Image(type="pil", label="Upload Design"),
85
- gr.Slider(0, 1000, step=10, label="X Coordinate", value=50),
86
- gr.Slider(0, 1000, step=10, label="Y Coordinate", value=50),
87
- ],
88
- outputs=gr.Image(type="pil", label="Blended Output"),
89
- title="Advanced Cloth Design Blending",
90
- description="Upload a cloth image and a design to blend them naturally using advanced warping & Poisson blending.",
91
  )
92
 
93
- # Launch the app
94
  if __name__ == "__main__":
95
- interface.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import transforms, datasets
5
+ from torch.utils.data import DataLoader
6
+ import segmentation_models_pytorch as smp
7
  import cv2
8
  import numpy as np
 
9
  import gradio as gr
10
+ from skimage.transform import warp, PiecewiseAffineTransform
11
+
12
+ # Define U-Net model for cloth fold segmentation
13
+ class ClothFoldUNet(nn.Module):
14
+ def __init__(self):
15
+ super(ClothFoldUNet, self).__init__()
16
+ self.model = smp.Unet(
17
+ encoder_name="resnet34", # Pre-trained backbone
18
+ encoder_weights="imagenet",
19
+ in_channels=3,
20
+ classes=1, # Single channel output for segmentation
21
+ )
22
 
23
+ def forward(self, x):
24
+ return self.model(x)
25
+
26
+ # Load dataset (placeholder, replace with real dataset)
27
+ def get_dataloader(batch_size=8):
28
+ transform = transforms.Compose([
29
+ transforms.Resize((256, 256)),
30
+ transforms.ToTensor(),
31
+ ])
32
+ dataset = datasets.FakeData(transform=transform)
33
+ return DataLoader(dataset, batch_size=batch_size, shuffle=True)
34
+
35
+ # Train function
36
+ def train_model():
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ model = ClothFoldUNet().to(device)
39
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
40
+ criterion = nn.BCEWithLogitsLoss()
41
+ dataloader = get_dataloader()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ for epoch in range(10): # Placeholder epoch count
44
+ for images, _ in dataloader:
45
+ images = images.to(device)
46
+ optimizer.zero_grad()
47
+ outputs = model(images)
48
+ loss = criterion(outputs, torch.ones_like(outputs)) # Placeholder loss
49
+ loss.backward()
50
+ optimizer.step()
51
+ print(f"Epoch {epoch+1}: Loss {loss.item():.4f}")
52
+
53
+ # Function to apply design onto cloth using segmentation mask
54
+ def apply_design(image, design, mask):
55
+ mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
56
+ design = cv2.resize(design, (image.shape[1], image.shape[0]))
57
+ mask = np.expand_dims(mask, axis=-1)
58
+ blended = (mask * design) + ((1 - mask) * image)
59
+ return blended.astype(np.uint8)
60
 
61
+ # Gradio Interface
62
+ def process_image(image, design):
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ model = ClothFoldUNet().to(device)
65
+ model.eval()
66
+ transform = transforms.Compose([
67
+ transforms.ToPILImage(),
68
+ transforms.Resize((256, 256)),
69
+ transforms.ToTensor(),
70
+ ])
71
+ img_tensor = transform(image).unsqueeze(0).to(device)
72
+ with torch.no_grad():
73
+ mask = model(img_tensor).squeeze().cpu().numpy()
74
+ result = apply_design(np.array(image), np.array(design), mask)
75
  return result
76
 
77
+ iface = gr.Interface(
 
 
78
  fn=process_image,
79
+ inputs=["image", "image"],
80
+ outputs="image",
81
+ title="AI Cloth Design Blending",
82
+ description="Upload a cloth image and a design to blend the design onto the cloth while considering the folds."
 
 
 
 
 
83
  )
84
 
85
+ # Run Gradio app
86
  if __name__ == "__main__":
87
+ train_model()
88
+ iface.launch()