Commit
·
e668e2c
1
Parent(s):
cd097dc
Update mymodels.py
Browse files- mymodels.py +3 -3
mymodels.py
CHANGED
@@ -258,8 +258,8 @@ class Color2Sketch(nn.Module):
|
|
258 |
self.decoder = Decoder()
|
259 |
if pretrained:
|
260 |
print('Loading pretrained {0} model...'.format('Color2Sketch'), end=' ')
|
261 |
-
checkpoint = torch.load('color2edge.pth')
|
262 |
-
self.load_state_dict(checkpoint['netG'], strict=
|
263 |
print("Done!")
|
264 |
else:
|
265 |
self.apply(weights_init)
|
@@ -395,7 +395,7 @@ class Sketch2Color(nn.Module):
|
|
395 |
self.decoder = Decoder()
|
396 |
if pretrained:
|
397 |
print('Loading pretrained {0} model...'.format('Sketch2Color'), end=' ')
|
398 |
-
checkpoint = torch.load('edge2color.pth')
|
399 |
self.load_state_dict(checkpoint['netG'], strict=True)
|
400 |
print("Done!")
|
401 |
else:
|
|
|
258 |
self.decoder = Decoder()
|
259 |
if pretrained:
|
260 |
print('Loading pretrained {0} model...'.format('Color2Sketch'), end=' ')
|
261 |
+
checkpoint = torch.load('color2edge.pth', map_location = "cuda" if torch.cuda.is_available() else "cpu")
|
262 |
+
self.load_state_dict(checkpoint['netG'], strict=True)
|
263 |
print("Done!")
|
264 |
else:
|
265 |
self.apply(weights_init)
|
|
|
395 |
self.decoder = Decoder()
|
396 |
if pretrained:
|
397 |
print('Loading pretrained {0} model...'.format('Sketch2Color'), end=' ')
|
398 |
+
checkpoint = torch.load('edge2color.pth', map_location = "cuda" if torch.cuda.is_available() else "cpu")
|
399 |
self.load_state_dict(checkpoint['netG'], strict=True)
|
400 |
print("Done!")
|
401 |
else:
|