gaur3009 commited on
Commit
1c2f991
·
verified ·
1 Parent(s): 30d0b6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -9,7 +9,15 @@ from cloth_segmentation.networks.u2net import U2NET # Import U²-Net
9
  # Load U²-Net model
10
  model_path = "cloth_segmentation/networks/cloth_segm_u2net_latest.pth" # Ensure this path is correct
11
  model = U2NET(3, 1)
12
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
 
 
 
 
 
 
 
 
13
  model.eval()
14
 
15
  def segment_dress(image_np):
 
9
  # Load U²-Net model
10
  model_path = "cloth_segmentation/networks/cloth_segm_u2net_latest.pth" # Ensure this path is correct
11
  model = U2NET(3, 1)
12
+
13
+ # Load the state dictionary
14
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
15
+
16
+ # Remove the 'module.' prefix from the keys
17
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
18
+
19
+ # Load the modified state dictionary into the model
20
+ model.load_state_dict(state_dict)
21
  model.eval()
22
 
23
  def segment_dress(image_np):