jadechoghari HF Staff commited on
Commit
2e07a37
·
verified ·
1 Parent(s): f792798

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +0 -6
modeling.py CHANGED
@@ -3,10 +3,6 @@ import torchaudio
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel
5
  import torch
6
- # download repo
7
- from huggingface_hub import snapshot_download
8
- repo_dir = snapshot_download(repo_id="jadechoghari/VoiceRestore", repo_type="model", local_dir="VoiceRestoreRepo")
9
- sys.path.append(repo_dir)
10
  from BigVGAN import bigvgan
11
  from BigVGAN.meldataset import get_mel_spectrogram
12
  from voice_restore import VoiceRestore
@@ -52,10 +48,8 @@ class VoiceRestore(PreTrainedModel):
52
  self.optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=self.bigvgan_model)
53
  save_path = "/content/voicerestore/checkpoints/voice-restore-20d-16h-optim.pt"
54
  state_dict = torch.load(save_path, map_location=torch.device(device))
55
- print("loaded")
56
  if 'model_state_dict' in state_dict:
57
  state_dict = state_dict['model_state_dict']
58
- print("change keys")
59
 
60
  self.optimized_model.voice_restore.load_state_dict(state_dict, strict=True)
61
  self.optimized_model.eval()
 
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel
5
  import torch
 
 
 
 
6
  from BigVGAN import bigvgan
7
  from BigVGAN.meldataset import get_mel_spectrogram
8
  from voice_restore import VoiceRestore
 
48
  self.optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=self.bigvgan_model)
49
  save_path = "/content/voicerestore/checkpoints/voice-restore-20d-16h-optim.pt"
50
  state_dict = torch.load(save_path, map_location=torch.device(device))
 
51
  if 'model_state_dict' in state_dict:
52
  state_dict = state_dict['model_state_dict']
 
53
 
54
  self.optimized_model.voice_restore.load_state_dict(state_dict, strict=True)
55
  self.optimized_model.eval()