|
import collections, torch, torchvision, numpy |
|
|
|
def load_places_vgg16(weight_file): |
|
model = torchvision.models.vgg16(num_classes=365) |
|
model.features = torch.nn.Sequential(collections.OrderedDict(zip([ |
|
'conv1_1', 'relu1_1', |
|
'conv1_2', 'relu1_2', |
|
'pool1', |
|
'conv2_1', 'relu2_1', |
|
'conv2_2', 'relu2_2', |
|
'pool2', |
|
'conv3_1', 'relu3_1', |
|
'conv3_2', 'relu3_2', |
|
'conv3_3', 'relu3_3', |
|
'pool3', |
|
'conv4_1', 'relu4_1', |
|
'conv4_2', 'relu4_2', |
|
'conv4_3', 'relu4_3', |
|
'pool4', |
|
'conv5_1', 'relu5_1', |
|
'conv5_2', 'relu5_2', |
|
'conv5_3', 'relu5_3', |
|
'pool5'], |
|
model.features))) |
|
|
|
model.classifier = torch.nn.Sequential(collections.OrderedDict(zip([ |
|
'fc6', 'relu6', |
|
'drop6', |
|
'fc7', 'relu7', |
|
'drop7', |
|
'fc8a'], |
|
model.classifier))) |
|
|
|
state_dict = torch.load(weight_file) |
|
converted_state_dict = ({ |
|
l: torch.from_numpy(numpy.array(v)).view_as(p) |
|
for k, v in state_dict.items() |
|
for l, p in model.named_parameters() if k in l}) |
|
model.load_state_dict(converted_state_dict) |
|
|
|
|
|
|
|
return model |
|
|