gigant commited on
Commit
fedbc55
·
verified ·
1 Parent(s): 87dbbfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -59,19 +59,19 @@ print('loading models')
59
  ae_config = OmegaConf.load(ae_config_path)
60
  ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
61
  ae_model.eval().requires_grad_(False).to(device)
62
- ae_model.load_state_dict(torch.load(ae_model_path), weights_only=True)
63
  n_ch, side_y, side_x = 4, 32, 32
64
 
65
  # diffusion model
66
  model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
67
- model.load_state_dict(torch.load(checkpoint, map_location='cpu'), weights_only=True)
68
  model = model.to(device).eval().requires_grad_(False)
69
 
70
  # CLOOB
71
  cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
72
  cloob = model_pt.get_pt_model(cloob_config)
73
  checkpoint = pretrained.download_checkpoint(cloob_config)
74
- cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint), weights_only=True)
75
  cloob.eval().requires_grad_(False).to(device)
76
 
77
 
 
59
  ae_config = OmegaConf.load(ae_config_path)
60
  ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
61
  ae_model.eval().requires_grad_(False).to(device)
62
+ ae_model.load_state_dict(torch.load(ae_model_path, weights_only=True))
63
  n_ch, side_y, side_x = 4, 32, 32
64
 
65
  # diffusion model
66
  model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
67
+ model.load_state_dict(torch.load(checkpoint, map_location='cpu', weights_only=True))
68
  model = model.to(device).eval().requires_grad_(False)
69
 
70
  # CLOOB
71
  cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
72
  cloob = model_pt.get_pt_model(cloob_config)
73
  checkpoint = pretrained.download_checkpoint(cloob_config)
74
+ cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
75
  cloob.eval().requires_grad_(False).to(device)
76
 
77