Spaces:
Running
on
Zero
Running
on
Zero
Update inference_scale.py
Browse files- inference_scale.py +3 -6
inference_scale.py
CHANGED
@@ -15,7 +15,7 @@ import time
|
|
15 |
|
16 |
|
17 |
@torch.no_grad()
|
18 |
-
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, prompt_text, target_text, mask_interval, cfg_coef, aug_text, aug_context, use_watermark, tts, device, decode_config):
|
19 |
# phonemize
|
20 |
text_tokens = [phn2num[phn] for phn in
|
21 |
tokenize_text(
|
@@ -54,6 +54,7 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
|
|
54 |
stop_repetition=decode_config['stop_repetition'],
|
55 |
kvcache=decode_config['kvcache'],
|
56 |
cfg_coef=cfg_coef,
|
|
|
57 |
aug_text=aug_text,
|
58 |
) # output is [1,K,T]
|
59 |
logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
|
@@ -69,23 +70,19 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
|
|
69 |
padding_length = (multiple - (current_length % multiple)) % multiple
|
70 |
if padding_length > 0:
|
71 |
wav = F.pad(wav, (0, padding_length), "constant", 0)
|
72 |
-
# new_emb = torch.zeros((1, emb.shape[1], encoded_frames.shape[-1])).to(encoded_frames.device)
|
73 |
new_wav = torch.zeros(1, encoded_frames.shape[-1]*320) # codec hz
|
74 |
|
75 |
ori_non_mask_intervals = [(max(item[0],0), item[1]) for item in ori_masks]
|
76 |
non_mask_intervals = [(max(item[0],0), item[1]) for item in masks]
|
77 |
for i in range(len(ori_non_mask_intervals)):
|
78 |
-
# new_emb[..., non_mask_intervals[i][0]:non_mask_intervals[i][1]] = emb[..., ori_non_mask_intervals[i][0]:ori_non_mask_intervals[i][1]]
|
79 |
new_wav[:, non_mask_intervals[i][0]*320:non_mask_intervals[i][1]*320] = wav[:, ori_non_mask_intervals[i][0]*320:ori_non_mask_intervals[i][1]*320]
|
80 |
|
81 |
-
# generated_sample = audio_tokenizer.wmdecode(encoded_frames, marks.to(encoded_frames.device), new_emb, scale)
|
82 |
generated_sample = audio_tokenizer.wmdecode(encoded_frames, marks.to(encoded_frames.device), new_wav.unsqueeze(0).to(encoded_frames.device), scale)
|
83 |
|
84 |
else:
|
85 |
generated_sample = audio_tokenizer.decode(encoded_frames, scale)
|
86 |
|
87 |
if tts:
|
88 |
-
wav, sr = torchaudio.load(audio_fn)
|
89 |
generated_sample = generated_sample[:,:, masks[0][1]*320:]
|
90 |
|
91 |
return generated_sample
|
@@ -118,4 +115,4 @@ def get_mask_interval(ali_fn, word_span):
|
|
118 |
return (start, end)
|
119 |
|
120 |
if __name__ == "__main__":
|
121 |
-
pass
|
|
|
15 |
|
16 |
|
17 |
@torch.no_grad()
|
18 |
+
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, prompt_text, target_text, mask_interval, cfg_coef, cfg_stride, aug_text, aug_context, use_watermark, tts, device, decode_config):
|
19 |
# phonemize
|
20 |
text_tokens = [phn2num[phn] for phn in
|
21 |
tokenize_text(
|
|
|
54 |
stop_repetition=decode_config['stop_repetition'],
|
55 |
kvcache=decode_config['kvcache'],
|
56 |
cfg_coef=cfg_coef,
|
57 |
+
cfg_stride=cfg_stride,
|
58 |
aug_text=aug_text,
|
59 |
) # output is [1,K,T]
|
60 |
logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
|
|
|
70 |
padding_length = (multiple - (current_length % multiple)) % multiple
|
71 |
if padding_length > 0:
|
72 |
wav = F.pad(wav, (0, padding_length), "constant", 0)
|
|
|
73 |
new_wav = torch.zeros(1, encoded_frames.shape[-1]*320) # codec hz
|
74 |
|
75 |
ori_non_mask_intervals = [(max(item[0],0), item[1]) for item in ori_masks]
|
76 |
non_mask_intervals = [(max(item[0],0), item[1]) for item in masks]
|
77 |
for i in range(len(ori_non_mask_intervals)):
|
|
|
78 |
new_wav[:, non_mask_intervals[i][0]*320:non_mask_intervals[i][1]*320] = wav[:, ori_non_mask_intervals[i][0]*320:ori_non_mask_intervals[i][1]*320]
|
79 |
|
|
|
80 |
generated_sample = audio_tokenizer.wmdecode(encoded_frames, marks.to(encoded_frames.device), new_wav.unsqueeze(0).to(encoded_frames.device), scale)
|
81 |
|
82 |
else:
|
83 |
generated_sample = audio_tokenizer.decode(encoded_frames, scale)
|
84 |
|
85 |
if tts:
|
|
|
86 |
generated_sample = generated_sample[:,:, masks[0][1]*320:]
|
87 |
|
88 |
return generated_sample
|
|
|
115 |
return (start, end)
|
116 |
|
117 |
if __name__ == "__main__":
|
118 |
+
pass
|