File size: 1,567 Bytes
de79343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import gdown
import torch

from networks import U2NET
from utils.saving_utils import save_checkpoint

os.makedirs("prev_checkpoints", exist_ok=True)
gdown.download(
    "https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
    "./prev_checkpoints/u2net.pth",
    quiet=False,
)

u_net = U2NET(in_ch=3, out_ch=4)
save_checkpoint(u_net, os.path.join("prev_checkpoints", "u2net_random.pth"))

# u2net.pth contains trained weights
trained_net_pth = os.path.join("prev_checkpoints", "u2net.pth")
# u2net_random.pth contains random weights
custom_net_pth = os.path.join("prev_checkpoints", "u2net_random.pth")

net_state_dict = torch.load(trained_net_pth)
count = 0
for k, v in net_state_dict.items():
    count += 1
print("Total number of layers in trained model are: {}".format(count))

custom_state_dict = torch.load(custom_net_pth)
count = 0
for k, v in custom_state_dict.items():
    count += 1
print("Total number of layers in trained model are: {}".format(count))

total_count = 0
update_count = 0
for k, v in net_state_dict.items():
    total_count += 1
    if custom_state_dict[k].shape == v.shape:
        update_count += 1
        custom_state_dict[k] = v

print(
    "Out of {} layers in custom network, {} layers weights are recovered from trained model".format(
        total_count, update_count
    )
)
torch.save(
    custom_state_dict, os.path.join("prev_checkpoints", "cloth_segm_unet_surgery.pth")
)
print("cloth_segm_unet_surgery.pth is generated in prev_checkpoints directory!")