Arnab Das commited on
Commit
948bfd2
·
1 Parent(s): 3f96ef5
manipulate_model/model.py CHANGED
@@ -1,11 +1,12 @@
1
  import torch
2
  import torch.nn as nn
 
3
 
4
  from manipulate_model.encoder.encoder import Encoder
5
  from manipulate_model.decoder.decoder import Decoder
6
 
7
 
8
- class Model(nn.Module):
9
  def __init__(self, config):
10
  super(Model, self).__init__()
11
  self.config = config
 
1
  import torch
2
  import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
 
5
  from manipulate_model.encoder.encoder import Encoder
6
  from manipulate_model.decoder.decoder import Decoder
7
 
8
 
9
+ class Model(nn.Module, PyTorchModelHubMixin):
10
  def __init__(self, config):
11
  super(Model, self).__init__()
12
  self.config = config
manipulate_model/utils.py CHANGED
@@ -17,7 +17,7 @@ def get_config_and_model(model_root="manipulate_model/demo-model/audio"):
17
  if isinstance(config.model.decoder, str):
18
  config.model.decoder = OmegaConf.load(config.model.decoder)
19
 
20
- model = Model(config)
21
  #weights = torch.load(os.path.join(model_root, "weights.pt"))
22
  #model.load_state_dict(weights["model_state_dict"])
23
 
 
17
  if isinstance(config.model.decoder, str):
18
  config.model.decoder = OmegaConf.load(config.model.decoder)
19
 
20
+ model = Model.from_pretrained("arnabdas8901/manipulation_detection_transformer", config)
21
  #weights = torch.load(os.path.join(model_root, "weights.pt"))
22
  #model.load_state_dict(weights["model_state_dict"])
23