vindemiatrix commited on
Commit
0d5a203
·
verified ·
1 Parent(s): 0286d01

Update vits_inference.py

Browse files
Files changed (1) hide show
  1. vits_inference.py +16 -29
vits_inference.py CHANGED
@@ -1,35 +1,22 @@
1
- import torch
2
- import torchaudio
3
- from vits import SynthesizerTrn
4
- from mel_processing import spectrogram_torch
5
  import argparse
 
 
 
6
 
7
- def load_model(model_path, config_path):
8
- hps = torch.load(config_path)
9
- model = SynthesizerTrn(
10
- hps.data.filter_length // 2 + 1,
11
- hps.train.segment_size // hps.data.hop_length,
12
- **hps.model).cuda()
13
-
14
- model.load_state_dict(torch.load(model_path, map_location="cuda"))
15
- model.eval()
16
- return model, hps
17
 
18
- def synthesize_text(text, model, hps):
19
- text = torch.LongTensor([ord(c) for c in text]).unsqueeze(0).cuda()
20
- with torch.no_grad():
21
- audio = model.infer(text, noise_scale=0.667, length_scale=1.0)[0]
22
- return audio.cpu().numpy()
23
 
24
- def main():
25
- parser = argparse.ArgumentParser()
26
- parser.add_argument("--text", type=str, required=True, help="Texto a ser dublado")
27
- parser.add_argument("--output_audio", type=str, required=True, help="Arquivo de saída")
28
- args = parser.parse_args()
29
 
30
- model, hps = load_model("vits_model.pth", "vits_config.pth")
31
- audio_data = synthesize_text(args.text, model, hps)
32
- torchaudio.save(args.output_audio, torch.tensor(audio_data), sample_rate=22050)
33
 
34
- if __name__ == "__main__":
35
- main()
 
 
 
 
 
1
  import argparse
2
+ from vits import SynthesizerTrn
3
+ import torchaudio
4
+ import os
5
 
6
+ # Argumentos
7
+ parser = argparse.ArgumentParser(description="Dublagem com VITS")
8
+ parser.add_argument("--text", type=str, required=True, help="Texto para dublagem")
9
+ parser.add_argument("--input_audio", type=str, required=True, help="Áudio original para clonagem de voz")
10
+ parser.add_argument("--output_audio", type=str, required=True, help="Áudio de saída dublado")
11
+ parser.add_argument("--language", type=str, required=True, help="Idioma da dublagem")
 
 
 
 
12
 
13
+ args = parser.parse_args()
 
 
 
 
14
 
15
+ # Carregar modelo VITS
16
+ model = SynthesizerTrn(args.language)
 
 
 
17
 
18
+ # Processar dublagem
19
+ waveform, sample_rate = model.synthesize(args.text, args.input_audio)
20
+ torchaudio.save(args.output_audio, waveform, sample_rate)
21
 
22
+ print(f"✅ Dublagem concluída: {args.output_audio}")