rajatsingh0702 commited on
Commit
e668e2c
·
1 Parent(s): cd097dc

Update mymodels.py

Browse files
Files changed (1) hide show
  1. mymodels.py +3 -3
mymodels.py CHANGED
@@ -258,8 +258,8 @@ class Color2Sketch(nn.Module):
258
  self.decoder = Decoder()
259
  if pretrained:
260
  print('Loading pretrained {0} model...'.format('Color2Sketch'), end=' ')
261
- checkpoint = torch.load('color2edge.pth')
262
- self.load_state_dict(checkpoint['netG'], strict=False)
263
  print("Done!")
264
  else:
265
  self.apply(weights_init)
@@ -395,7 +395,7 @@ class Sketch2Color(nn.Module):
395
  self.decoder = Decoder()
396
  if pretrained:
397
  print('Loading pretrained {0} model...'.format('Sketch2Color'), end=' ')
398
- checkpoint = torch.load('edge2color.pth')
399
  self.load_state_dict(checkpoint['netG'], strict=True)
400
  print("Done!")
401
  else:
 
258
  self.decoder = Decoder()
259
  if pretrained:
260
  print('Loading pretrained {0} model...'.format('Color2Sketch'), end=' ')
261
+ checkpoint = torch.load('color2edge.pth', map_location = "cuda" if torch.cuda.is_available() else "cpu")
262
+ self.load_state_dict(checkpoint['netG'], strict=True)
263
  print("Done!")
264
  else:
265
  self.apply(weights_init)
 
395
  self.decoder = Decoder()
396
  if pretrained:
397
  print('Loading pretrained {0} model...'.format('Sketch2Color'), end=' ')
398
+ checkpoint = torch.load('edge2color.pth', map_location = "cuda" if torch.cuda.is_available() else "cpu")
399
  self.load_state_dict(checkpoint['netG'], strict=True)
400
  print("Done!")
401
  else: