RASPIAUDIO commited on
Commit
274613e
·
1 Parent(s): 14dc6ff

french version 1.00

Browse files
Files changed (2) hide show
  1. app.py +1 -5
  2. model/utils.py +2 -2
app.py CHANGED
@@ -31,11 +31,7 @@ def gpu_decorator(func):
31
  else:
32
  return func
33
 
34
- device = (
35
- "cpu"
36
- if torch.cuda.is_available()
37
- else "mps" if torch.backends.mps.is_available() else "cpu"
38
- )
39
 
40
  print(f"Using {device} device")
41
 
 
31
  else:
32
  return func
33
 
34
+ device = device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
35
 
36
  print(f"Using {device} device")
37
 
model/utils.py CHANGED
@@ -572,9 +572,9 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
572
  if ckpt_type == "safetensors":
573
  ema_model.load_state_dict(checkpoint)
574
  else:
575
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
  ema_model.copy_params_from_ema_to_model()
577
  else:
578
- model.load_state_dict(checkpoint['model_state_dict'])
579
 
580
  return model
 
572
  if ckpt_type == "safetensors":
573
  ema_model.load_state_dict(checkpoint)
574
  else:
575
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'], strict=False)
576
  ema_model.copy_params_from_ema_to_model()
577
  else:
578
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
579
 
580
  return model