OpenSound commited on
Commit
6d01598
·
1 Parent(s): 2c71769

Update inference_scale.py

Browse files
Files changed (1) hide show
  1. inference_scale.py +3 -6
inference_scale.py CHANGED
@@ -15,7 +15,7 @@ import time
15
 
16
 
17
  @torch.no_grad()
18
- def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, prompt_text, target_text, mask_interval, cfg_coef, aug_text, aug_context, use_watermark, tts, device, decode_config):
19
  # phonemize
20
  text_tokens = [phn2num[phn] for phn in
21
  tokenize_text(
@@ -54,6 +54,7 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
54
  stop_repetition=decode_config['stop_repetition'],
55
  kvcache=decode_config['kvcache'],
56
  cfg_coef=cfg_coef,
 
57
  aug_text=aug_text,
58
  ) # output is [1,K,T]
59
  logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
@@ -69,23 +70,19 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
69
  padding_length = (multiple - (current_length % multiple)) % multiple
70
  if padding_length > 0:
71
  wav = F.pad(wav, (0, padding_length), "constant", 0)
72
- # new_emb = torch.zeros((1, emb.shape[1], encoded_frames.shape[-1])).to(encoded_frames.device)
73
  new_wav = torch.zeros(1, encoded_frames.shape[-1]*320) # codec hz
74
 
75
  ori_non_mask_intervals = [(max(item[0],0), item[1]) for item in ori_masks]
76
  non_mask_intervals = [(max(item[0],0), item[1]) for item in masks]
77
  for i in range(len(ori_non_mask_intervals)):
78
- # new_emb[..., non_mask_intervals[i][0]:non_mask_intervals[i][1]] = emb[..., ori_non_mask_intervals[i][0]:ori_non_mask_intervals[i][1]]
79
  new_wav[:, non_mask_intervals[i][0]*320:non_mask_intervals[i][1]*320] = wav[:, ori_non_mask_intervals[i][0]*320:ori_non_mask_intervals[i][1]*320]
80
 
81
- # generated_sample = audio_tokenizer.wmdecode(encoded_frames, marks.to(encoded_frames.device), new_emb, scale)
82
  generated_sample = audio_tokenizer.wmdecode(encoded_frames, marks.to(encoded_frames.device), new_wav.unsqueeze(0).to(encoded_frames.device), scale)
83
 
84
  else:
85
  generated_sample = audio_tokenizer.decode(encoded_frames, scale)
86
 
87
  if tts:
88
- wav, sr = torchaudio.load(audio_fn)
89
  generated_sample = generated_sample[:,:, masks[0][1]*320:]
90
 
91
  return generated_sample
@@ -118,4 +115,4 @@ def get_mask_interval(ali_fn, word_span):
118
  return (start, end)
119
 
120
  if __name__ == "__main__":
121
- pass
 
15
 
16
 
17
  @torch.no_grad()
18
+ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, prompt_text, target_text, mask_interval, cfg_coef, cfg_stride, aug_text, aug_context, use_watermark, tts, device, decode_config):
19
  # phonemize
20
  text_tokens = [phn2num[phn] for phn in
21
  tokenize_text(
 
54
  stop_repetition=decode_config['stop_repetition'],
55
  kvcache=decode_config['kvcache'],
56
  cfg_coef=cfg_coef,
57
+ cfg_stride=cfg_stride,
58
  aug_text=aug_text,
59
  ) # output is [1,K,T]
60
  logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
 
70
  padding_length = (multiple - (current_length % multiple)) % multiple
71
  if padding_length > 0:
72
  wav = F.pad(wav, (0, padding_length), "constant", 0)
 
73
  new_wav = torch.zeros(1, encoded_frames.shape[-1]*320) # codec hz
74
 
75
  ori_non_mask_intervals = [(max(item[0],0), item[1]) for item in ori_masks]
76
  non_mask_intervals = [(max(item[0],0), item[1]) for item in masks]
77
  for i in range(len(ori_non_mask_intervals)):
 
78
  new_wav[:, non_mask_intervals[i][0]*320:non_mask_intervals[i][1]*320] = wav[:, ori_non_mask_intervals[i][0]*320:ori_non_mask_intervals[i][1]*320]
79
 
 
80
  generated_sample = audio_tokenizer.wmdecode(encoded_frames, marks.to(encoded_frames.device), new_wav.unsqueeze(0).to(encoded_frames.device), scale)
81
 
82
  else:
83
  generated_sample = audio_tokenizer.decode(encoded_frames, scale)
84
 
85
  if tts:
 
86
  generated_sample = generated_sample[:,:, masks[0][1]*320:]
87
 
88
  return generated_sample
 
115
  return (start, end)
116
 
117
  if __name__ == "__main__":
118
+ pass