shredder-31 commited on
Commit
d7027f7
·
verified ·
1 Parent(s): 35ee781

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -2
main.py CHANGED
@@ -17,8 +17,9 @@ def initialize():
17
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
  ])
19
 
20
- checkpoint_path = "checkpoint_epoch_40.pth"
21
- model = ImgCap(cnn_feature_size=1024, lstm_hidden_size=1024, embedding_dim=1024, num_layers=2, vocab_size=len(vocab))
 
22
  model, _, _, _, _, _, _ = load_checkpoint(checkpoint_path=checkpoint_path, model=model)
23
 
24
  return model, vocab, transforms
 
17
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
  ])
19
 
20
+ checkpoint_path = "checkpoint_epoch_30.pth"
21
+ model = ImgCap(feature_size=2048, lstm_hidden_size=1024, embedding_dim=1024, num_layers=2, vocab_size=len(vocab))
22
+
23
  model, _, _, _, _, _, _ = load_checkpoint(checkpoint_path=checkpoint_path, model=model)
24
 
25
  return model, vocab, transforms