gaur3009 commited on
Commit
4175fd1
·
verified ·
1 Parent(s): a22d3b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -34,6 +34,7 @@ def segment_dress(image_np):
34
  with torch.no_grad():
35
  output = model(input_tensor)[0][0].squeeze().cpu().numpy()
36
 
 
37
  dress_mask = (output > 0.5).astype(np.uint8) * 255
38
  dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_LINEAR)
39
 
@@ -44,12 +45,16 @@ def apply_grabcut(image_np, dress_mask):
44
  bgd_model = np.zeros((1, 65), np.float64)
45
  fgd_model = np.zeros((1, 65), np.float64)
46
 
47
- mask = np.where(dress_mask > 0, cv2.GC_FGD, cv2.GC_BGD).astype('uint8')
48
- rect = (10, 10, image_np.shape[1] - 10, image_np.shape[0] - 10)
 
 
 
 
49
 
50
  cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_MASK)
51
 
52
- refined_mask = np.where((mask == 2) | (mask == 0), 0, 255).astype("uint8")
53
  return refine_mask(refined_mask)
54
 
55
  def recolor_dress(image_np, dress_mask, target_color):
 
34
  with torch.no_grad():
35
  output = model(input_tensor)[0][0].squeeze().cpu().numpy()
36
 
37
+ output = (output - output.min()) / (output.max() - output.min() + 1e-8) # Normalize to [0, 1]
38
  dress_mask = (output > 0.5).astype(np.uint8) * 255
39
  dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_LINEAR)
40
 
 
45
  bgd_model = np.zeros((1, 65), np.float64)
46
  fgd_model = np.zeros((1, 65), np.float64)
47
 
48
+ mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
49
+
50
+ # Get bounding box of the mask
51
+ coords = cv2.findNonZero(dress_mask)
52
+ x, y, w, h = cv2.boundingRect(coords)
53
+ rect = (x, y, w, h)
54
 
55
  cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_MASK)
56
 
57
+ refined_mask = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
58
  return refine_mask(refined_mask)
59
 
60
  def recolor_dress(image_np, dress_mask, target_color):