Spaces:
Runtime error
Runtime error
Commit
·
e668e2c
1
Parent(s):
cd097dc
Update mymodels.py
Browse files- mymodels.py +3 -3
mymodels.py
CHANGED
|
@@ -258,8 +258,8 @@ class Color2Sketch(nn.Module):
|
|
| 258 |
self.decoder = Decoder()
|
| 259 |
if pretrained:
|
| 260 |
print('Loading pretrained {0} model...'.format('Color2Sketch'), end=' ')
|
| 261 |
-
checkpoint = torch.load('color2edge.pth')
|
| 262 |
-
self.load_state_dict(checkpoint['netG'], strict=
|
| 263 |
print("Done!")
|
| 264 |
else:
|
| 265 |
self.apply(weights_init)
|
|
@@ -395,7 +395,7 @@ class Sketch2Color(nn.Module):
|
|
| 395 |
self.decoder = Decoder()
|
| 396 |
if pretrained:
|
| 397 |
print('Loading pretrained {0} model...'.format('Sketch2Color'), end=' ')
|
| 398 |
-
checkpoint = torch.load('edge2color.pth')
|
| 399 |
self.load_state_dict(checkpoint['netG'], strict=True)
|
| 400 |
print("Done!")
|
| 401 |
else:
|
|
|
|
| 258 |
self.decoder = Decoder()
|
| 259 |
if pretrained:
|
| 260 |
print('Loading pretrained {0} model...'.format('Color2Sketch'), end=' ')
|
| 261 |
+
checkpoint = torch.load('color2edge.pth', map_location = "cuda" if torch.cuda.is_available() else "cpu")
|
| 262 |
+
self.load_state_dict(checkpoint['netG'], strict=True)
|
| 263 |
print("Done!")
|
| 264 |
else:
|
| 265 |
self.apply(weights_init)
|
|
|
|
| 395 |
self.decoder = Decoder()
|
| 396 |
if pretrained:
|
| 397 |
print('Loading pretrained {0} model...'.format('Sketch2Color'), end=' ')
|
| 398 |
+
checkpoint = torch.load('edge2color.pth', map_location = "cuda" if torch.cuda.is_available() else "cpu")
|
| 399 |
self.load_state_dict(checkpoint['netG'], strict=True)
|
| 400 |
print("Done!")
|
| 401 |
else:
|