ajayarora1235 commited on
Commit
f9f05d9
·
1 Parent(s): 1cc40d5

new hubert method

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +1 -1
  3. hubert.pth +3 -0
  4. vc_infer_pipeline.py +3 -2
.gitattributes CHANGED
@@ -2,3 +2,4 @@ ilariasuitewallpaper.jpg filter=lfs diff=lfs merge=lfs -text
2
  ilariaaisuite.png filter=lfs diff=lfs merge=lfs -text
3
  pretrained_models/giga330M.pth filter=lfs diff=lfs merge=lfs -text
4
  pretrained_models/encodec_4cb2048_giga.th filter=lfs diff=lfs merge=lfs -text
 
 
2
  ilariaaisuite.png filter=lfs diff=lfs merge=lfs -text
3
  pretrained_models/giga330M.pth filter=lfs diff=lfs merge=lfs -text
4
  pretrained_models/encodec_4cb2048_giga.th filter=lfs diff=lfs merge=lfs -text
5
+ hubert.pth filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -246,7 +246,7 @@ associated_links = {}
246
  def load_hubert():
247
  global hubert_model
248
  # Load the model
249
- hubert_model = torch.load("hubert_base.pt", map_location=config.device)
250
 
251
  # Prepare the model
252
  hubert_model = hubert_model.to(config.device)
 
246
  def load_hubert():
247
  global hubert_model
248
  # Load the model
249
+ hubert_model = torch.load("hubert_base.pth", map_location=config.device)
250
 
251
  # Prepare the model
252
  hubert_model = hubert_model.to(config.device)
hubert.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e579cfcfb99bfca12e392d89854f7ed722ebb08c74daa8d54b4b4165436e8f7
3
+ size 377560373
vc_infer_pipeline.py CHANGED
@@ -396,8 +396,9 @@ class VC(object):
396
  }
397
  t0 = ttime()
398
  with torch.no_grad():
399
- logits = model.extract_features(**inputs)
400
- feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
 
401
  if protect < 0.5 and pitch != None and pitchf != None:
402
  feats0 = feats.clone()
403
  if (
 
396
  }
397
  t0 = ttime()
398
  with torch.no_grad():
399
+ feats = model(inputs["source"])["last_hidden_state"]
400
+ # logits = model.extract_features(**inputs)
401
+ # feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
402
  if protect < 0.5 and pitch != None and pitchf != None:
403
  feats0 = feats.clone()
404
  if (