Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- cosyvoice/bin/average_model.py +92 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/frontend.py +211 -0
- cosyvoice/dataset/processor.py +435 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +0 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/embedding.py +294 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +318 -0
- cosyvoice/utils/executor.py +172 -0
- cosyvoice/utils/file_utils.py +89 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/mask.py +267 -0
- cosyvoice/utils/train_utils.py +345 -0
- docker/Dockerfile +51 -0
- examples/libritts/cosyvoice/conf/cosyvoice.yaml +257 -0
- examples/libritts/cosyvoice/local/download_and_untar.sh +97 -0
- examples/libritts/cosyvoice/run.sh +126 -0
- examples/magicdata-read/cosyvoice/conf/ds_stage2.json +42 -0
- examples/magicdata-read/cosyvoice/local/download_and_untar.sh +97 -0
- examples/magicdata-read/cosyvoice/local/prepare_data.py +52 -0
- examples/magicdata-read/cosyvoice/path.sh +3 -0
- examples/magicdata-read/cosyvoice/run.sh +111 -0
- examples/magicdata-read/cosyvoice/tts_text.json +18 -0
- runtime/python/fastapi/server.py +101 -0
- runtime/python/grpc/client.py +106 -0
- runtime/python/grpc/server.py +96 -0
- third_party/Matcha-TTS/.env.example +6 -0
- third_party/Matcha-TTS/.github/codecov.yml +15 -0
- third_party/Matcha-TTS/.gitignore +163 -0
- third_party/Matcha-TTS/.pre-commit-config.yaml +59 -0
- third_party/Matcha-TTS/.project-root +2 -0
- third_party/Matcha-TTS/Makefile +42 -0
- third_party/Matcha-TTS/configs/__init__.py +1 -0
- third_party/Matcha-TTS/configs/callbacks/default.yaml +5 -0
- third_party/Matcha-TTS/configs/callbacks/model_summary.yaml +5 -0
- third_party/Matcha-TTS/configs/debug/fdr.yaml +9 -0
- third_party/Matcha-TTS/configs/debug/limit.yaml +12 -0
- third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml +14 -0
- third_party/Matcha-TTS/configs/hydra/default.yaml +19 -0
- third_party/Matcha-TTS/configs/logger/comet.yaml +12 -0
- third_party/Matcha-TTS/configs/logger/many_loggers.yaml +9 -0
- third_party/Matcha-TTS/configs/logger/mlflow.yaml +12 -0
- third_party/Matcha-TTS/configs/logger/neptune.yaml +9 -0
cosyvoice/bin/average_model.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import argparse
|
18 |
+
import glob
|
19 |
+
|
20 |
+
import yaml
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
def get_args():
|
25 |
+
parser = argparse.ArgumentParser(description='average model')
|
26 |
+
parser.add_argument('--dst_model', required=True, help='averaged model')
|
27 |
+
parser.add_argument('--src_path',
|
28 |
+
required=True,
|
29 |
+
help='src model path for average')
|
30 |
+
parser.add_argument('--val_best',
|
31 |
+
action="store_true",
|
32 |
+
help='averaged model')
|
33 |
+
parser.add_argument('--num',
|
34 |
+
default=5,
|
35 |
+
type=int,
|
36 |
+
help='nums for averaged model')
|
37 |
+
|
38 |
+
args = parser.parse_args()
|
39 |
+
print(args)
|
40 |
+
return args
|
41 |
+
|
42 |
+
|
43 |
+
def main():
|
44 |
+
args = get_args()
|
45 |
+
val_scores = []
|
46 |
+
if args.val_best:
|
47 |
+
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
48 |
+
yamls = [
|
49 |
+
f for f in yamls
|
50 |
+
if not (os.path.basename(f).startswith('train')
|
51 |
+
or os.path.basename(f).startswith('init'))
|
52 |
+
]
|
53 |
+
for y in yamls:
|
54 |
+
with open(y, 'r') as f:
|
55 |
+
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
56 |
+
loss = float(dic_yaml['loss_dict']['loss'])
|
57 |
+
epoch = int(dic_yaml['epoch'])
|
58 |
+
step = int(dic_yaml['step'])
|
59 |
+
tag = dic_yaml['tag']
|
60 |
+
val_scores += [[epoch, step, loss, tag]]
|
61 |
+
sorted_val_scores = sorted(val_scores,
|
62 |
+
key=lambda x: x[2],
|
63 |
+
reverse=False)
|
64 |
+
print("best val (epoch, step, loss, tag) = " +
|
65 |
+
str(sorted_val_scores[:args.num]))
|
66 |
+
path_list = [
|
67 |
+
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
68 |
+
for score in sorted_val_scores[:args.num]
|
69 |
+
]
|
70 |
+
print(path_list)
|
71 |
+
avg = {}
|
72 |
+
num = args.num
|
73 |
+
assert num == len(path_list)
|
74 |
+
for path in path_list:
|
75 |
+
print('Processing {}'.format(path))
|
76 |
+
states = torch.load(path, map_location=torch.device('cpu'))
|
77 |
+
for k in states.keys():
|
78 |
+
if k not in avg.keys():
|
79 |
+
avg[k] = states[k].clone()
|
80 |
+
else:
|
81 |
+
avg[k] += states[k]
|
82 |
+
# average
|
83 |
+
for k in avg.keys():
|
84 |
+
if avg[k] is not None:
|
85 |
+
# pytorch 1.6 use true_divide instead of /=
|
86 |
+
avg[k] = torch.true_divide(avg[k], num)
|
87 |
+
print('Saving to {}'.format(args.dst_model))
|
88 |
+
torch.save(avg, args.dst_model)
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == '__main__':
|
92 |
+
main()
|
cosyvoice/cli/__init__.py
ADDED
File without changes
|
cosyvoice/cli/frontend.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from functools import partial
|
15 |
+
from typing import Generator
|
16 |
+
import json
|
17 |
+
import onnxruntime
|
18 |
+
import torch
|
19 |
+
import numpy as np
|
20 |
+
import whisper
|
21 |
+
from typing import Callable
|
22 |
+
import torchaudio.compliance.kaldi as kaldi
|
23 |
+
import torchaudio
|
24 |
+
import os
|
25 |
+
import re
|
26 |
+
import inflect
|
27 |
+
try:
|
28 |
+
import ttsfrd
|
29 |
+
use_ttsfrd = True
|
30 |
+
except ImportError:
|
31 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
32 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
33 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
34 |
+
use_ttsfrd = False
|
35 |
+
from cosyvoice.utils.file_utils import logging
|
36 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
37 |
+
|
38 |
+
|
39 |
+
class CosyVoiceFrontEnd:
|
40 |
+
|
41 |
+
def __init__(self,
|
42 |
+
get_tokenizer: Callable,
|
43 |
+
feat_extractor: Callable,
|
44 |
+
campplus_model: str,
|
45 |
+
speech_tokenizer_model: str,
|
46 |
+
spk2info: str = '',
|
47 |
+
allowed_special: str = 'all'):
|
48 |
+
self.tokenizer = get_tokenizer()
|
49 |
+
self.feat_extractor = feat_extractor
|
50 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
51 |
+
option = onnxruntime.SessionOptions()
|
52 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
53 |
+
option.intra_op_num_threads = 1
|
54 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
55 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
56 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
57 |
+
"CPUExecutionProvider"])
|
58 |
+
if os.path.exists(spk2info):
|
59 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
60 |
+
else:
|
61 |
+
self.spk2info = {}
|
62 |
+
self.allowed_special = allowed_special
|
63 |
+
self.use_ttsfrd = use_ttsfrd
|
64 |
+
if self.use_ttsfrd:
|
65 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
66 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
67 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
68 |
+
'failed to initialize ttsfrd resource'
|
69 |
+
self.frd.set_lang_type('pinyinvg')
|
70 |
+
else:
|
71 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
72 |
+
self.en_tn_model = EnNormalizer()
|
73 |
+
self.inflect_parser = inflect.engine()
|
74 |
+
|
75 |
+
def _extract_text_token(self, text):
|
76 |
+
if isinstance(text, Generator):
|
77 |
+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
78 |
+
# NOTE add a dummy text_token_len for compatibility
|
79 |
+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
80 |
+
else:
|
81 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
82 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
83 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
84 |
+
return text_token, text_token_len
|
85 |
+
|
86 |
+
def _extract_text_token_generator(self, text_generator):
|
87 |
+
for text in text_generator:
|
88 |
+
text_token, _ = self._extract_text_token(text)
|
89 |
+
for i in range(text_token.shape[1]):
|
90 |
+
yield text_token[:, i: i + 1]
|
91 |
+
|
92 |
+
def _extract_speech_token(self, speech):
|
93 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
94 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
95 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
96 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
97 |
+
feat.detach().cpu().numpy(),
|
98 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
99 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
100 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
101 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
102 |
+
return speech_token, speech_token_len
|
103 |
+
|
104 |
+
def _extract_spk_embedding(self, speech):
|
105 |
+
feat = kaldi.fbank(speech,
|
106 |
+
num_mel_bins=80,
|
107 |
+
dither=0,
|
108 |
+
sample_frequency=16000)
|
109 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
110 |
+
embedding = self.campplus_session.run(None,
|
111 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
112 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
113 |
+
return embedding
|
114 |
+
|
115 |
+
def _extract_speech_feat(self, speech):
|
116 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
117 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
118 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
119 |
+
return speech_feat, speech_feat_len
|
120 |
+
|
121 |
+
def text_normalize(self, text, split=True, text_frontend=True):
|
122 |
+
if isinstance(text, Generator):
|
123 |
+
logging.info('get tts_text generator, will skip text_normalize!')
|
124 |
+
return [text]
|
125 |
+
if text_frontend is False:
|
126 |
+
return [text] if split is True else text
|
127 |
+
text = text.strip()
|
128 |
+
if self.use_ttsfrd:
|
129 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
130 |
+
text = ''.join(texts)
|
131 |
+
else:
|
132 |
+
if contains_chinese(text):
|
133 |
+
text = self.zh_tn_model.normalize(text)
|
134 |
+
text = text.replace("\n", "")
|
135 |
+
text = replace_blank(text)
|
136 |
+
text = replace_corner_mark(text)
|
137 |
+
text = text.replace(".", "。")
|
138 |
+
text = text.replace(" - ", ",")
|
139 |
+
text = remove_bracket(text)
|
140 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
141 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
142 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
143 |
+
else:
|
144 |
+
text = self.en_tn_model.normalize(text)
|
145 |
+
text = spell_out_number(text, self.inflect_parser)
|
146 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
147 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
148 |
+
texts = [i for i in texts if not is_only_punctuation(i)]
|
149 |
+
return texts if split is True else text
|
150 |
+
|
151 |
+
def frontend_sft(self, tts_text, spk_id):
|
152 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
153 |
+
embedding = self.spk2info[spk_id]['embedding']
|
154 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
155 |
+
return model_input
|
156 |
+
|
157 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
|
158 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
159 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
160 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
161 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
162 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
163 |
+
if resample_rate == 24000:
|
164 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
165 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
166 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
167 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
168 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
169 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
170 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
171 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
172 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
173 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
174 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
175 |
+
return model_input
|
176 |
+
|
177 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
178 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
179 |
+
# in cross lingual mode, we remove prompt in llm
|
180 |
+
del model_input['prompt_text']
|
181 |
+
del model_input['prompt_text_len']
|
182 |
+
del model_input['llm_prompt_speech_token']
|
183 |
+
del model_input['llm_prompt_speech_token_len']
|
184 |
+
return model_input
|
185 |
+
|
186 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
187 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
188 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
189 |
+
del model_input['llm_embedding']
|
190 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
191 |
+
model_input['prompt_text'] = instruct_text_token
|
192 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
193 |
+
return model_input
|
194 |
+
|
195 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
|
196 |
+
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
|
197 |
+
del model_input['llm_prompt_speech_token']
|
198 |
+
del model_input['llm_prompt_speech_token_len']
|
199 |
+
return model_input
|
200 |
+
|
201 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
202 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
203 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
204 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
205 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
206 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
207 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
208 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
209 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
210 |
+
'flow_embedding': embedding}
|
211 |
+
return model_input
|
cosyvoice/dataset/processor.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import logging
|
15 |
+
import random
|
16 |
+
|
17 |
+
import pyarrow.parquet as pq
|
18 |
+
from io import BytesIO
|
19 |
+
import torch
|
20 |
+
import torchaudio
|
21 |
+
from torch.nn.utils.rnn import pad_sequence
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import pyworld as pw
|
24 |
+
|
25 |
+
|
26 |
+
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
27 |
+
|
28 |
+
|
29 |
+
def parquet_opener(data, mode='train', tts_data={}):
|
30 |
+
""" Give url or local file, return file descriptor
|
31 |
+
Inplace operation.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
data(Iterable[str]): url or local file list
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Iterable[{src, stream}]
|
38 |
+
"""
|
39 |
+
for sample in data:
|
40 |
+
assert 'src' in sample
|
41 |
+
url = sample['src']
|
42 |
+
try:
|
43 |
+
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
44 |
+
df = df.to_pandas()
|
45 |
+
for i in range(len(df)):
|
46 |
+
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
47 |
+
continue
|
48 |
+
sample.update(dict(df.loc[i]))
|
49 |
+
if mode == 'train':
|
50 |
+
# NOTE do not return sample directly, must initialize a new dict
|
51 |
+
yield {**sample}
|
52 |
+
else:
|
53 |
+
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
54 |
+
yield {**sample, 'tts_index': index, 'tts_text': text}
|
55 |
+
except Exception as ex:
|
56 |
+
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
57 |
+
|
58 |
+
|
59 |
+
def filter(data,
|
60 |
+
max_length=10240,
|
61 |
+
min_length=10,
|
62 |
+
token_max_length=200,
|
63 |
+
token_min_length=1,
|
64 |
+
min_output_input_ratio=0.0005,
|
65 |
+
max_output_input_ratio=1,
|
66 |
+
mode='train'):
|
67 |
+
""" Filter sample according to feature and label length
|
68 |
+
Inplace operation.
|
69 |
+
|
70 |
+
Args::
|
71 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
72 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
73 |
+
min_length: drop utterance which is less than min_length(10ms)
|
74 |
+
token_max_length: drop utterance which is greater than
|
75 |
+
token_max_length, especially when use char unit for
|
76 |
+
english modeling
|
77 |
+
token_min_length: drop utterance which is
|
78 |
+
less than token_max_length
|
79 |
+
min_output_input_ratio: minimal ration of
|
80 |
+
token_length / feats_length(10ms)
|
81 |
+
max_output_input_ratio: maximum ration of
|
82 |
+
token_length / feats_length(10ms)
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Iterable[{key, wav, label, sample_rate}]
|
86 |
+
"""
|
87 |
+
for sample in data:
|
88 |
+
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
89 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
90 |
+
del sample['audio_data']
|
91 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
92 |
+
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
93 |
+
if num_frames < min_length:
|
94 |
+
continue
|
95 |
+
if num_frames > max_length:
|
96 |
+
continue
|
97 |
+
if len(sample['text_token']) < token_min_length:
|
98 |
+
continue
|
99 |
+
if len(sample['text_token']) > token_max_length:
|
100 |
+
continue
|
101 |
+
if len(sample['speech_token']) == 0:
|
102 |
+
continue
|
103 |
+
if num_frames != 0:
|
104 |
+
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
105 |
+
continue
|
106 |
+
if len(sample['text_token']) / num_frames > max_output_input_ratio:
|
107 |
+
continue
|
108 |
+
yield sample
|
109 |
+
|
110 |
+
|
111 |
+
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
112 |
+
""" Resample data.
|
113 |
+
Inplace operation.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
117 |
+
resample_rate: target resample rate
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Iterable[{key, wav, label, sample_rate}]
|
121 |
+
"""
|
122 |
+
for sample in data:
|
123 |
+
assert 'sample_rate' in sample
|
124 |
+
assert 'speech' in sample
|
125 |
+
sample_rate = sample['sample_rate']
|
126 |
+
waveform = sample['speech']
|
127 |
+
if sample_rate != resample_rate:
|
128 |
+
if sample_rate < min_sample_rate:
|
129 |
+
continue
|
130 |
+
sample['sample_rate'] = resample_rate
|
131 |
+
sample['speech'] = torchaudio.transforms.Resample(
|
132 |
+
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
133 |
+
max_val = sample['speech'].abs().max()
|
134 |
+
if max_val > 1:
|
135 |
+
sample['speech'] /= max_val
|
136 |
+
yield sample
|
137 |
+
|
138 |
+
|
139 |
+
def truncate(data, truncate_length=24576, mode='train'):
|
140 |
+
""" Truncate data.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
144 |
+
truncate_length: truncate length
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Iterable[{key, wav, label, sample_rate}]
|
148 |
+
"""
|
149 |
+
for sample in data:
|
150 |
+
waveform = sample['speech']
|
151 |
+
if waveform.shape[1] > truncate_length:
|
152 |
+
start = random.randint(0, waveform.shape[1] - truncate_length)
|
153 |
+
waveform = waveform[:, start: start + truncate_length]
|
154 |
+
else:
|
155 |
+
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
156 |
+
sample['speech'] = waveform
|
157 |
+
yield sample
|
158 |
+
|
159 |
+
|
160 |
+
def compute_fbank(data,
|
161 |
+
feat_extractor,
|
162 |
+
mode='train'):
|
163 |
+
""" Extract fbank
|
164 |
+
|
165 |
+
Args:
|
166 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Iterable[{key, feat, label}]
|
170 |
+
"""
|
171 |
+
for sample in data:
|
172 |
+
assert 'sample_rate' in sample
|
173 |
+
assert 'speech' in sample
|
174 |
+
assert 'utt' in sample
|
175 |
+
assert 'text_token' in sample
|
176 |
+
waveform = sample['speech']
|
177 |
+
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
178 |
+
sample['speech_feat'] = mat
|
179 |
+
yield sample
|
180 |
+
|
181 |
+
|
182 |
+
def compute_f0(data, sample_rate, hop_size, mode='train'):
|
183 |
+
""" Extract f0
|
184 |
+
|
185 |
+
Args:
|
186 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
Iterable[{key, feat, label}]
|
190 |
+
"""
|
191 |
+
frame_period = hop_size * 1000 / sample_rate
|
192 |
+
for sample in data:
|
193 |
+
assert 'sample_rate' in sample
|
194 |
+
assert 'speech' in sample
|
195 |
+
assert 'utt' in sample
|
196 |
+
assert 'text_token' in sample
|
197 |
+
waveform = sample['speech']
|
198 |
+
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
|
199 |
+
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
|
200 |
+
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
|
201 |
+
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
|
202 |
+
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
|
203 |
+
sample['pitch_feat'] = f0
|
204 |
+
yield sample
|
205 |
+
|
206 |
+
|
207 |
+
def parse_embedding(data, normalize, mode='train'):
|
208 |
+
""" Parse utt_embedding/spk_embedding
|
209 |
+
|
210 |
+
Args:
|
211 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
Iterable[{key, feat, label}]
|
215 |
+
"""
|
216 |
+
for sample in data:
|
217 |
+
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
218 |
+
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
219 |
+
if normalize:
|
220 |
+
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
221 |
+
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
222 |
+
yield sample
|
223 |
+
|
224 |
+
|
225 |
+
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
226 |
+
""" Decode text to chars or BPE
|
227 |
+
Inplace operation
|
228 |
+
|
229 |
+
Args:
|
230 |
+
data: Iterable[{key, wav, txt, sample_rate}]
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
234 |
+
"""
|
235 |
+
tokenizer = get_tokenizer()
|
236 |
+
for sample in data:
|
237 |
+
assert 'text' in sample
|
238 |
+
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
239 |
+
if mode == 'inference':
|
240 |
+
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
241 |
+
yield sample
|
242 |
+
|
243 |
+
|
244 |
+
def shuffle(data, shuffle_size=10000, mode='train'):
|
245 |
+
""" Local shuffle the data
|
246 |
+
|
247 |
+
Args:
|
248 |
+
data: Iterable[{key, feat, label}]
|
249 |
+
shuffle_size: buffer size for shuffle
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
Iterable[{key, feat, label}]
|
253 |
+
"""
|
254 |
+
buf = []
|
255 |
+
for sample in data:
|
256 |
+
buf.append(sample)
|
257 |
+
if len(buf) >= shuffle_size:
|
258 |
+
random.shuffle(buf)
|
259 |
+
for x in buf:
|
260 |
+
yield x
|
261 |
+
buf = []
|
262 |
+
# The sample left over
|
263 |
+
random.shuffle(buf)
|
264 |
+
for x in buf:
|
265 |
+
yield x
|
266 |
+
|
267 |
+
|
268 |
+
def sort(data, sort_size=500, mode='train'):
|
269 |
+
""" Sort the data by feature length.
|
270 |
+
Sort is used after shuffle and before batch, so we can group
|
271 |
+
utts with similar lengths into a batch, and `sort_size` should
|
272 |
+
be less than `shuffle_size`
|
273 |
+
|
274 |
+
Args:
|
275 |
+
data: Iterable[{key, feat, label}]
|
276 |
+
sort_size: buffer size for sort
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
Iterable[{key, feat, label}]
|
280 |
+
"""
|
281 |
+
|
282 |
+
buf = []
|
283 |
+
for sample in data:
|
284 |
+
buf.append(sample)
|
285 |
+
if len(buf) >= sort_size:
|
286 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
287 |
+
for x in buf:
|
288 |
+
yield x
|
289 |
+
buf = []
|
290 |
+
# The sample left over
|
291 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
292 |
+
for x in buf:
|
293 |
+
yield x
|
294 |
+
|
295 |
+
|
296 |
+
def static_batch(data, batch_size=16):
|
297 |
+
""" Static batch the data by `batch_size`
|
298 |
+
|
299 |
+
Args:
|
300 |
+
data: Iterable[{key, feat, label}]
|
301 |
+
batch_size: batch size
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
Iterable[List[{key, feat, label}]]
|
305 |
+
"""
|
306 |
+
buf = []
|
307 |
+
for sample in data:
|
308 |
+
buf.append(sample)
|
309 |
+
if len(buf) >= batch_size:
|
310 |
+
yield buf
|
311 |
+
buf = []
|
312 |
+
if len(buf) > 0:
|
313 |
+
yield buf
|
314 |
+
|
315 |
+
|
316 |
+
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
317 |
+
""" Dynamic batch the data until the total frames in batch
|
318 |
+
reach `max_frames_in_batch`
|
319 |
+
|
320 |
+
Args:
|
321 |
+
data: Iterable[{key, feat, label}]
|
322 |
+
max_frames_in_batch: max_frames in one batch
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
Iterable[List[{key, feat, label}]]
|
326 |
+
"""
|
327 |
+
buf = []
|
328 |
+
longest_frames = 0
|
329 |
+
for sample in data:
|
330 |
+
assert 'speech_feat' in sample
|
331 |
+
assert isinstance(sample['speech_feat'], torch.Tensor)
|
332 |
+
new_sample_frames = sample['speech_feat'].size(0)
|
333 |
+
longest_frames = max(longest_frames, new_sample_frames)
|
334 |
+
frames_after_padding = longest_frames * (len(buf) + 1)
|
335 |
+
if frames_after_padding > max_frames_in_batch:
|
336 |
+
yield buf
|
337 |
+
buf = [sample]
|
338 |
+
longest_frames = new_sample_frames
|
339 |
+
else:
|
340 |
+
buf.append(sample)
|
341 |
+
if len(buf) > 0:
|
342 |
+
yield buf
|
343 |
+
|
344 |
+
|
345 |
+
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
346 |
+
""" Wrapper for static/dynamic batch
|
347 |
+
"""
|
348 |
+
if mode == 'inference':
|
349 |
+
return static_batch(data, 1)
|
350 |
+
else:
|
351 |
+
if batch_type == 'static':
|
352 |
+
return static_batch(data, batch_size)
|
353 |
+
elif batch_type == 'dynamic':
|
354 |
+
return dynamic_batch(data, max_frames_in_batch)
|
355 |
+
else:
|
356 |
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
357 |
+
|
358 |
+
|
359 |
+
def padding(data, use_spk_embedding, mode='train', gan=False):
|
360 |
+
""" Padding the data into training data
|
361 |
+
|
362 |
+
Args:
|
363 |
+
data: Iterable[List[{key, feat, label}]]
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
367 |
+
"""
|
368 |
+
for sample in data:
|
369 |
+
assert isinstance(sample, list)
|
370 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
371 |
+
dtype=torch.int32)
|
372 |
+
order = torch.argsort(speech_feat_len, descending=True)
|
373 |
+
|
374 |
+
utts = [sample[i]['utt'] for i in order]
|
375 |
+
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
376 |
+
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
377 |
+
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
378 |
+
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
379 |
+
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
380 |
+
speech_token = pad_sequence(speech_token,
|
381 |
+
batch_first=True,
|
382 |
+
padding_value=0)
|
383 |
+
speech_feat = [sample[i]['speech_feat'] for i in order]
|
384 |
+
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
385 |
+
speech_feat = pad_sequence(speech_feat,
|
386 |
+
batch_first=True,
|
387 |
+
padding_value=0)
|
388 |
+
text = [sample[i]['text'] for i in order]
|
389 |
+
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
390 |
+
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
391 |
+
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
392 |
+
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
393 |
+
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
394 |
+
batch = {
|
395 |
+
"utts": utts,
|
396 |
+
"speech": speech,
|
397 |
+
"speech_len": speech_len,
|
398 |
+
"speech_token": speech_token,
|
399 |
+
"speech_token_len": speech_token_len,
|
400 |
+
"speech_feat": speech_feat,
|
401 |
+
"speech_feat_len": speech_feat_len,
|
402 |
+
"text": text,
|
403 |
+
"text_token": text_token,
|
404 |
+
"text_token_len": text_token_len,
|
405 |
+
"utt_embedding": utt_embedding,
|
406 |
+
"spk_embedding": spk_embedding,
|
407 |
+
}
|
408 |
+
if gan is True:
|
409 |
+
# in gan train, we need pitch_feat
|
410 |
+
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
411 |
+
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
412 |
+
pitch_feat = pad_sequence(pitch_feat,
|
413 |
+
batch_first=True,
|
414 |
+
padding_value=0)
|
415 |
+
batch["pitch_feat"] = pitch_feat
|
416 |
+
batch["pitch_feat_len"] = pitch_feat_len
|
417 |
+
else:
|
418 |
+
# only gan train needs speech, delete it to save memory
|
419 |
+
del batch["speech"]
|
420 |
+
del batch["speech_len"]
|
421 |
+
if mode == 'inference':
|
422 |
+
tts_text = [sample[i]['tts_text'] for i in order]
|
423 |
+
tts_index = [sample[i]['tts_index'] for i in order]
|
424 |
+
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
425 |
+
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
426 |
+
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
427 |
+
batch.update({'tts_text': tts_text,
|
428 |
+
'tts_index': tts_index,
|
429 |
+
'tts_text_token': tts_text_token,
|
430 |
+
'tts_text_token_len': tts_text_token_len})
|
431 |
+
if use_spk_embedding is True:
|
432 |
+
batch["embedding"] = batch["spk_embedding"]
|
433 |
+
else:
|
434 |
+
batch["embedding"] = batch["utt_embedding"]
|
435 |
+
yield batch
|
cosyvoice/hifigan/hifigan.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
|
6 |
+
from cosyvoice.utils.losses import tpr_loss, mel_loss
|
7 |
+
|
8 |
+
|
9 |
+
class HiFiGan(nn.Module):
|
10 |
+
def __init__(self, generator, discriminator, mel_spec_transform,
|
11 |
+
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
|
12 |
+
tpr_loss_weight=1.0, tpr_loss_tau=0.04):
|
13 |
+
super(HiFiGan, self).__init__()
|
14 |
+
self.generator = generator
|
15 |
+
self.discriminator = discriminator
|
16 |
+
self.mel_spec_transform = mel_spec_transform
|
17 |
+
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
|
18 |
+
self.feat_match_loss_weight = feat_match_loss_weight
|
19 |
+
self.tpr_loss_weight = tpr_loss_weight
|
20 |
+
self.tpr_loss_tau = tpr_loss_tau
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
batch: dict,
|
25 |
+
device: torch.device,
|
26 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
27 |
+
if batch['turn'] == 'generator':
|
28 |
+
return self.forward_generator(batch, device)
|
29 |
+
else:
|
30 |
+
return self.forward_discriminator(batch, device)
|
31 |
+
|
32 |
+
def forward_generator(self, batch, device):
|
33 |
+
real_speech = batch['speech'].to(device)
|
34 |
+
pitch_feat = batch['pitch_feat'].to(device)
|
35 |
+
# 1. calculate generator outputs
|
36 |
+
generated_speech, generated_f0 = self.generator(batch, device)
|
37 |
+
# 2. calculate discriminator outputs
|
38 |
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
39 |
+
# 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
|
40 |
+
loss_gen, _ = generator_loss(y_d_gs)
|
41 |
+
loss_fm = feature_loss(fmap_rs, fmap_gs)
|
42 |
+
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
|
43 |
+
if self.tpr_loss_weight != 0:
|
44 |
+
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
45 |
+
else:
|
46 |
+
loss_tpr = torch.zeros(1).to(device)
|
47 |
+
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
48 |
+
loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
|
49 |
+
self.multi_mel_spectral_recon_loss_weight * loss_mel + \
|
50 |
+
self.tpr_loss_weight * loss_tpr + loss_f0
|
51 |
+
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
52 |
+
|
53 |
+
def forward_discriminator(self, batch, device):
|
54 |
+
real_speech = batch['speech'].to(device)
|
55 |
+
# 1. calculate generator outputs
|
56 |
+
with torch.no_grad():
|
57 |
+
generated_speech, generated_f0 = self.generator(batch, device)
|
58 |
+
# 2. calculate discriminator outputs
|
59 |
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
60 |
+
# 3. calculate discriminator losses, tpr losses [Optional]
|
61 |
+
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
62 |
+
if self.tpr_loss_weight != 0:
|
63 |
+
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
64 |
+
else:
|
65 |
+
loss_tpr = torch.zeros(1).to(device)
|
66 |
+
loss = loss_disc + self.tpr_loss_weight * loss_tpr
|
67 |
+
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
|
cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken
ADDED
The diff for this file is too large to render.
See raw diff
|
|
cosyvoice/tokenizer/tokenizer.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
from functools import lru_cache
|
4 |
+
from typing import Optional
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from whisper.tokenizer import Tokenizer
|
8 |
+
|
9 |
+
import tiktoken
|
10 |
+
|
11 |
+
LANGUAGES = {
|
12 |
+
"en": "english",
|
13 |
+
"zh": "chinese",
|
14 |
+
"de": "german",
|
15 |
+
"es": "spanish",
|
16 |
+
"ru": "russian",
|
17 |
+
"ko": "korean",
|
18 |
+
"fr": "french",
|
19 |
+
"ja": "japanese",
|
20 |
+
"pt": "portuguese",
|
21 |
+
"tr": "turkish",
|
22 |
+
"pl": "polish",
|
23 |
+
"ca": "catalan",
|
24 |
+
"nl": "dutch",
|
25 |
+
"ar": "arabic",
|
26 |
+
"sv": "swedish",
|
27 |
+
"it": "italian",
|
28 |
+
"id": "indonesian",
|
29 |
+
"hi": "hindi",
|
30 |
+
"fi": "finnish",
|
31 |
+
"vi": "vietnamese",
|
32 |
+
"he": "hebrew",
|
33 |
+
"uk": "ukrainian",
|
34 |
+
"el": "greek",
|
35 |
+
"ms": "malay",
|
36 |
+
"cs": "czech",
|
37 |
+
"ro": "romanian",
|
38 |
+
"da": "danish",
|
39 |
+
"hu": "hungarian",
|
40 |
+
"ta": "tamil",
|
41 |
+
"no": "norwegian",
|
42 |
+
"th": "thai",
|
43 |
+
"ur": "urdu",
|
44 |
+
"hr": "croatian",
|
45 |
+
"bg": "bulgarian",
|
46 |
+
"lt": "lithuanian",
|
47 |
+
"la": "latin",
|
48 |
+
"mi": "maori",
|
49 |
+
"ml": "malayalam",
|
50 |
+
"cy": "welsh",
|
51 |
+
"sk": "slovak",
|
52 |
+
"te": "telugu",
|
53 |
+
"fa": "persian",
|
54 |
+
"lv": "latvian",
|
55 |
+
"bn": "bengali",
|
56 |
+
"sr": "serbian",
|
57 |
+
"az": "azerbaijani",
|
58 |
+
"sl": "slovenian",
|
59 |
+
"kn": "kannada",
|
60 |
+
"et": "estonian",
|
61 |
+
"mk": "macedonian",
|
62 |
+
"br": "breton",
|
63 |
+
"eu": "basque",
|
64 |
+
"is": "icelandic",
|
65 |
+
"hy": "armenian",
|
66 |
+
"ne": "nepali",
|
67 |
+
"mn": "mongolian",
|
68 |
+
"bs": "bosnian",
|
69 |
+
"kk": "kazakh",
|
70 |
+
"sq": "albanian",
|
71 |
+
"sw": "swahili",
|
72 |
+
"gl": "galician",
|
73 |
+
"mr": "marathi",
|
74 |
+
"pa": "punjabi",
|
75 |
+
"si": "sinhala",
|
76 |
+
"km": "khmer",
|
77 |
+
"sn": "shona",
|
78 |
+
"yo": "yoruba",
|
79 |
+
"so": "somali",
|
80 |
+
"af": "afrikaans",
|
81 |
+
"oc": "occitan",
|
82 |
+
"ka": "georgian",
|
83 |
+
"be": "belarusian",
|
84 |
+
"tg": "tajik",
|
85 |
+
"sd": "sindhi",
|
86 |
+
"gu": "gujarati",
|
87 |
+
"am": "amharic",
|
88 |
+
"yi": "yiddish",
|
89 |
+
"lo": "lao",
|
90 |
+
"uz": "uzbek",
|
91 |
+
"fo": "faroese",
|
92 |
+
"ht": "haitian creole",
|
93 |
+
"ps": "pashto",
|
94 |
+
"tk": "turkmen",
|
95 |
+
"nn": "nynorsk",
|
96 |
+
"mt": "maltese",
|
97 |
+
"sa": "sanskrit",
|
98 |
+
"lb": "luxembourgish",
|
99 |
+
"my": "myanmar",
|
100 |
+
"bo": "tibetan",
|
101 |
+
"tl": "tagalog",
|
102 |
+
"mg": "malagasy",
|
103 |
+
"as": "assamese",
|
104 |
+
"tt": "tatar",
|
105 |
+
"haw": "hawaiian",
|
106 |
+
"ln": "lingala",
|
107 |
+
"ha": "hausa",
|
108 |
+
"ba": "bashkir",
|
109 |
+
"jw": "javanese",
|
110 |
+
"su": "sundanese",
|
111 |
+
"yue": "cantonese",
|
112 |
+
"minnan": "minnan",
|
113 |
+
"wuyu": "wuyu",
|
114 |
+
"dialect": "dialect",
|
115 |
+
"zh/en": "zh/en",
|
116 |
+
"en/zh": "en/zh",
|
117 |
+
}
|
118 |
+
|
119 |
+
# language code lookup by name, with a few language aliases
|
120 |
+
TO_LANGUAGE_CODE = {
|
121 |
+
**{language: code for code, language in LANGUAGES.items()},
|
122 |
+
"burmese": "my",
|
123 |
+
"valencian": "ca",
|
124 |
+
"flemish": "nl",
|
125 |
+
"haitian": "ht",
|
126 |
+
"letzeburgesch": "lb",
|
127 |
+
"pushto": "ps",
|
128 |
+
"panjabi": "pa",
|
129 |
+
"moldavian": "ro",
|
130 |
+
"moldovan": "ro",
|
131 |
+
"sinhalese": "si",
|
132 |
+
"castilian": "es",
|
133 |
+
"mandarin": "zh",
|
134 |
+
}
|
135 |
+
|
136 |
+
AUDIO_EVENT = {
|
137 |
+
"ASR": "ASR",
|
138 |
+
"AED": "AED",
|
139 |
+
"SER": "SER",
|
140 |
+
"Speech": "Speech",
|
141 |
+
"/Speech": "/Speech",
|
142 |
+
"BGM": "BGM",
|
143 |
+
"/BGM": "/BGM",
|
144 |
+
"Laughter": "Laughter",
|
145 |
+
"/Laughter": "/Laughter",
|
146 |
+
"Applause": "Applause",
|
147 |
+
"/Applause": "/Applause",
|
148 |
+
}
|
149 |
+
|
150 |
+
EMOTION = {
|
151 |
+
"HAPPY": "HAPPY",
|
152 |
+
"SAD": "SAD",
|
153 |
+
"ANGRY": "ANGRY",
|
154 |
+
"NEUTRAL": "NEUTRAL",
|
155 |
+
}
|
156 |
+
|
157 |
+
TTS_Vocal_Token = {
|
158 |
+
"TTS/B": "TTS/B",
|
159 |
+
"TTS/O": "TTS/O",
|
160 |
+
"TTS/Q": "TTS/Q",
|
161 |
+
"TTS/A": "TTS/A",
|
162 |
+
"TTS/CO": "TTS/CO",
|
163 |
+
"TTS/CL": "TTS/CL",
|
164 |
+
"TTS/H": "TTS/H",
|
165 |
+
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
|
166 |
+
}
|
167 |
+
|
168 |
+
|
169 |
+
@lru_cache(maxsize=None)
|
170 |
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
171 |
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
172 |
+
ranks = {
|
173 |
+
base64.b64decode(token): int(rank)
|
174 |
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
175 |
+
}
|
176 |
+
n_vocab = len(ranks)
|
177 |
+
special_tokens = {}
|
178 |
+
|
179 |
+
specials = [
|
180 |
+
"<|endoftext|>",
|
181 |
+
"<|startoftranscript|>",
|
182 |
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
183 |
+
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
184 |
+
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
185 |
+
"<|translate|>",
|
186 |
+
"<|transcribe|>",
|
187 |
+
"<|startoflm|>",
|
188 |
+
"<|startofprev|>",
|
189 |
+
"<|nospeech|>",
|
190 |
+
"<|notimestamps|>",
|
191 |
+
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
|
192 |
+
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
|
193 |
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
194 |
+
]
|
195 |
+
|
196 |
+
for token in specials:
|
197 |
+
special_tokens[token] = n_vocab
|
198 |
+
n_vocab += 1
|
199 |
+
|
200 |
+
return tiktoken.Encoding(
|
201 |
+
name=os.path.basename(vocab_path),
|
202 |
+
explicit_n_vocab=n_vocab,
|
203 |
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
204 |
+
mergeable_ranks=ranks,
|
205 |
+
special_tokens=special_tokens,
|
206 |
+
)
|
207 |
+
|
208 |
+
|
209 |
+
@lru_cache(maxsize=None)
|
210 |
+
def get_tokenizer(
|
211 |
+
multilingual: bool,
|
212 |
+
*,
|
213 |
+
num_languages: int = 99,
|
214 |
+
language: Optional[str] = None,
|
215 |
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
216 |
+
) -> Tokenizer:
|
217 |
+
if language is not None:
|
218 |
+
language = language.lower()
|
219 |
+
if language not in LANGUAGES:
|
220 |
+
if language in TO_LANGUAGE_CODE:
|
221 |
+
language = TO_LANGUAGE_CODE[language]
|
222 |
+
else:
|
223 |
+
raise ValueError(f"Unsupported language: {language}")
|
224 |
+
|
225 |
+
if multilingual:
|
226 |
+
encoding_name = "multilingual_zh_ja_yue_char_del"
|
227 |
+
language = language or "en"
|
228 |
+
task = task or "transcribe"
|
229 |
+
else:
|
230 |
+
encoding_name = "gpt2"
|
231 |
+
language = None
|
232 |
+
task = None
|
233 |
+
|
234 |
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
235 |
+
|
236 |
+
return Tokenizer(
|
237 |
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
class QwenTokenizer():
|
242 |
+
def __init__(self, token_path, skip_special_tokens=True):
|
243 |
+
super().__init__()
|
244 |
+
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
245 |
+
special_tokens = {
|
246 |
+
'eos_token': '<|endoftext|>',
|
247 |
+
'pad_token': '<|endoftext|>',
|
248 |
+
'additional_special_tokens': [
|
249 |
+
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
250 |
+
'[breath]', '<strong>', '</strong>', '[noise]',
|
251 |
+
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
252 |
+
'[quick_breath]',
|
253 |
+
"<laughter>", "</laughter>",
|
254 |
+
"[hissing]", "[sigh]", "[vocalized-noise]",
|
255 |
+
"[lipsmack]", "[mn]"
|
256 |
+
]
|
257 |
+
}
|
258 |
+
self.special_tokens = special_tokens
|
259 |
+
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
260 |
+
self.tokenizer.add_special_tokens(special_tokens)
|
261 |
+
self.skip_special_tokens = skip_special_tokens
|
262 |
+
|
263 |
+
def encode(self, text, **kwargs):
|
264 |
+
tokens = self.tokenizer([text], return_tensors="pt")
|
265 |
+
tokens = tokens["input_ids"][0].cpu().tolist()
|
266 |
+
return tokens
|
267 |
+
|
268 |
+
def decode(self, tokens):
|
269 |
+
tokens = torch.tensor(tokens, dtype=torch.int64)
|
270 |
+
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
271 |
+
return text
|
272 |
+
|
273 |
+
|
274 |
+
@lru_cache(maxsize=None)
|
275 |
+
def get_qwen_tokenizer(
|
276 |
+
token_path: str,
|
277 |
+
skip_special_tokens: bool
|
278 |
+
) -> QwenTokenizer:
|
279 |
+
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
cosyvoice/transformer/activation.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
|
3 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
4 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""Swish() activation function for Conformer."""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn, sin, pow
|
21 |
+
from torch.nn import Parameter
|
22 |
+
|
23 |
+
|
24 |
+
class Swish(torch.nn.Module):
|
25 |
+
"""Construct an Swish object."""
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
"""Return Swish activation function."""
|
29 |
+
return x * torch.sigmoid(x)
|
30 |
+
|
31 |
+
|
32 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
33 |
+
# LICENSE is in incl_licenses directory.
|
34 |
+
class Snake(nn.Module):
|
35 |
+
'''
|
36 |
+
Implementation of a sine-based periodic activation function
|
37 |
+
Shape:
|
38 |
+
- Input: (B, C, T)
|
39 |
+
- Output: (B, C, T), same shape as the input
|
40 |
+
Parameters:
|
41 |
+
- alpha - trainable parameter
|
42 |
+
References:
|
43 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
44 |
+
https://arxiv.org/abs/2006.08195
|
45 |
+
Examples:
|
46 |
+
>>> a1 = snake(256)
|
47 |
+
>>> x = torch.randn(256)
|
48 |
+
>>> x = a1(x)
|
49 |
+
'''
|
50 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
51 |
+
'''
|
52 |
+
Initialization.
|
53 |
+
INPUT:
|
54 |
+
- in_features: shape of the input
|
55 |
+
- alpha: trainable parameter
|
56 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
57 |
+
alpha will be trained along with the rest of your model.
|
58 |
+
'''
|
59 |
+
super(Snake, self).__init__()
|
60 |
+
self.in_features = in_features
|
61 |
+
|
62 |
+
# initialize alpha
|
63 |
+
self.alpha_logscale = alpha_logscale
|
64 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
65 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
66 |
+
else: # linear scale alphas initialized to ones
|
67 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
68 |
+
|
69 |
+
self.alpha.requires_grad = alpha_trainable
|
70 |
+
|
71 |
+
self.no_div_by_zero = 0.000000001
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
'''
|
75 |
+
Forward pass of the function.
|
76 |
+
Applies the function to the input elementwise.
|
77 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
78 |
+
'''
|
79 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
80 |
+
if self.alpha_logscale:
|
81 |
+
alpha = torch.exp(alpha)
|
82 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
83 |
+
|
84 |
+
return x
|
cosyvoice/transformer/embedding.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Positonal Encoding Module."""
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
|
26 |
+
class PositionalEncoding(torch.nn.Module):
|
27 |
+
"""Positional encoding.
|
28 |
+
|
29 |
+
:param int d_model: embedding dim
|
30 |
+
:param float dropout_rate: dropout rate
|
31 |
+
:param int max_len: maximum input length
|
32 |
+
|
33 |
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
34 |
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
d_model: int,
|
39 |
+
dropout_rate: float,
|
40 |
+
max_len: int = 5000,
|
41 |
+
reverse: bool = False):
|
42 |
+
"""Construct an PositionalEncoding object."""
|
43 |
+
super().__init__()
|
44 |
+
self.d_model = d_model
|
45 |
+
self.xscale = math.sqrt(self.d_model)
|
46 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
47 |
+
self.max_len = max_len
|
48 |
+
|
49 |
+
self.pe = torch.zeros(self.max_len, self.d_model)
|
50 |
+
position = torch.arange(0, self.max_len,
|
51 |
+
dtype=torch.float32).unsqueeze(1)
|
52 |
+
div_term = torch.exp(
|
53 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
|
54 |
+
-(math.log(10000.0) / self.d_model))
|
55 |
+
self.pe[:, 0::2] = torch.sin(position * div_term)
|
56 |
+
self.pe[:, 1::2] = torch.cos(position * div_term)
|
57 |
+
self.pe = self.pe.unsqueeze(0)
|
58 |
+
|
59 |
+
def forward(self,
|
60 |
+
x: torch.Tensor,
|
61 |
+
offset: Union[int, torch.Tensor] = 0) \
|
62 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
63 |
+
"""Add positional encoding.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
67 |
+
offset (int, torch.tensor): position offset
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
71 |
+
torch.Tensor: for compatibility to RelPositionalEncoding
|
72 |
+
"""
|
73 |
+
|
74 |
+
self.pe = self.pe.to(x.device)
|
75 |
+
pos_emb = self.position_encoding(offset, x.size(1), False)
|
76 |
+
x = x * self.xscale + pos_emb
|
77 |
+
return self.dropout(x), self.dropout(pos_emb)
|
78 |
+
|
79 |
+
def position_encoding(self,
|
80 |
+
offset: Union[int, torch.Tensor],
|
81 |
+
size: int,
|
82 |
+
apply_dropout: bool = True) -> torch.Tensor:
|
83 |
+
""" For getting encoding in a streaming fashion
|
84 |
+
|
85 |
+
Attention!!!!!
|
86 |
+
we apply dropout only once at the whole utterance level in a none
|
87 |
+
streaming way, but will call this function several times with
|
88 |
+
increasing input size in a streaming scenario, so the dropout will
|
89 |
+
be applied several times.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
offset (int or torch.tensor): start offset
|
93 |
+
size (int): required size of position encoding
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
torch.Tensor: Corresponding encoding
|
97 |
+
"""
|
98 |
+
# How to subscript a Union type:
|
99 |
+
# https://github.com/pytorch/pytorch/issues/69434
|
100 |
+
if isinstance(offset, int):
|
101 |
+
assert offset + size <= self.max_len
|
102 |
+
pos_emb = self.pe[:, offset:offset + size]
|
103 |
+
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
|
104 |
+
assert offset + size <= self.max_len
|
105 |
+
pos_emb = self.pe[:, offset:offset + size]
|
106 |
+
else: # for batched streaming decoding on GPU
|
107 |
+
assert torch.max(offset) + size <= self.max_len
|
108 |
+
index = offset.unsqueeze(1) + \
|
109 |
+
torch.arange(0, size).to(offset.device) # B X T
|
110 |
+
flag = index > 0
|
111 |
+
# remove negative offset
|
112 |
+
index = index * flag
|
113 |
+
pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
|
114 |
+
|
115 |
+
if apply_dropout:
|
116 |
+
pos_emb = self.dropout(pos_emb)
|
117 |
+
return pos_emb
|
118 |
+
|
119 |
+
|
120 |
+
class RelPositionalEncoding(PositionalEncoding):
|
121 |
+
"""Relative positional encoding module.
|
122 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
123 |
+
Args:
|
124 |
+
d_model (int): Embedding dimension.
|
125 |
+
dropout_rate (float): Dropout rate.
|
126 |
+
max_len (int): Maximum input length.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
130 |
+
"""Initialize class."""
|
131 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
132 |
+
|
133 |
+
def forward(self,
|
134 |
+
x: torch.Tensor,
|
135 |
+
offset: Union[int, torch.Tensor] = 0) \
|
136 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
137 |
+
"""Compute positional encoding.
|
138 |
+
Args:
|
139 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
140 |
+
Returns:
|
141 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
142 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
143 |
+
"""
|
144 |
+
self.pe = self.pe.to(x.device)
|
145 |
+
x = x * self.xscale
|
146 |
+
pos_emb = self.position_encoding(offset, x.size(1), False)
|
147 |
+
return self.dropout(x), self.dropout(pos_emb)
|
148 |
+
|
149 |
+
|
150 |
+
class WhisperPositionalEncoding(PositionalEncoding):
|
151 |
+
""" Sinusoids position encoding used in openai-whisper.encoder
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
|
155 |
+
super().__init__(d_model, dropout_rate, max_len)
|
156 |
+
self.xscale = 1.0
|
157 |
+
log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
|
158 |
+
inv_timescales = torch.exp(-log_timescale_increment *
|
159 |
+
torch.arange(d_model // 2))
|
160 |
+
scaled_time = torch.arange(max_len)[:, np.newaxis] * \
|
161 |
+
inv_timescales[np.newaxis, :]
|
162 |
+
pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
163 |
+
delattr(self, "pe")
|
164 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
165 |
+
|
166 |
+
|
167 |
+
class LearnablePositionalEncoding(PositionalEncoding):
|
168 |
+
""" Learnable position encoding used in openai-whisper.decoder
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
|
172 |
+
super().__init__(d_model, dropout_rate, max_len)
|
173 |
+
# NOTE(xcsong): overwrite self.pe & self.xscale
|
174 |
+
self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
|
175 |
+
self.xscale = 1.0
|
176 |
+
|
177 |
+
|
178 |
+
class NoPositionalEncoding(torch.nn.Module):
|
179 |
+
""" No position encoding
|
180 |
+
"""
|
181 |
+
|
182 |
+
def __init__(self, d_model: int, dropout_rate: float):
|
183 |
+
super().__init__()
|
184 |
+
self.d_model = d_model
|
185 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
186 |
+
|
187 |
+
def forward(self,
|
188 |
+
x: torch.Tensor,
|
189 |
+
offset: Union[int, torch.Tensor] = 0) \
|
190 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
191 |
+
""" Just return zero vector for interface compatibility
|
192 |
+
"""
|
193 |
+
pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
|
194 |
+
return self.dropout(x), pos_emb
|
195 |
+
|
196 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
197 |
+
size: int) -> torch.Tensor:
|
198 |
+
return torch.zeros(1, size, self.d_model)
|
199 |
+
|
200 |
+
|
201 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
202 |
+
"""Relative positional encoding module (new implementation).
|
203 |
+
|
204 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
205 |
+
|
206 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
207 |
+
|
208 |
+
Args:
|
209 |
+
d_model (int): Embedding dimension.
|
210 |
+
dropout_rate (float): Dropout rate.
|
211 |
+
max_len (int): Maximum input length.
|
212 |
+
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
216 |
+
"""Construct an PositionalEncoding object."""
|
217 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
218 |
+
self.d_model = d_model
|
219 |
+
self.xscale = math.sqrt(self.d_model)
|
220 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
221 |
+
self.pe = None
|
222 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
223 |
+
|
224 |
+
def extend_pe(self, x: torch.Tensor):
|
225 |
+
"""Reset the positional encodings."""
|
226 |
+
if self.pe is not None:
|
227 |
+
# self.pe contains both positive and negative parts
|
228 |
+
# the length of self.pe is 2 * input_len - 1
|
229 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
230 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
231 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
232 |
+
return
|
233 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
234 |
+
# position of key vector. We use position relative positions when keys
|
235 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
236 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
237 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
238 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
239 |
+
div_term = torch.exp(
|
240 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
241 |
+
* -(math.log(10000.0) / self.d_model)
|
242 |
+
)
|
243 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
244 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
245 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
246 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
247 |
+
|
248 |
+
# Reserve the order of positive indices and concat both positive and
|
249 |
+
# negative indices. This is used to support the shifting trick
|
250 |
+
# as in https://arxiv.org/abs/1901.02860
|
251 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
252 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
253 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
254 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
255 |
+
|
256 |
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
257 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
258 |
+
"""Add positional encoding.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
265 |
+
|
266 |
+
"""
|
267 |
+
self.extend_pe(x)
|
268 |
+
x = x * self.xscale
|
269 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
270 |
+
return self.dropout(x), self.dropout(pos_emb)
|
271 |
+
|
272 |
+
def position_encoding(self,
|
273 |
+
offset: Union[int, torch.Tensor],
|
274 |
+
size: int) -> torch.Tensor:
|
275 |
+
""" For getting encoding in a streaming fashion
|
276 |
+
|
277 |
+
Attention!!!!!
|
278 |
+
we apply dropout only once at the whole utterance level in a none
|
279 |
+
streaming way, but will call this function several times with
|
280 |
+
increasing input size in a streaming scenario, so the dropout will
|
281 |
+
be applied several times.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
offset (int or torch.tensor): start offset
|
285 |
+
size (int): required size of position encoding
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
torch.Tensor: Corresponding encoding
|
289 |
+
"""
|
290 |
+
pos_emb = self.pe[
|
291 |
+
:,
|
292 |
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
293 |
+
]
|
294 |
+
return pos_emb
|
cosyvoice/transformer/encoder.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
17 |
+
"""Encoder definition."""
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.utils.checkpoint as ckpt
|
22 |
+
|
23 |
+
from cosyvoice.transformer.convolution import ConvolutionModule
|
24 |
+
from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
|
25 |
+
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
26 |
+
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
27 |
+
from cosyvoice.utils.class_utils import (
|
28 |
+
COSYVOICE_EMB_CLASSES,
|
29 |
+
COSYVOICE_SUBSAMPLE_CLASSES,
|
30 |
+
COSYVOICE_ATTENTION_CLASSES,
|
31 |
+
COSYVOICE_ACTIVATION_CLASSES,
|
32 |
+
)
|
33 |
+
from cosyvoice.utils.mask import make_pad_mask
|
34 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
35 |
+
|
36 |
+
|
37 |
+
class BaseEncoder(torch.nn.Module):
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
input_size: int,
|
42 |
+
output_size: int = 256,
|
43 |
+
attention_heads: int = 4,
|
44 |
+
linear_units: int = 2048,
|
45 |
+
num_blocks: int = 6,
|
46 |
+
dropout_rate: float = 0.1,
|
47 |
+
positional_dropout_rate: float = 0.1,
|
48 |
+
attention_dropout_rate: float = 0.0,
|
49 |
+
input_layer: str = "conv2d",
|
50 |
+
pos_enc_layer_type: str = "abs_pos",
|
51 |
+
normalize_before: bool = True,
|
52 |
+
static_chunk_size: int = 0,
|
53 |
+
use_dynamic_chunk: bool = False,
|
54 |
+
global_cmvn: torch.nn.Module = None,
|
55 |
+
use_dynamic_left_chunk: bool = False,
|
56 |
+
gradient_checkpointing: bool = False,
|
57 |
+
):
|
58 |
+
"""
|
59 |
+
Args:
|
60 |
+
input_size (int): input dim
|
61 |
+
output_size (int): dimension of attention
|
62 |
+
attention_heads (int): the number of heads of multi head attention
|
63 |
+
linear_units (int): the hidden units number of position-wise feed
|
64 |
+
forward
|
65 |
+
num_blocks (int): the number of decoder blocks
|
66 |
+
dropout_rate (float): dropout rate
|
67 |
+
attention_dropout_rate (float): dropout rate in attention
|
68 |
+
positional_dropout_rate (float): dropout rate after adding
|
69 |
+
positional encoding
|
70 |
+
input_layer (str): input layer type.
|
71 |
+
optional [linear, conv2d, conv2d6, conv2d8]
|
72 |
+
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
73 |
+
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
74 |
+
normalize_before (bool):
|
75 |
+
True: use layer_norm before each sub-block of a layer.
|
76 |
+
False: use layer_norm after each sub-block of a layer.
|
77 |
+
static_chunk_size (int): chunk size for static chunk training and
|
78 |
+
decoding
|
79 |
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
80 |
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
81 |
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
82 |
+
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
83 |
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
84 |
+
dynamic chunk training
|
85 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
86 |
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
87 |
+
checkpointed segment during backward.
|
88 |
+
"""
|
89 |
+
super().__init__()
|
90 |
+
self._output_size = output_size
|
91 |
+
|
92 |
+
self.global_cmvn = global_cmvn
|
93 |
+
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
94 |
+
input_size,
|
95 |
+
output_size,
|
96 |
+
dropout_rate,
|
97 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
98 |
+
positional_dropout_rate),
|
99 |
+
)
|
100 |
+
|
101 |
+
self.normalize_before = normalize_before
|
102 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
103 |
+
self.static_chunk_size = static_chunk_size
|
104 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
105 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
106 |
+
self.gradient_checkpointing = gradient_checkpointing
|
107 |
+
|
108 |
+
def output_size(self) -> int:
|
109 |
+
return self._output_size
|
110 |
+
|
111 |
+
def forward(
|
112 |
+
self,
|
113 |
+
xs: torch.Tensor,
|
114 |
+
xs_lens: torch.Tensor,
|
115 |
+
decoding_chunk_size: int = 0,
|
116 |
+
num_decoding_left_chunks: int = -1,
|
117 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
118 |
+
"""Embed positions in tensor.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
xs: padded input tensor (B, T, D)
|
122 |
+
xs_lens: input length (B)
|
123 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
124 |
+
0: default for training, use random dynamic chunk.
|
125 |
+
<0: for decoding, use full chunk.
|
126 |
+
>0: for decoding, use fixed chunk size as set.
|
127 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
128 |
+
the chunk size is decoding_chunk_size.
|
129 |
+
>=0: use num_decoding_left_chunks
|
130 |
+
<0: use all left chunks
|
131 |
+
Returns:
|
132 |
+
encoder output tensor xs, and subsampled masks
|
133 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
134 |
+
masks: torch.Tensor batch padding mask after subsample
|
135 |
+
(B, 1, T' ~= T/subsample_rate)
|
136 |
+
NOTE(xcsong):
|
137 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
138 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
139 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
140 |
+
"""
|
141 |
+
T = xs.size(1)
|
142 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
143 |
+
if self.global_cmvn is not None:
|
144 |
+
xs = self.global_cmvn(xs)
|
145 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
146 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
147 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
148 |
+
self.use_dynamic_chunk,
|
149 |
+
self.use_dynamic_left_chunk,
|
150 |
+
decoding_chunk_size,
|
151 |
+
self.static_chunk_size,
|
152 |
+
num_decoding_left_chunks)
|
153 |
+
if self.gradient_checkpointing and self.training:
|
154 |
+
xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
|
155 |
+
mask_pad)
|
156 |
+
else:
|
157 |
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
158 |
+
if self.normalize_before:
|
159 |
+
xs = self.after_norm(xs)
|
160 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
161 |
+
# return the masks before encoder layers, and the masks will be used
|
162 |
+
# for cross attention with decoder later
|
163 |
+
return xs, masks
|
164 |
+
|
165 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
166 |
+
pos_emb: torch.Tensor,
|
167 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
168 |
+
for layer in self.encoders:
|
169 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
170 |
+
return xs
|
171 |
+
|
172 |
+
@torch.jit.unused
|
173 |
+
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
174 |
+
chunk_masks: torch.Tensor,
|
175 |
+
pos_emb: torch.Tensor,
|
176 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
177 |
+
for layer in self.encoders:
|
178 |
+
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
|
179 |
+
chunk_masks, pos_emb,
|
180 |
+
mask_pad)
|
181 |
+
return xs
|
182 |
+
|
183 |
+
@torch.jit.export
|
184 |
+
def forward_chunk(
|
185 |
+
self,
|
186 |
+
xs: torch.Tensor,
|
187 |
+
offset: int,
|
188 |
+
required_cache_size: int,
|
189 |
+
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
190 |
+
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
191 |
+
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
192 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
193 |
+
""" Forward just one chunk
|
194 |
+
|
195 |
+
Args:
|
196 |
+
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
|
197 |
+
where `time == (chunk_size - 1) * subsample_rate + \
|
198 |
+
subsample.right_context + 1`
|
199 |
+
offset (int): current offset in encoder output time stamp
|
200 |
+
required_cache_size (int): cache size required for next chunk
|
201 |
+
compuation
|
202 |
+
>=0: actual cache size
|
203 |
+
<0: means all history cache is required
|
204 |
+
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
|
205 |
+
transformer/conformer attention, with shape
|
206 |
+
(elayers, head, cache_t1, d_k * 2), where
|
207 |
+
`head * d_k == hidden-dim` and
|
208 |
+
`cache_t1 == chunk_size * num_decoding_left_chunks`.
|
209 |
+
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
|
210 |
+
(elayers, b=1, hidden-dim, cache_t2), where
|
211 |
+
`cache_t2 == cnn.lorder - 1`
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
torch.Tensor: output of current input xs,
|
215 |
+
with shape (b=1, chunk_size, hidden-dim).
|
216 |
+
torch.Tensor: new attention cache required for next chunk, with
|
217 |
+
dynamic shape (elayers, head, ?, d_k * 2)
|
218 |
+
depending on required_cache_size.
|
219 |
+
torch.Tensor: new conformer cnn cache required for next chunk, with
|
220 |
+
same shape as the original cnn_cache.
|
221 |
+
|
222 |
+
"""
|
223 |
+
assert xs.size(0) == 1
|
224 |
+
# tmp_masks is just for interface compatibility
|
225 |
+
tmp_masks = torch.ones(1,
|
226 |
+
xs.size(1),
|
227 |
+
device=xs.device,
|
228 |
+
dtype=torch.bool)
|
229 |
+
tmp_masks = tmp_masks.unsqueeze(1)
|
230 |
+
if self.global_cmvn is not None:
|
231 |
+
xs = self.global_cmvn(xs)
|
232 |
+
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
|
233 |
+
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
|
234 |
+
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
|
235 |
+
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
|
236 |
+
chunk_size = xs.size(1)
|
237 |
+
attention_key_size = cache_t1 + chunk_size
|
238 |
+
pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
|
239 |
+
size=attention_key_size)
|
240 |
+
if required_cache_size < 0:
|
241 |
+
next_cache_start = 0
|
242 |
+
elif required_cache_size == 0:
|
243 |
+
next_cache_start = attention_key_size
|
244 |
+
else:
|
245 |
+
next_cache_start = max(attention_key_size - required_cache_size, 0)
|
246 |
+
r_att_cache = []
|
247 |
+
r_cnn_cache = []
|
248 |
+
for i, layer in enumerate(self.encoders):
|
249 |
+
# NOTE(xcsong): Before layer.forward
|
250 |
+
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
|
251 |
+
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
|
252 |
+
xs, _, new_att_cache, new_cnn_cache = layer(
|
253 |
+
xs,
|
254 |
+
att_mask,
|
255 |
+
pos_emb,
|
256 |
+
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
257 |
+
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
|
258 |
+
# NOTE(xcsong): After layer.forward
|
259 |
+
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
|
260 |
+
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
|
261 |
+
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
|
262 |
+
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
|
263 |
+
if self.normalize_before:
|
264 |
+
xs = self.after_norm(xs)
|
265 |
+
|
266 |
+
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
|
267 |
+
# ? may be larger than cache_t1, it depends on required_cache_size
|
268 |
+
r_att_cache = torch.cat(r_att_cache, dim=0)
|
269 |
+
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
|
270 |
+
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
|
271 |
+
|
272 |
+
return (xs, r_att_cache, r_cnn_cache)
|
273 |
+
|
274 |
+
@torch.jit.unused
|
275 |
+
def forward_chunk_by_chunk(
|
276 |
+
self,
|
277 |
+
xs: torch.Tensor,
|
278 |
+
decoding_chunk_size: int,
|
279 |
+
num_decoding_left_chunks: int = -1,
|
280 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
281 |
+
""" Forward input chunk by chunk with chunk_size like a streaming
|
282 |
+
fashion
|
283 |
+
|
284 |
+
Here we should pay special attention to computation cache in the
|
285 |
+
streaming style forward chunk by chunk. Three things should be taken
|
286 |
+
into account for computation in the current network:
|
287 |
+
1. transformer/conformer encoder layers output cache
|
288 |
+
2. convolution in conformer
|
289 |
+
3. convolution in subsampling
|
290 |
+
|
291 |
+
However, we don't implement subsampling cache for:
|
292 |
+
1. We can control subsampling module to output the right result by
|
293 |
+
overlapping input instead of cache left context, even though it
|
294 |
+
wastes some computation, but subsampling only takes a very
|
295 |
+
small fraction of computation in the whole model.
|
296 |
+
2. Typically, there are several covolution layers with subsampling
|
297 |
+
in subsampling module, it is tricky and complicated to do cache
|
298 |
+
with different convolution layers with different subsampling
|
299 |
+
rate.
|
300 |
+
3. Currently, nn.Sequential is used to stack all the convolution
|
301 |
+
layers in subsampling, we need to rewrite it to make it work
|
302 |
+
with cache, which is not preferred.
|
303 |
+
Args:
|
304 |
+
xs (torch.Tensor): (1, max_len, dim)
|
305 |
+
chunk_size (int): decoding chunk size
|
306 |
+
"""
|
307 |
+
assert decoding_chunk_size > 0
|
308 |
+
# The model is trained by static or dynamic chunk
|
309 |
+
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
|
310 |
+
subsampling = self.embed.subsampling_rate
|
311 |
+
context = self.embed.right_context + 1 # Add current frame
|
312 |
+
stride = subsampling * decoding_chunk_size
|
313 |
+
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
314 |
+
num_frames = xs.size(1)
|
315 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
316 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
317 |
+
outputs = []
|
318 |
+
offset = 0
|
319 |
+
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
320 |
+
|
321 |
+
# Feed forward overlap input step by step
|
322 |
+
for cur in range(0, num_frames - context + 1, stride):
|
323 |
+
end = min(cur + decoding_window, num_frames)
|
324 |
+
chunk_xs = xs[:, cur:end, :]
|
325 |
+
(y, att_cache,
|
326 |
+
cnn_cache) = self.forward_chunk(chunk_xs, offset,
|
327 |
+
required_cache_size, att_cache,
|
328 |
+
cnn_cache)
|
329 |
+
outputs.append(y)
|
330 |
+
offset += y.size(1)
|
331 |
+
ys = torch.cat(outputs, 1)
|
332 |
+
masks = torch.ones((1, 1, ys.size(1)),
|
333 |
+
device=ys.device,
|
334 |
+
dtype=torch.bool)
|
335 |
+
return ys, masks
|
336 |
+
|
337 |
+
|
338 |
+
class TransformerEncoder(BaseEncoder):
|
339 |
+
"""Transformer encoder module."""
|
340 |
+
|
341 |
+
def __init__(
|
342 |
+
self,
|
343 |
+
input_size: int,
|
344 |
+
output_size: int = 256,
|
345 |
+
attention_heads: int = 4,
|
346 |
+
linear_units: int = 2048,
|
347 |
+
num_blocks: int = 6,
|
348 |
+
dropout_rate: float = 0.1,
|
349 |
+
positional_dropout_rate: float = 0.1,
|
350 |
+
attention_dropout_rate: float = 0.0,
|
351 |
+
input_layer: str = "conv2d",
|
352 |
+
pos_enc_layer_type: str = "abs_pos",
|
353 |
+
normalize_before: bool = True,
|
354 |
+
static_chunk_size: int = 0,
|
355 |
+
use_dynamic_chunk: bool = False,
|
356 |
+
global_cmvn: torch.nn.Module = None,
|
357 |
+
use_dynamic_left_chunk: bool = False,
|
358 |
+
key_bias: bool = True,
|
359 |
+
selfattention_layer_type: str = "selfattn",
|
360 |
+
activation_type: str = "relu",
|
361 |
+
gradient_checkpointing: bool = False,
|
362 |
+
):
|
363 |
+
""" Construct TransformerEncoder
|
364 |
+
|
365 |
+
See Encoder for the meaning of each parameter.
|
366 |
+
"""
|
367 |
+
super().__init__(input_size, output_size, attention_heads,
|
368 |
+
linear_units, num_blocks, dropout_rate,
|
369 |
+
positional_dropout_rate, attention_dropout_rate,
|
370 |
+
input_layer, pos_enc_layer_type, normalize_before,
|
371 |
+
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
372 |
+
use_dynamic_left_chunk, gradient_checkpointing)
|
373 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
374 |
+
self.encoders = torch.nn.ModuleList([
|
375 |
+
TransformerEncoderLayer(
|
376 |
+
output_size,
|
377 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
|
378 |
+
output_size,
|
379 |
+
attention_dropout_rate,
|
380 |
+
key_bias),
|
381 |
+
PositionwiseFeedForward(output_size, linear_units,
|
382 |
+
dropout_rate, activation),
|
383 |
+
dropout_rate, normalize_before) for _ in range(num_blocks)
|
384 |
+
])
|
385 |
+
|
386 |
+
|
387 |
+
class ConformerEncoder(BaseEncoder):
|
388 |
+
"""Conformer encoder module."""
|
389 |
+
|
390 |
+
def __init__(
|
391 |
+
self,
|
392 |
+
input_size: int,
|
393 |
+
output_size: int = 256,
|
394 |
+
attention_heads: int = 4,
|
395 |
+
linear_units: int = 2048,
|
396 |
+
num_blocks: int = 6,
|
397 |
+
dropout_rate: float = 0.1,
|
398 |
+
positional_dropout_rate: float = 0.1,
|
399 |
+
attention_dropout_rate: float = 0.0,
|
400 |
+
input_layer: str = "conv2d",
|
401 |
+
pos_enc_layer_type: str = "rel_pos",
|
402 |
+
normalize_before: bool = True,
|
403 |
+
static_chunk_size: int = 0,
|
404 |
+
use_dynamic_chunk: bool = False,
|
405 |
+
global_cmvn: torch.nn.Module = None,
|
406 |
+
use_dynamic_left_chunk: bool = False,
|
407 |
+
positionwise_conv_kernel_size: int = 1,
|
408 |
+
macaron_style: bool = True,
|
409 |
+
selfattention_layer_type: str = "rel_selfattn",
|
410 |
+
activation_type: str = "swish",
|
411 |
+
use_cnn_module: bool = True,
|
412 |
+
cnn_module_kernel: int = 15,
|
413 |
+
causal: bool = False,
|
414 |
+
cnn_module_norm: str = "batch_norm",
|
415 |
+
key_bias: bool = True,
|
416 |
+
gradient_checkpointing: bool = False,
|
417 |
+
):
|
418 |
+
"""Construct ConformerEncoder
|
419 |
+
|
420 |
+
Args:
|
421 |
+
input_size to use_dynamic_chunk, see in BaseEncoder
|
422 |
+
positionwise_conv_kernel_size (int): Kernel size of positionwise
|
423 |
+
conv1d layer.
|
424 |
+
macaron_style (bool): Whether to use macaron style for
|
425 |
+
positionwise layer.
|
426 |
+
selfattention_layer_type (str): Encoder attention layer type,
|
427 |
+
the parameter has no effect now, it's just for configure
|
428 |
+
compatibility.
|
429 |
+
activation_type (str): Encoder activation function type.
|
430 |
+
use_cnn_module (bool): Whether to use convolution module.
|
431 |
+
cnn_module_kernel (int): Kernel size of convolution module.
|
432 |
+
causal (bool): whether to use causal convolution or not.
|
433 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
434 |
+
"""
|
435 |
+
super().__init__(input_size, output_size, attention_heads,
|
436 |
+
linear_units, num_blocks, dropout_rate,
|
437 |
+
positional_dropout_rate, attention_dropout_rate,
|
438 |
+
input_layer, pos_enc_layer_type, normalize_before,
|
439 |
+
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
440 |
+
use_dynamic_left_chunk, gradient_checkpointing)
|
441 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
442 |
+
|
443 |
+
# self-attention module definition
|
444 |
+
encoder_selfattn_layer_args = (
|
445 |
+
attention_heads,
|
446 |
+
output_size,
|
447 |
+
attention_dropout_rate,
|
448 |
+
key_bias,
|
449 |
+
)
|
450 |
+
# feed-forward module definition
|
451 |
+
positionwise_layer_args = (
|
452 |
+
output_size,
|
453 |
+
linear_units,
|
454 |
+
dropout_rate,
|
455 |
+
activation,
|
456 |
+
)
|
457 |
+
# convolution module definition
|
458 |
+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
459 |
+
cnn_module_norm, causal)
|
460 |
+
|
461 |
+
self.encoders = torch.nn.ModuleList([
|
462 |
+
ConformerEncoderLayer(
|
463 |
+
output_size,
|
464 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
465 |
+
*encoder_selfattn_layer_args),
|
466 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
467 |
+
PositionwiseFeedForward(
|
468 |
+
*positionwise_layer_args) if macaron_style else None,
|
469 |
+
ConvolutionModule(
|
470 |
+
*convolution_layer_args) if use_cnn_module else None,
|
471 |
+
dropout_rate,
|
472 |
+
normalize_before,
|
473 |
+
) for _ in range(num_blocks)
|
474 |
+
])
|
cosyvoice/transformer/encoder_layer.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Encoder self-attention layer definition."""
|
17 |
+
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
|
24 |
+
class TransformerEncoderLayer(nn.Module):
|
25 |
+
"""Encoder layer module.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
size (int): Input dimension.
|
29 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
30 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
31 |
+
instance can be used as the argument.
|
32 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
33 |
+
`PositionwiseFeedForward`, instance can be used as the argument.
|
34 |
+
dropout_rate (float): Dropout rate.
|
35 |
+
normalize_before (bool):
|
36 |
+
True: use layer_norm before each sub-block.
|
37 |
+
False: to use layer_norm after each sub-block.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
size: int,
|
43 |
+
self_attn: torch.nn.Module,
|
44 |
+
feed_forward: torch.nn.Module,
|
45 |
+
dropout_rate: float,
|
46 |
+
normalize_before: bool = True,
|
47 |
+
):
|
48 |
+
"""Construct an EncoderLayer object."""
|
49 |
+
super().__init__()
|
50 |
+
self.self_attn = self_attn
|
51 |
+
self.feed_forward = feed_forward
|
52 |
+
self.norm1 = nn.LayerNorm(size, eps=1e-12)
|
53 |
+
self.norm2 = nn.LayerNorm(size, eps=1e-12)
|
54 |
+
self.dropout = nn.Dropout(dropout_rate)
|
55 |
+
self.size = size
|
56 |
+
self.normalize_before = normalize_before
|
57 |
+
|
58 |
+
def forward(
|
59 |
+
self,
|
60 |
+
x: torch.Tensor,
|
61 |
+
mask: torch.Tensor,
|
62 |
+
pos_emb: torch.Tensor,
|
63 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
64 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
65 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
66 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
67 |
+
"""Compute encoded features.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
x (torch.Tensor): (#batch, time, size)
|
71 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
72 |
+
(0, 0, 0) means fake mask.
|
73 |
+
pos_emb (torch.Tensor): just for interface compatibility
|
74 |
+
to ConformerEncoderLayer
|
75 |
+
mask_pad (torch.Tensor): does not used in transformer layer,
|
76 |
+
just for unified api with conformer.
|
77 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
78 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
79 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
80 |
+
(#batch=1, size, cache_t2), not used here, it's for interface
|
81 |
+
compatibility to ConformerEncoderLayer.
|
82 |
+
Returns:
|
83 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
84 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
85 |
+
torch.Tensor: att_cache tensor,
|
86 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
87 |
+
torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
|
88 |
+
|
89 |
+
"""
|
90 |
+
residual = x
|
91 |
+
if self.normalize_before:
|
92 |
+
x = self.norm1(x)
|
93 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
|
94 |
+
x = residual + self.dropout(x_att)
|
95 |
+
if not self.normalize_before:
|
96 |
+
x = self.norm1(x)
|
97 |
+
|
98 |
+
residual = x
|
99 |
+
if self.normalize_before:
|
100 |
+
x = self.norm2(x)
|
101 |
+
x = residual + self.dropout(self.feed_forward(x))
|
102 |
+
if not self.normalize_before:
|
103 |
+
x = self.norm2(x)
|
104 |
+
|
105 |
+
fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
106 |
+
return x, mask, new_att_cache, fake_cnn_cache
|
107 |
+
|
108 |
+
|
109 |
+
class ConformerEncoderLayer(nn.Module):
|
110 |
+
"""Encoder layer module.
|
111 |
+
Args:
|
112 |
+
size (int): Input dimension.
|
113 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
114 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
115 |
+
instance can be used as the argument.
|
116 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
117 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
118 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
119 |
+
instance.
|
120 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
121 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
122 |
+
`ConvlutionModule` instance can be used as the argument.
|
123 |
+
dropout_rate (float): Dropout rate.
|
124 |
+
normalize_before (bool):
|
125 |
+
True: use layer_norm before each sub-block.
|
126 |
+
False: use layer_norm after each sub-block.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
size: int,
|
132 |
+
self_attn: torch.nn.Module,
|
133 |
+
feed_forward: Optional[nn.Module] = None,
|
134 |
+
feed_forward_macaron: Optional[nn.Module] = None,
|
135 |
+
conv_module: Optional[nn.Module] = None,
|
136 |
+
dropout_rate: float = 0.1,
|
137 |
+
normalize_before: bool = True,
|
138 |
+
):
|
139 |
+
"""Construct an EncoderLayer object."""
|
140 |
+
super().__init__()
|
141 |
+
self.self_attn = self_attn
|
142 |
+
self.feed_forward = feed_forward
|
143 |
+
self.feed_forward_macaron = feed_forward_macaron
|
144 |
+
self.conv_module = conv_module
|
145 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
146 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
147 |
+
if feed_forward_macaron is not None:
|
148 |
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
149 |
+
self.ff_scale = 0.5
|
150 |
+
else:
|
151 |
+
self.ff_scale = 1.0
|
152 |
+
if self.conv_module is not None:
|
153 |
+
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
154 |
+
self.norm_final = nn.LayerNorm(
|
155 |
+
size, eps=1e-12) # for the final output of the block
|
156 |
+
self.dropout = nn.Dropout(dropout_rate)
|
157 |
+
self.size = size
|
158 |
+
self.normalize_before = normalize_before
|
159 |
+
|
160 |
+
def forward(
|
161 |
+
self,
|
162 |
+
x: torch.Tensor,
|
163 |
+
mask: torch.Tensor,
|
164 |
+
pos_emb: torch.Tensor,
|
165 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
166 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
167 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
168 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
169 |
+
"""Compute encoded features.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
x (torch.Tensor): (#batch, time, size)
|
173 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
174 |
+
(0, 0, 0) means fake mask.
|
175 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
176 |
+
for ConformerEncoderLayer.
|
177 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
178 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
179 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
180 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
181 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
182 |
+
(#batch=1, size, cache_t2)
|
183 |
+
Returns:
|
184 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
185 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
186 |
+
torch.Tensor: att_cache tensor,
|
187 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
188 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
189 |
+
"""
|
190 |
+
|
191 |
+
# whether to use macaron style
|
192 |
+
if self.feed_forward_macaron is not None:
|
193 |
+
residual = x
|
194 |
+
if self.normalize_before:
|
195 |
+
x = self.norm_ff_macaron(x)
|
196 |
+
x = residual + self.ff_scale * self.dropout(
|
197 |
+
self.feed_forward_macaron(x))
|
198 |
+
if not self.normalize_before:
|
199 |
+
x = self.norm_ff_macaron(x)
|
200 |
+
|
201 |
+
# multi-headed self-attention module
|
202 |
+
residual = x
|
203 |
+
if self.normalize_before:
|
204 |
+
x = self.norm_mha(x)
|
205 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
206 |
+
att_cache)
|
207 |
+
x = residual + self.dropout(x_att)
|
208 |
+
if not self.normalize_before:
|
209 |
+
x = self.norm_mha(x)
|
210 |
+
|
211 |
+
# convolution module
|
212 |
+
# Fake new cnn cache here, and then change it in conv_module
|
213 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
214 |
+
if self.conv_module is not None:
|
215 |
+
residual = x
|
216 |
+
if self.normalize_before:
|
217 |
+
x = self.norm_conv(x)
|
218 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
219 |
+
x = residual + self.dropout(x)
|
220 |
+
|
221 |
+
if not self.normalize_before:
|
222 |
+
x = self.norm_conv(x)
|
223 |
+
|
224 |
+
# feed forward module
|
225 |
+
residual = x
|
226 |
+
if self.normalize_before:
|
227 |
+
x = self.norm_ff(x)
|
228 |
+
|
229 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
230 |
+
if not self.normalize_before:
|
231 |
+
x = self.norm_ff(x)
|
232 |
+
|
233 |
+
if self.conv_module is not None:
|
234 |
+
x = self.norm_final(x)
|
235 |
+
|
236 |
+
return x, mask, new_att_cache, new_cnn_cache
|
cosyvoice/transformer/label_smoothing_loss.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Label smoothing module."""
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
|
21 |
+
class LabelSmoothingLoss(nn.Module):
|
22 |
+
"""Label-smoothing loss.
|
23 |
+
|
24 |
+
In a standard CE loss, the label's data distribution is:
|
25 |
+
[0,1,2] ->
|
26 |
+
[
|
27 |
+
[1.0, 0.0, 0.0],
|
28 |
+
[0.0, 1.0, 0.0],
|
29 |
+
[0.0, 0.0, 1.0],
|
30 |
+
]
|
31 |
+
|
32 |
+
In the smoothing version CE Loss,some probabilities
|
33 |
+
are taken from the true label prob (1.0) and are divided
|
34 |
+
among other labels.
|
35 |
+
|
36 |
+
e.g.
|
37 |
+
smoothing=0.1
|
38 |
+
[0,1,2] ->
|
39 |
+
[
|
40 |
+
[0.9, 0.05, 0.05],
|
41 |
+
[0.05, 0.9, 0.05],
|
42 |
+
[0.05, 0.05, 0.9],
|
43 |
+
]
|
44 |
+
|
45 |
+
Args:
|
46 |
+
size (int): the number of class
|
47 |
+
padding_idx (int): padding class id which will be ignored for loss
|
48 |
+
smoothing (float): smoothing rate (0.0 means the conventional CE)
|
49 |
+
normalize_length (bool):
|
50 |
+
normalize loss by sequence length if True
|
51 |
+
normalize loss by batch size if False
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self,
|
55 |
+
size: int,
|
56 |
+
padding_idx: int,
|
57 |
+
smoothing: float,
|
58 |
+
normalize_length: bool = False):
|
59 |
+
"""Construct an LabelSmoothingLoss object."""
|
60 |
+
super(LabelSmoothingLoss, self).__init__()
|
61 |
+
self.criterion = nn.KLDivLoss(reduction="none")
|
62 |
+
self.padding_idx = padding_idx
|
63 |
+
self.confidence = 1.0 - smoothing
|
64 |
+
self.smoothing = smoothing
|
65 |
+
self.size = size
|
66 |
+
self.normalize_length = normalize_length
|
67 |
+
|
68 |
+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
69 |
+
"""Compute loss between x and target.
|
70 |
+
|
71 |
+
The model outputs and data labels tensors are flatten to
|
72 |
+
(batch*seqlen, class) shape and a mask is applied to the
|
73 |
+
padding part which should not be calculated for loss.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
x (torch.Tensor): prediction (batch, seqlen, class)
|
77 |
+
target (torch.Tensor):
|
78 |
+
target signal masked with self.padding_id (batch, seqlen)
|
79 |
+
Returns:
|
80 |
+
loss (torch.Tensor) : The KL loss, scalar float value
|
81 |
+
"""
|
82 |
+
assert x.size(2) == self.size
|
83 |
+
batch_size = x.size(0)
|
84 |
+
x = x.view(-1, self.size)
|
85 |
+
target = target.view(-1)
|
86 |
+
# use zeros_like instead of torch.no_grad() for true_dist,
|
87 |
+
# since no_grad() can not be exported by JIT
|
88 |
+
true_dist = torch.zeros_like(x)
|
89 |
+
true_dist.fill_(self.smoothing / (self.size - 1))
|
90 |
+
ignore = target == self.padding_idx # (B,)
|
91 |
+
total = len(target) - ignore.sum().item()
|
92 |
+
target = target.masked_fill(ignore, 0) # avoid -1 index
|
93 |
+
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
94 |
+
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
95 |
+
denom = total if self.normalize_length else batch_size
|
96 |
+
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
cosyvoice/transformer/positionwise_feed_forward.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Positionwise feed forward layer definition."""
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
21 |
+
"""Positionwise feed forward layer.
|
22 |
+
|
23 |
+
FeedForward are appied on each position of the sequence.
|
24 |
+
The output dim is same with the input dim.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
idim (int): Input dimenstion.
|
28 |
+
hidden_units (int): The number of hidden units.
|
29 |
+
dropout_rate (float): Dropout rate.
|
30 |
+
activation (torch.nn.Module): Activation function
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
idim: int,
|
36 |
+
hidden_units: int,
|
37 |
+
dropout_rate: float,
|
38 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
39 |
+
):
|
40 |
+
"""Construct a PositionwiseFeedForward object."""
|
41 |
+
super(PositionwiseFeedForward, self).__init__()
|
42 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
43 |
+
self.activation = activation
|
44 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
45 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
46 |
+
|
47 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
48 |
+
"""Forward function.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
xs: input tensor (B, L, D)
|
52 |
+
Returns:
|
53 |
+
output tensor, (B, L, D)
|
54 |
+
"""
|
55 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
56 |
+
|
57 |
+
|
58 |
+
class MoEFFNLayer(torch.nn.Module):
|
59 |
+
"""
|
60 |
+
Mixture of expert with Positionwise feed forward layer
|
61 |
+
See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
|
62 |
+
The output dim is same with the input dim.
|
63 |
+
|
64 |
+
Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
|
65 |
+
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
|
66 |
+
Args:
|
67 |
+
n_expert: number of expert.
|
68 |
+
n_expert_per_token: The actual number of experts used for each frame
|
69 |
+
idim (int): Input dimenstion.
|
70 |
+
hidden_units (int): The number of hidden units.
|
71 |
+
dropout_rate (float): Dropout rate.
|
72 |
+
activation (torch.nn.Module): Activation function
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
n_expert: int,
|
78 |
+
n_expert_per_token: int,
|
79 |
+
idim: int,
|
80 |
+
hidden_units: int,
|
81 |
+
dropout_rate: float,
|
82 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
83 |
+
):
|
84 |
+
super(MoEFFNLayer, self).__init__()
|
85 |
+
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
|
86 |
+
self.experts = torch.nn.ModuleList(
|
87 |
+
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
|
88 |
+
activation) for _ in range(n_expert))
|
89 |
+
self.n_expert_per_token = n_expert_per_token
|
90 |
+
|
91 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
92 |
+
"""Foward function.
|
93 |
+
Args:
|
94 |
+
xs: input tensor (B, L, D)
|
95 |
+
Returns:
|
96 |
+
output tensor, (B, L, D)
|
97 |
+
|
98 |
+
"""
|
99 |
+
B, L, D = xs.size(
|
100 |
+
) # batch size, sequence length, embedding dimension (idim)
|
101 |
+
xs = xs.view(-1, D) # (B*L, D)
|
102 |
+
router = self.gate(xs) # (B*L, n_expert)
|
103 |
+
logits, indices = torch.topk(
|
104 |
+
router, self.n_expert_per_token
|
105 |
+
) # probs:(B*L, n_expert), indices: (B*L, n_expert)
|
106 |
+
weights = torch.nn.functional.softmax(
|
107 |
+
logits, dim=1,
|
108 |
+
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
|
109 |
+
output = torch.zeros_like(xs) # (B*L, D)
|
110 |
+
for i, expert in enumerate(self.experts):
|
111 |
+
mask = indices == i
|
112 |
+
batch_idx, ith_expert = torch.where(mask)
|
113 |
+
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
|
114 |
+
xs[batch_idx])
|
115 |
+
return output.view(B, L, D)
|
cosyvoice/transformer/subsampling.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Subsampling layer definition."""
|
17 |
+
|
18 |
+
from typing import Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
|
23 |
+
class BaseSubsampling(torch.nn.Module):
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
super().__init__()
|
27 |
+
self.right_context = 0
|
28 |
+
self.subsampling_rate = 1
|
29 |
+
|
30 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
31 |
+
size: int) -> torch.Tensor:
|
32 |
+
return self.pos_enc.position_encoding(offset, size)
|
33 |
+
|
34 |
+
|
35 |
+
class EmbedinigNoSubsampling(BaseSubsampling):
|
36 |
+
"""Embedding input without subsampling
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
40 |
+
pos_enc_class: torch.nn.Module):
|
41 |
+
super().__init__()
|
42 |
+
self.embed = torch.nn.Embedding(idim, odim)
|
43 |
+
self.pos_enc = pos_enc_class
|
44 |
+
|
45 |
+
def forward(
|
46 |
+
self,
|
47 |
+
x: torch.Tensor,
|
48 |
+
x_mask: torch.Tensor,
|
49 |
+
offset: Union[int, torch.Tensor] = 0
|
50 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
51 |
+
"""Input x.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
55 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
59 |
+
where time' = time .
|
60 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
61 |
+
where time' = time .
|
62 |
+
|
63 |
+
"""
|
64 |
+
x = self.embed(x)
|
65 |
+
x, pos_emb = self.pos_enc(x, offset)
|
66 |
+
return x, pos_emb, x_mask
|
67 |
+
|
68 |
+
|
69 |
+
class LinearNoSubsampling(BaseSubsampling):
|
70 |
+
"""Linear transform the input without subsampling
|
71 |
+
|
72 |
+
Args:
|
73 |
+
idim (int): Input dimension.
|
74 |
+
odim (int): Output dimension.
|
75 |
+
dropout_rate (float): Dropout rate.
|
76 |
+
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
80 |
+
pos_enc_class: torch.nn.Module):
|
81 |
+
"""Construct an linear object."""
|
82 |
+
super().__init__()
|
83 |
+
self.out = torch.nn.Sequential(
|
84 |
+
torch.nn.Linear(idim, odim),
|
85 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
86 |
+
torch.nn.Dropout(dropout_rate),
|
87 |
+
)
|
88 |
+
self.pos_enc = pos_enc_class
|
89 |
+
self.right_context = 0
|
90 |
+
self.subsampling_rate = 1
|
91 |
+
|
92 |
+
def forward(
|
93 |
+
self,
|
94 |
+
x: torch.Tensor,
|
95 |
+
x_mask: torch.Tensor,
|
96 |
+
offset: Union[int, torch.Tensor] = 0
|
97 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
98 |
+
"""Input x.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
102 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
106 |
+
where time' = time .
|
107 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
108 |
+
where time' = time .
|
109 |
+
|
110 |
+
"""
|
111 |
+
x = self.out(x)
|
112 |
+
x, pos_emb = self.pos_enc(x, offset)
|
113 |
+
return x, pos_emb, x_mask
|
114 |
+
|
115 |
+
|
116 |
+
class Conv1dSubsampling2(BaseSubsampling):
|
117 |
+
"""Convolutional 1D subsampling (to 1/2 length).
|
118 |
+
It is designed for Whisper, ref:
|
119 |
+
https://github.com/openai/whisper/blob/main/whisper/model.py
|
120 |
+
|
121 |
+
Args:
|
122 |
+
idim (int): Input dimension.
|
123 |
+
odim (int): Output dimension.
|
124 |
+
dropout_rate (float): Dropout rate.
|
125 |
+
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
129 |
+
pos_enc_class: torch.nn.Module):
|
130 |
+
"""Construct an Conv1dSubsampling2 object."""
|
131 |
+
super().__init__()
|
132 |
+
self.conv = torch.nn.Sequential(
|
133 |
+
torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
|
134 |
+
torch.nn.GELU(),
|
135 |
+
torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
|
136 |
+
torch.nn.GELU(),
|
137 |
+
)
|
138 |
+
self.pos_enc = pos_enc_class
|
139 |
+
# The right context for every conv layer is computed by:
|
140 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
141 |
+
self.subsampling_rate = 2
|
142 |
+
# 4 = (3 - 1) * 1 + (3 - 1) * 1
|
143 |
+
self.right_context = 4
|
144 |
+
|
145 |
+
def forward(
|
146 |
+
self,
|
147 |
+
x: torch.Tensor,
|
148 |
+
x_mask: torch.Tensor,
|
149 |
+
offset: Union[int, torch.Tensor] = 0
|
150 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
151 |
+
"""Subsample x.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
155 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
159 |
+
where time' = time // 2.
|
160 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
161 |
+
where time' = time // 2.
|
162 |
+
torch.Tensor: positional encoding
|
163 |
+
|
164 |
+
"""
|
165 |
+
time = x.size(1)
|
166 |
+
x = x.transpose(1, 2) # (b, f, t)
|
167 |
+
x = self.conv(x)
|
168 |
+
x = x.transpose(1, 2) # (b, t, f)
|
169 |
+
x, pos_emb = self.pos_enc(x, offset)
|
170 |
+
return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
|
171 |
+
|
172 |
+
|
173 |
+
class Conv2dSubsampling4(BaseSubsampling):
|
174 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
175 |
+
|
176 |
+
Args:
|
177 |
+
idim (int): Input dimension.
|
178 |
+
odim (int): Output dimension.
|
179 |
+
dropout_rate (float): Dropout rate.
|
180 |
+
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
184 |
+
pos_enc_class: torch.nn.Module):
|
185 |
+
"""Construct an Conv2dSubsampling4 object."""
|
186 |
+
super().__init__()
|
187 |
+
self.conv = torch.nn.Sequential(
|
188 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
189 |
+
torch.nn.ReLU(),
|
190 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
191 |
+
torch.nn.ReLU(),
|
192 |
+
)
|
193 |
+
self.out = torch.nn.Sequential(
|
194 |
+
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
|
195 |
+
self.pos_enc = pos_enc_class
|
196 |
+
# The right context for every conv layer is computed by:
|
197 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
198 |
+
self.subsampling_rate = 4
|
199 |
+
# 6 = (3 - 1) * 1 + (3 - 1) * 2
|
200 |
+
self.right_context = 6
|
201 |
+
|
202 |
+
def forward(
|
203 |
+
self,
|
204 |
+
x: torch.Tensor,
|
205 |
+
x_mask: torch.Tensor,
|
206 |
+
offset: Union[int, torch.Tensor] = 0
|
207 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
208 |
+
"""Subsample x.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
212 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
216 |
+
where time' = time // 4.
|
217 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
218 |
+
where time' = time // 4.
|
219 |
+
torch.Tensor: positional encoding
|
220 |
+
|
221 |
+
"""
|
222 |
+
x = x.unsqueeze(1) # (b, c=1, t, f)
|
223 |
+
x = self.conv(x)
|
224 |
+
b, c, t, f = x.size()
|
225 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
226 |
+
x, pos_emb = self.pos_enc(x, offset)
|
227 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
|
228 |
+
|
229 |
+
|
230 |
+
class Conv2dSubsampling6(BaseSubsampling):
|
231 |
+
"""Convolutional 2D subsampling (to 1/6 length).
|
232 |
+
Args:
|
233 |
+
idim (int): Input dimension.
|
234 |
+
odim (int): Output dimension.
|
235 |
+
dropout_rate (float): Dropout rate.
|
236 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
240 |
+
pos_enc_class: torch.nn.Module):
|
241 |
+
"""Construct an Conv2dSubsampling6 object."""
|
242 |
+
super().__init__()
|
243 |
+
self.conv = torch.nn.Sequential(
|
244 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
245 |
+
torch.nn.ReLU(),
|
246 |
+
torch.nn.Conv2d(odim, odim, 5, 3),
|
247 |
+
torch.nn.ReLU(),
|
248 |
+
)
|
249 |
+
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
|
250 |
+
odim)
|
251 |
+
self.pos_enc = pos_enc_class
|
252 |
+
# 10 = (3 - 1) * 1 + (5 - 1) * 2
|
253 |
+
self.subsampling_rate = 6
|
254 |
+
self.right_context = 10
|
255 |
+
|
256 |
+
def forward(
|
257 |
+
self,
|
258 |
+
x: torch.Tensor,
|
259 |
+
x_mask: torch.Tensor,
|
260 |
+
offset: Union[int, torch.Tensor] = 0
|
261 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
262 |
+
"""Subsample x.
|
263 |
+
Args:
|
264 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
265 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
269 |
+
where time' = time // 6.
|
270 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
271 |
+
where time' = time // 6.
|
272 |
+
torch.Tensor: positional encoding
|
273 |
+
"""
|
274 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
275 |
+
x = self.conv(x)
|
276 |
+
b, c, t, f = x.size()
|
277 |
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
278 |
+
x, pos_emb = self.pos_enc(x, offset)
|
279 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
|
280 |
+
|
281 |
+
|
282 |
+
class Conv2dSubsampling8(BaseSubsampling):
|
283 |
+
"""Convolutional 2D subsampling (to 1/8 length).
|
284 |
+
|
285 |
+
Args:
|
286 |
+
idim (int): Input dimension.
|
287 |
+
odim (int): Output dimension.
|
288 |
+
dropout_rate (float): Dropout rate.
|
289 |
+
|
290 |
+
"""
|
291 |
+
|
292 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
293 |
+
pos_enc_class: torch.nn.Module):
|
294 |
+
"""Construct an Conv2dSubsampling8 object."""
|
295 |
+
super().__init__()
|
296 |
+
self.conv = torch.nn.Sequential(
|
297 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
298 |
+
torch.nn.ReLU(),
|
299 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
300 |
+
torch.nn.ReLU(),
|
301 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
302 |
+
torch.nn.ReLU(),
|
303 |
+
)
|
304 |
+
self.linear = torch.nn.Linear(
|
305 |
+
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
|
306 |
+
self.pos_enc = pos_enc_class
|
307 |
+
self.subsampling_rate = 8
|
308 |
+
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
|
309 |
+
self.right_context = 14
|
310 |
+
|
311 |
+
def forward(
|
312 |
+
self,
|
313 |
+
x: torch.Tensor,
|
314 |
+
x_mask: torch.Tensor,
|
315 |
+
offset: Union[int, torch.Tensor] = 0
|
316 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
317 |
+
"""Subsample x.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
321 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
325 |
+
where time' = time // 8.
|
326 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
327 |
+
where time' = time // 8.
|
328 |
+
torch.Tensor: positional encoding
|
329 |
+
"""
|
330 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
331 |
+
x = self.conv(x)
|
332 |
+
b, c, t, f = x.size()
|
333 |
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
334 |
+
x, pos_emb = self.pos_enc(x, offset)
|
335 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
|
336 |
+
|
337 |
+
|
338 |
+
class LegacyLinearNoSubsampling(BaseSubsampling):
|
339 |
+
"""Linear transform the input without subsampling
|
340 |
+
|
341 |
+
Args:
|
342 |
+
idim (int): Input dimension.
|
343 |
+
odim (int): Output dimension.
|
344 |
+
dropout_rate (float): Dropout rate.
|
345 |
+
|
346 |
+
"""
|
347 |
+
|
348 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
349 |
+
pos_enc_class: torch.nn.Module):
|
350 |
+
"""Construct an linear object."""
|
351 |
+
super().__init__()
|
352 |
+
self.out = torch.nn.Sequential(
|
353 |
+
torch.nn.Linear(idim, odim),
|
354 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
355 |
+
torch.nn.Dropout(dropout_rate),
|
356 |
+
torch.nn.ReLU(),
|
357 |
+
)
|
358 |
+
self.pos_enc = pos_enc_class
|
359 |
+
self.right_context = 0
|
360 |
+
self.subsampling_rate = 1
|
361 |
+
|
362 |
+
def forward(
|
363 |
+
self,
|
364 |
+
x: torch.Tensor,
|
365 |
+
x_mask: torch.Tensor,
|
366 |
+
offset: Union[int, torch.Tensor] = 0
|
367 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
368 |
+
"""Input x.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
372 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
376 |
+
where time' = time .
|
377 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
378 |
+
where time' = time .
|
379 |
+
|
380 |
+
"""
|
381 |
+
x = self.out(x)
|
382 |
+
x, pos_emb = self.pos_enc(x, offset)
|
383 |
+
return x, pos_emb, x_mask
|
cosyvoice/transformer/upsample_encoder.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
17 |
+
"""Encoder definition."""
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
from torch.nn import functional as F
|
23 |
+
|
24 |
+
from cosyvoice.transformer.convolution import ConvolutionModule
|
25 |
+
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
26 |
+
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
27 |
+
from cosyvoice.utils.class_utils import (
|
28 |
+
COSYVOICE_EMB_CLASSES,
|
29 |
+
COSYVOICE_SUBSAMPLE_CLASSES,
|
30 |
+
COSYVOICE_ATTENTION_CLASSES,
|
31 |
+
COSYVOICE_ACTIVATION_CLASSES,
|
32 |
+
)
|
33 |
+
from cosyvoice.utils.mask import make_pad_mask
|
34 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
35 |
+
|
36 |
+
|
37 |
+
class Upsample1D(nn.Module):
|
38 |
+
"""A 1D upsampling layer with an optional convolution.
|
39 |
+
|
40 |
+
Parameters:
|
41 |
+
channels (`int`):
|
42 |
+
number of channels in the inputs and outputs.
|
43 |
+
use_conv (`bool`, default `False`):
|
44 |
+
option to use a convolution.
|
45 |
+
use_conv_transpose (`bool`, default `False`):
|
46 |
+
option to use a convolution transpose.
|
47 |
+
out_channels (`int`, optional):
|
48 |
+
number of output channels. Defaults to `channels`.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
52 |
+
super().__init__()
|
53 |
+
self.channels = channels
|
54 |
+
self.out_channels = out_channels
|
55 |
+
self.stride = stride
|
56 |
+
# In this mode, first repeat interpolate, than conv with stride=1
|
57 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
58 |
+
|
59 |
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
|
60 |
+
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
61 |
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
62 |
+
outputs = self.conv(outputs)
|
63 |
+
return outputs, input_lengths * self.stride
|
64 |
+
|
65 |
+
|
66 |
+
class PreLookaheadLayer(nn.Module):
|
67 |
+
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
68 |
+
super().__init__()
|
69 |
+
self.channels = channels
|
70 |
+
self.pre_lookahead_len = pre_lookahead_len
|
71 |
+
self.conv1 = nn.Conv1d(
|
72 |
+
channels, channels,
|
73 |
+
kernel_size=pre_lookahead_len + 1,
|
74 |
+
stride=1, padding=0,
|
75 |
+
)
|
76 |
+
self.conv2 = nn.Conv1d(
|
77 |
+
channels, channels,
|
78 |
+
kernel_size=3, stride=1, padding=0,
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
82 |
+
"""
|
83 |
+
inputs: (batch_size, seq_len, channels)
|
84 |
+
"""
|
85 |
+
outputs = inputs.transpose(1, 2).contiguous()
|
86 |
+
# look ahead
|
87 |
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
88 |
+
outputs = F.leaky_relu(self.conv1(outputs))
|
89 |
+
# outputs
|
90 |
+
outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
|
91 |
+
outputs = self.conv2(outputs)
|
92 |
+
outputs = outputs.transpose(1, 2).contiguous()
|
93 |
+
|
94 |
+
# residual connection
|
95 |
+
outputs = outputs + inputs
|
96 |
+
return outputs
|
97 |
+
|
98 |
+
|
99 |
+
class UpsampleConformerEncoder(torch.nn.Module):
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
input_size: int,
|
104 |
+
output_size: int = 256,
|
105 |
+
attention_heads: int = 4,
|
106 |
+
linear_units: int = 2048,
|
107 |
+
num_blocks: int = 6,
|
108 |
+
dropout_rate: float = 0.1,
|
109 |
+
positional_dropout_rate: float = 0.1,
|
110 |
+
attention_dropout_rate: float = 0.0,
|
111 |
+
input_layer: str = "conv2d",
|
112 |
+
pos_enc_layer_type: str = "rel_pos",
|
113 |
+
normalize_before: bool = True,
|
114 |
+
static_chunk_size: int = 0,
|
115 |
+
use_dynamic_chunk: bool = False,
|
116 |
+
global_cmvn: torch.nn.Module = None,
|
117 |
+
use_dynamic_left_chunk: bool = False,
|
118 |
+
positionwise_conv_kernel_size: int = 1,
|
119 |
+
macaron_style: bool = True,
|
120 |
+
selfattention_layer_type: str = "rel_selfattn",
|
121 |
+
activation_type: str = "swish",
|
122 |
+
use_cnn_module: bool = True,
|
123 |
+
cnn_module_kernel: int = 15,
|
124 |
+
causal: bool = False,
|
125 |
+
cnn_module_norm: str = "batch_norm",
|
126 |
+
key_bias: bool = True,
|
127 |
+
gradient_checkpointing: bool = False,
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
Args:
|
131 |
+
input_size (int): input dim
|
132 |
+
output_size (int): dimension of attention
|
133 |
+
attention_heads (int): the number of heads of multi head attention
|
134 |
+
linear_units (int): the hidden units number of position-wise feed
|
135 |
+
forward
|
136 |
+
num_blocks (int): the number of decoder blocks
|
137 |
+
dropout_rate (float): dropout rate
|
138 |
+
attention_dropout_rate (float): dropout rate in attention
|
139 |
+
positional_dropout_rate (float): dropout rate after adding
|
140 |
+
positional encoding
|
141 |
+
input_layer (str): input layer type.
|
142 |
+
optional [linear, conv2d, conv2d6, conv2d8]
|
143 |
+
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
144 |
+
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
145 |
+
normalize_before (bool):
|
146 |
+
True: use layer_norm before each sub-block of a layer.
|
147 |
+
False: use layer_norm after each sub-block of a layer.
|
148 |
+
static_chunk_size (int): chunk size for static chunk training and
|
149 |
+
decoding
|
150 |
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
151 |
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
152 |
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
153 |
+
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
154 |
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
155 |
+
dynamic chunk training
|
156 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
157 |
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
158 |
+
checkpointed segment during backward.
|
159 |
+
"""
|
160 |
+
super().__init__()
|
161 |
+
self._output_size = output_size
|
162 |
+
|
163 |
+
self.global_cmvn = global_cmvn
|
164 |
+
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
165 |
+
input_size,
|
166 |
+
output_size,
|
167 |
+
dropout_rate,
|
168 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
169 |
+
positional_dropout_rate),
|
170 |
+
)
|
171 |
+
|
172 |
+
self.normalize_before = normalize_before
|
173 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
174 |
+
self.static_chunk_size = static_chunk_size
|
175 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
176 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
177 |
+
self.gradient_checkpointing = gradient_checkpointing
|
178 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
179 |
+
# self-attention module definition
|
180 |
+
encoder_selfattn_layer_args = (
|
181 |
+
attention_heads,
|
182 |
+
output_size,
|
183 |
+
attention_dropout_rate,
|
184 |
+
key_bias,
|
185 |
+
)
|
186 |
+
# feed-forward module definition
|
187 |
+
positionwise_layer_args = (
|
188 |
+
output_size,
|
189 |
+
linear_units,
|
190 |
+
dropout_rate,
|
191 |
+
activation,
|
192 |
+
)
|
193 |
+
# convolution module definition
|
194 |
+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
195 |
+
cnn_module_norm, causal)
|
196 |
+
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
|
197 |
+
self.encoders = torch.nn.ModuleList([
|
198 |
+
ConformerEncoderLayer(
|
199 |
+
output_size,
|
200 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
201 |
+
*encoder_selfattn_layer_args),
|
202 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
203 |
+
PositionwiseFeedForward(
|
204 |
+
*positionwise_layer_args) if macaron_style else None,
|
205 |
+
ConvolutionModule(
|
206 |
+
*convolution_layer_args) if use_cnn_module else None,
|
207 |
+
dropout_rate,
|
208 |
+
normalize_before,
|
209 |
+
) for _ in range(num_blocks)
|
210 |
+
])
|
211 |
+
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
212 |
+
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
213 |
+
input_size,
|
214 |
+
output_size,
|
215 |
+
dropout_rate,
|
216 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
217 |
+
positional_dropout_rate),
|
218 |
+
)
|
219 |
+
self.up_encoders = torch.nn.ModuleList([
|
220 |
+
ConformerEncoderLayer(
|
221 |
+
output_size,
|
222 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
223 |
+
*encoder_selfattn_layer_args),
|
224 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
225 |
+
PositionwiseFeedForward(
|
226 |
+
*positionwise_layer_args) if macaron_style else None,
|
227 |
+
ConvolutionModule(
|
228 |
+
*convolution_layer_args) if use_cnn_module else None,
|
229 |
+
dropout_rate,
|
230 |
+
normalize_before,
|
231 |
+
) for _ in range(4)
|
232 |
+
])
|
233 |
+
|
234 |
+
def output_size(self) -> int:
|
235 |
+
return self._output_size
|
236 |
+
|
237 |
+
def forward(
|
238 |
+
self,
|
239 |
+
xs: torch.Tensor,
|
240 |
+
xs_lens: torch.Tensor,
|
241 |
+
decoding_chunk_size: int = 0,
|
242 |
+
num_decoding_left_chunks: int = -1,
|
243 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
244 |
+
"""Embed positions in tensor.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
xs: padded input tensor (B, T, D)
|
248 |
+
xs_lens: input length (B)
|
249 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
250 |
+
0: default for training, use random dynamic chunk.
|
251 |
+
<0: for decoding, use full chunk.
|
252 |
+
>0: for decoding, use fixed chunk size as set.
|
253 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
254 |
+
the chunk size is decoding_chunk_size.
|
255 |
+
>=0: use num_decoding_left_chunks
|
256 |
+
<0: use all left chunks
|
257 |
+
Returns:
|
258 |
+
encoder output tensor xs, and subsampled masks
|
259 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
260 |
+
masks: torch.Tensor batch padding mask after subsample
|
261 |
+
(B, 1, T' ~= T/subsample_rate)
|
262 |
+
NOTE(xcsong):
|
263 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
264 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
265 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
266 |
+
"""
|
267 |
+
T = xs.size(1)
|
268 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
269 |
+
if self.global_cmvn is not None:
|
270 |
+
xs = self.global_cmvn(xs)
|
271 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
272 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
273 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
274 |
+
self.use_dynamic_chunk,
|
275 |
+
self.use_dynamic_left_chunk,
|
276 |
+
decoding_chunk_size,
|
277 |
+
self.static_chunk_size,
|
278 |
+
num_decoding_left_chunks)
|
279 |
+
# lookahead + conformer encoder
|
280 |
+
xs = self.pre_lookahead_layer(xs)
|
281 |
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
282 |
+
|
283 |
+
# upsample + conformer encoder
|
284 |
+
xs = xs.transpose(1, 2).contiguous()
|
285 |
+
xs, xs_lens = self.up_layer(xs, xs_lens)
|
286 |
+
xs = xs.transpose(1, 2).contiguous()
|
287 |
+
T = xs.size(1)
|
288 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
289 |
+
xs, pos_emb, masks = self.up_embed(xs, masks)
|
290 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
291 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
292 |
+
self.use_dynamic_chunk,
|
293 |
+
self.use_dynamic_left_chunk,
|
294 |
+
decoding_chunk_size,
|
295 |
+
self.static_chunk_size * self.up_layer.stride,
|
296 |
+
num_decoding_left_chunks)
|
297 |
+
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
298 |
+
|
299 |
+
if self.normalize_before:
|
300 |
+
xs = self.after_norm(xs)
|
301 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
302 |
+
# return the masks before encoder layers, and the masks will be used
|
303 |
+
# for cross attention with decoder later
|
304 |
+
return xs, masks
|
305 |
+
|
306 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
307 |
+
pos_emb: torch.Tensor,
|
308 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
309 |
+
for layer in self.encoders:
|
310 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
311 |
+
return xs
|
312 |
+
|
313 |
+
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
314 |
+
pos_emb: torch.Tensor,
|
315 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
316 |
+
for layer in self.up_encoders:
|
317 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
318 |
+
return xs
|
cosyvoice/utils/executor.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import logging
|
17 |
+
from contextlib import nullcontext
|
18 |
+
import os
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.distributed as dist
|
22 |
+
|
23 |
+
from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
|
24 |
+
|
25 |
+
|
26 |
+
class Executor:
|
27 |
+
|
28 |
+
def __init__(self, gan: bool = False):
|
29 |
+
self.gan = gan
|
30 |
+
self.step = 0
|
31 |
+
self.epoch = 0
|
32 |
+
self.rank = int(os.environ.get('RANK', 0))
|
33 |
+
self.device = torch.device('cuda:{}'.format(self.rank))
|
34 |
+
|
35 |
+
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
|
36 |
+
''' Train one epoch
|
37 |
+
'''
|
38 |
+
|
39 |
+
lr = optimizer.param_groups[0]['lr']
|
40 |
+
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
|
41 |
+
logging.info('using accumulate grad, new batch size is {} times'
|
42 |
+
' larger than before'.format(info_dict['accum_grad']))
|
43 |
+
# A context manager to be used in conjunction with an instance of
|
44 |
+
# torch.nn.parallel.DistributedDataParallel to be able to train
|
45 |
+
# with uneven inputs across participating processes.
|
46 |
+
model.train()
|
47 |
+
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
48 |
+
with model_context():
|
49 |
+
for batch_idx, batch_dict in enumerate(train_data_loader):
|
50 |
+
info_dict["tag"] = "TRAIN"
|
51 |
+
info_dict["step"] = self.step
|
52 |
+
info_dict["epoch"] = self.epoch
|
53 |
+
info_dict["batch_idx"] = batch_idx
|
54 |
+
if cosyvoice_join(group_join, info_dict):
|
55 |
+
break
|
56 |
+
|
57 |
+
# Disable gradient synchronizations across DDP processes.
|
58 |
+
# Within this context, gradients will be accumulated on module
|
59 |
+
# variables, which will later be synchronized.
|
60 |
+
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
61 |
+
context = model.no_sync
|
62 |
+
# Used for single gpu training and DDP gradient synchronization
|
63 |
+
# processes.
|
64 |
+
else:
|
65 |
+
context = nullcontext
|
66 |
+
|
67 |
+
with context():
|
68 |
+
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
69 |
+
info_dict = batch_backward(model, scaler, info_dict)
|
70 |
+
|
71 |
+
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
72 |
+
log_per_step(writer, info_dict)
|
73 |
+
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
74 |
+
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
75 |
+
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
76 |
+
dist.barrier()
|
77 |
+
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
78 |
+
model.train()
|
79 |
+
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
80 |
+
self.step += 1
|
81 |
+
dist.barrier()
|
82 |
+
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
83 |
+
|
84 |
+
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
85 |
+
writer, info_dict, scaler, group_join):
|
86 |
+
''' Train one epoch
|
87 |
+
'''
|
88 |
+
|
89 |
+
lr = optimizer.param_groups[0]['lr']
|
90 |
+
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
|
91 |
+
logging.info('using accumulate grad, new batch size is {} times'
|
92 |
+
' larger than before'.format(info_dict['accum_grad']))
|
93 |
+
# A context manager to be used in conjunction with an instance of
|
94 |
+
# torch.nn.parallel.DistributedDataParallel to be able to train
|
95 |
+
# with uneven inputs across participating processes.
|
96 |
+
model.train()
|
97 |
+
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
98 |
+
with model_context():
|
99 |
+
for batch_idx, batch_dict in enumerate(train_data_loader):
|
100 |
+
info_dict["tag"] = "TRAIN"
|
101 |
+
info_dict["step"] = self.step
|
102 |
+
info_dict["epoch"] = self.epoch
|
103 |
+
info_dict["batch_idx"] = batch_idx
|
104 |
+
if cosyvoice_join(group_join, info_dict):
|
105 |
+
break
|
106 |
+
|
107 |
+
# Disable gradient synchronizations across DDP processes.
|
108 |
+
# Within this context, gradients will be accumulated on module
|
109 |
+
# variables, which will later be synchronized.
|
110 |
+
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
111 |
+
context = model.no_sync
|
112 |
+
# Used for single gpu training and DDP gradient synchronization
|
113 |
+
# processes.
|
114 |
+
else:
|
115 |
+
context = nullcontext
|
116 |
+
|
117 |
+
with context():
|
118 |
+
batch_dict['turn'] = 'discriminator'
|
119 |
+
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
120 |
+
info_dict = batch_backward(model, scaler, info_dict)
|
121 |
+
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
|
122 |
+
optimizer.zero_grad()
|
123 |
+
log_per_step(writer, info_dict)
|
124 |
+
with context():
|
125 |
+
batch_dict['turn'] = 'generator'
|
126 |
+
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
127 |
+
info_dict = batch_backward(model, scaler, info_dict)
|
128 |
+
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
129 |
+
optimizer_d.zero_grad()
|
130 |
+
log_per_step(writer, info_dict)
|
131 |
+
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
132 |
+
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
133 |
+
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
134 |
+
dist.barrier()
|
135 |
+
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
136 |
+
model.train()
|
137 |
+
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
138 |
+
self.step += 1
|
139 |
+
dist.barrier()
|
140 |
+
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
141 |
+
|
142 |
+
@torch.inference_mode()
|
143 |
+
def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
|
144 |
+
''' Cross validation on
|
145 |
+
'''
|
146 |
+
logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
|
147 |
+
model.eval()
|
148 |
+
total_num_utts, total_loss_dict = 0, {} # avoid division by 0
|
149 |
+
for batch_idx, batch_dict in enumerate(cv_data_loader):
|
150 |
+
info_dict["tag"] = "CV"
|
151 |
+
info_dict["step"] = self.step
|
152 |
+
info_dict["epoch"] = self.epoch
|
153 |
+
info_dict["batch_idx"] = batch_idx
|
154 |
+
|
155 |
+
num_utts = len(batch_dict["utts"])
|
156 |
+
total_num_utts += num_utts
|
157 |
+
|
158 |
+
if self.gan is True:
|
159 |
+
batch_dict['turn'] = 'generator'
|
160 |
+
info_dict = batch_forward(model, batch_dict, None, info_dict)
|
161 |
+
|
162 |
+
for k, v in info_dict['loss_dict'].items():
|
163 |
+
if k not in total_loss_dict:
|
164 |
+
total_loss_dict[k] = []
|
165 |
+
total_loss_dict[k].append(v.item() * num_utts)
|
166 |
+
log_per_step(None, info_dict)
|
167 |
+
for k, v in total_loss_dict.items():
|
168 |
+
total_loss_dict[k] = sum(v) / total_num_utts
|
169 |
+
info_dict['loss_dict'] = total_loss_dict
|
170 |
+
log_per_save(writer, info_dict)
|
171 |
+
model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
|
172 |
+
save_model(model, model_name, info_dict)
|
cosyvoice/utils/file_utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import json
|
17 |
+
import torchaudio
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
logging.basicConfig(level=logging.DEBUG,
|
21 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
22 |
+
|
23 |
+
|
24 |
+
def read_lists(list_file):
|
25 |
+
lists = []
|
26 |
+
with open(list_file, 'r', encoding='utf8') as fin:
|
27 |
+
for line in fin:
|
28 |
+
lists.append(line.strip())
|
29 |
+
return lists
|
30 |
+
|
31 |
+
|
32 |
+
def read_json_lists(list_file):
|
33 |
+
lists = read_lists(list_file)
|
34 |
+
results = {}
|
35 |
+
for fn in lists:
|
36 |
+
with open(fn, 'r', encoding='utf8') as fin:
|
37 |
+
results.update(json.load(fin))
|
38 |
+
return results
|
39 |
+
|
40 |
+
|
41 |
+
def load_wav(wav, target_sr):
|
42 |
+
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
|
43 |
+
speech = speech.mean(dim=0, keepdim=True)
|
44 |
+
if sample_rate != target_sr:
|
45 |
+
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
46 |
+
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
47 |
+
return speech
|
48 |
+
|
49 |
+
|
50 |
+
def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
51 |
+
import tensorrt as trt
|
52 |
+
_min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
|
53 |
+
_opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
|
54 |
+
_max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
|
55 |
+
input_names = ["x", "mask", "mu", "t", "spks", "cond"]
|
56 |
+
|
57 |
+
logging.info("Converting onnx to trt...")
|
58 |
+
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
59 |
+
logger = trt.Logger(trt.Logger.INFO)
|
60 |
+
builder = trt.Builder(logger)
|
61 |
+
network = builder.create_network(network_flags)
|
62 |
+
parser = trt.OnnxParser(network, logger)
|
63 |
+
config = builder.create_builder_config()
|
64 |
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
|
65 |
+
if fp16:
|
66 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
67 |
+
profile = builder.create_optimization_profile()
|
68 |
+
# load onnx model
|
69 |
+
with open(onnx_model, "rb") as f:
|
70 |
+
if not parser.parse(f.read()):
|
71 |
+
for error in range(parser.num_errors):
|
72 |
+
print(parser.get_error(error))
|
73 |
+
raise ValueError('failed to parse {}'.format(onnx_model))
|
74 |
+
# set input shapes
|
75 |
+
for i in range(len(input_names)):
|
76 |
+
profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i])
|
77 |
+
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
78 |
+
# set input and output data type
|
79 |
+
for i in range(network.num_inputs):
|
80 |
+
input_tensor = network.get_input(i)
|
81 |
+
input_tensor.dtype = tensor_dtype
|
82 |
+
for i in range(network.num_outputs):
|
83 |
+
output_tensor = network.get_output(i)
|
84 |
+
output_tensor.dtype = tensor_dtype
|
85 |
+
config.add_optimization_profile(profile)
|
86 |
+
engine_bytes = builder.build_serialized_network(network, config)
|
87 |
+
# save trt engine
|
88 |
+
with open(trt_model, "wb") as f:
|
89 |
+
f.write(engine_bytes)
|
cosyvoice/utils/frontend_utils.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import re
|
16 |
+
import regex
|
17 |
+
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
18 |
+
|
19 |
+
|
20 |
+
# whether contain chinese character
|
21 |
+
def contains_chinese(text):
|
22 |
+
return bool(chinese_char_pattern.search(text))
|
23 |
+
|
24 |
+
|
25 |
+
# replace special symbol
|
26 |
+
def replace_corner_mark(text):
|
27 |
+
text = text.replace('²', '平方')
|
28 |
+
text = text.replace('³', '立方')
|
29 |
+
return text
|
30 |
+
|
31 |
+
|
32 |
+
# remove meaningless symbol
|
33 |
+
def remove_bracket(text):
|
34 |
+
text = text.replace('(', '').replace(')', '')
|
35 |
+
text = text.replace('【', '').replace('】', '')
|
36 |
+
text = text.replace('`', '').replace('`', '')
|
37 |
+
text = text.replace("——", " ")
|
38 |
+
return text
|
39 |
+
|
40 |
+
|
41 |
+
# spell Arabic numerals
|
42 |
+
def spell_out_number(text: str, inflect_parser):
|
43 |
+
new_text = []
|
44 |
+
st = None
|
45 |
+
for i, c in enumerate(text):
|
46 |
+
if not c.isdigit():
|
47 |
+
if st is not None:
|
48 |
+
num_str = inflect_parser.number_to_words(text[st: i])
|
49 |
+
new_text.append(num_str)
|
50 |
+
st = None
|
51 |
+
new_text.append(c)
|
52 |
+
else:
|
53 |
+
if st is None:
|
54 |
+
st = i
|
55 |
+
if st is not None and st < len(text):
|
56 |
+
num_str = inflect_parser.number_to_words(text[st:])
|
57 |
+
new_text.append(num_str)
|
58 |
+
return ''.join(new_text)
|
59 |
+
|
60 |
+
|
61 |
+
# split paragrah logic:
|
62 |
+
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
63 |
+
# 2. cal sentence len according to lang
|
64 |
+
# 3. split sentence according to puncatation
|
65 |
+
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
|
66 |
+
def calc_utt_length(_text: str):
|
67 |
+
if lang == "zh":
|
68 |
+
return len(_text)
|
69 |
+
else:
|
70 |
+
return len(tokenize(_text))
|
71 |
+
|
72 |
+
def should_merge(_text: str):
|
73 |
+
if lang == "zh":
|
74 |
+
return len(_text) < merge_len
|
75 |
+
else:
|
76 |
+
return len(tokenize(_text)) < merge_len
|
77 |
+
|
78 |
+
if lang == "zh":
|
79 |
+
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
80 |
+
else:
|
81 |
+
pounc = ['.', '?', '!', ';', ':']
|
82 |
+
if comma_split:
|
83 |
+
pounc.extend([',', ','])
|
84 |
+
|
85 |
+
if text[-1] not in pounc:
|
86 |
+
if lang == "zh":
|
87 |
+
text += "。"
|
88 |
+
else:
|
89 |
+
text += "."
|
90 |
+
|
91 |
+
st = 0
|
92 |
+
utts = []
|
93 |
+
for i, c in enumerate(text):
|
94 |
+
if c in pounc:
|
95 |
+
if len(text[st: i]) > 0:
|
96 |
+
utts.append(text[st: i] + c)
|
97 |
+
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
98 |
+
tmp = utts.pop(-1)
|
99 |
+
utts.append(tmp + text[i + 1])
|
100 |
+
st = i + 2
|
101 |
+
else:
|
102 |
+
st = i + 1
|
103 |
+
|
104 |
+
final_utts = []
|
105 |
+
cur_utt = ""
|
106 |
+
for utt in utts:
|
107 |
+
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
108 |
+
final_utts.append(cur_utt)
|
109 |
+
cur_utt = ""
|
110 |
+
cur_utt = cur_utt + utt
|
111 |
+
if len(cur_utt) > 0:
|
112 |
+
if should_merge(cur_utt) and len(final_utts) != 0:
|
113 |
+
final_utts[-1] = final_utts[-1] + cur_utt
|
114 |
+
else:
|
115 |
+
final_utts.append(cur_utt)
|
116 |
+
|
117 |
+
return final_utts
|
118 |
+
|
119 |
+
|
120 |
+
# remove blank between chinese character
|
121 |
+
def replace_blank(text: str):
|
122 |
+
out_str = []
|
123 |
+
for i, c in enumerate(text):
|
124 |
+
if c == " ":
|
125 |
+
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
126 |
+
(text[i - 1].isascii() and text[i - 1] != " ")):
|
127 |
+
out_str.append(c)
|
128 |
+
else:
|
129 |
+
out_str.append(c)
|
130 |
+
return "".join(out_str)
|
131 |
+
|
132 |
+
|
133 |
+
def is_only_punctuation(text):
|
134 |
+
# Regular expression: Match strings that consist only of punctuation marks or are empty.
|
135 |
+
punctuation_pattern = r'^[\p{P}\p{S}]*$'
|
136 |
+
return bool(regex.fullmatch(punctuation_pattern, text))
|
cosyvoice/utils/mask.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from cosyvoice.utils.file_utils import logging
|
19 |
+
'''
|
20 |
+
def subsequent_mask(
|
21 |
+
size: int,
|
22 |
+
device: torch.device = torch.device("cpu"),
|
23 |
+
) -> torch.Tensor:
|
24 |
+
"""Create mask for subsequent steps (size, size).
|
25 |
+
|
26 |
+
This mask is used only in decoder which works in an auto-regressive mode.
|
27 |
+
This means the current step could only do attention with its left steps.
|
28 |
+
|
29 |
+
In encoder, fully attention is used when streaming is not necessary and
|
30 |
+
the sequence is not long. In this case, no attention mask is needed.
|
31 |
+
|
32 |
+
When streaming is need, chunk-based attention is used in encoder. See
|
33 |
+
subsequent_chunk_mask for the chunk-based attention mask.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
size (int): size of mask
|
37 |
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
38 |
+
dtype (torch.device): result dtype
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
torch.Tensor: mask
|
42 |
+
|
43 |
+
Examples:
|
44 |
+
>>> subsequent_mask(3)
|
45 |
+
[[1, 0, 0],
|
46 |
+
[1, 1, 0],
|
47 |
+
[1, 1, 1]]
|
48 |
+
"""
|
49 |
+
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
50 |
+
return torch.tril(ret)
|
51 |
+
'''
|
52 |
+
|
53 |
+
|
54 |
+
def subsequent_mask(
|
55 |
+
size: int,
|
56 |
+
device: torch.device = torch.device("cpu"),
|
57 |
+
) -> torch.Tensor:
|
58 |
+
"""Create mask for subsequent steps (size, size).
|
59 |
+
|
60 |
+
This mask is used only in decoder which works in an auto-regressive mode.
|
61 |
+
This means the current step could only do attention with its left steps.
|
62 |
+
|
63 |
+
In encoder, fully attention is used when streaming is not necessary and
|
64 |
+
the sequence is not long. In this case, no attention mask is needed.
|
65 |
+
|
66 |
+
When streaming is need, chunk-based attention is used in encoder. See
|
67 |
+
subsequent_chunk_mask for the chunk-based attention mask.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
size (int): size of mask
|
71 |
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
72 |
+
dtype (torch.device): result dtype
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
torch.Tensor: mask
|
76 |
+
|
77 |
+
Examples:
|
78 |
+
>>> subsequent_mask(3)
|
79 |
+
[[1, 0, 0],
|
80 |
+
[1, 1, 0],
|
81 |
+
[1, 1, 1]]
|
82 |
+
"""
|
83 |
+
arange = torch.arange(size, device=device)
|
84 |
+
mask = arange.expand(size, size)
|
85 |
+
arange = arange.unsqueeze(-1)
|
86 |
+
mask = mask <= arange
|
87 |
+
return mask
|
88 |
+
|
89 |
+
|
90 |
+
def subsequent_chunk_mask_deprecated(
|
91 |
+
size: int,
|
92 |
+
chunk_size: int,
|
93 |
+
num_left_chunks: int = -1,
|
94 |
+
device: torch.device = torch.device("cpu"),
|
95 |
+
) -> torch.Tensor:
|
96 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
97 |
+
this is for streaming encoder
|
98 |
+
|
99 |
+
Args:
|
100 |
+
size (int): size of mask
|
101 |
+
chunk_size (int): size of chunk
|
102 |
+
num_left_chunks (int): number of left chunks
|
103 |
+
<0: use full chunk
|
104 |
+
>=0: use num_left_chunks
|
105 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
torch.Tensor: mask
|
109 |
+
|
110 |
+
Examples:
|
111 |
+
>>> subsequent_chunk_mask(4, 2)
|
112 |
+
[[1, 1, 0, 0],
|
113 |
+
[1, 1, 0, 0],
|
114 |
+
[1, 1, 1, 1],
|
115 |
+
[1, 1, 1, 1]]
|
116 |
+
"""
|
117 |
+
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
118 |
+
for i in range(size):
|
119 |
+
if num_left_chunks < 0:
|
120 |
+
start = 0
|
121 |
+
else:
|
122 |
+
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
123 |
+
ending = min((i // chunk_size + 1) * chunk_size, size)
|
124 |
+
ret[i, start:ending] = True
|
125 |
+
return ret
|
126 |
+
|
127 |
+
|
128 |
+
def subsequent_chunk_mask(
|
129 |
+
size: int,
|
130 |
+
chunk_size: int,
|
131 |
+
num_left_chunks: int = -1,
|
132 |
+
device: torch.device = torch.device("cpu"),
|
133 |
+
) -> torch.Tensor:
|
134 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
135 |
+
this is for streaming encoder
|
136 |
+
|
137 |
+
Args:
|
138 |
+
size (int): size of mask
|
139 |
+
chunk_size (int): size of chunk
|
140 |
+
num_left_chunks (int): number of left chunks
|
141 |
+
<0: use full chunk
|
142 |
+
>=0: use num_left_chunks
|
143 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
torch.Tensor: mask
|
147 |
+
|
148 |
+
Examples:
|
149 |
+
>>> subsequent_chunk_mask(4, 2)
|
150 |
+
[[1, 1, 0, 0],
|
151 |
+
[1, 1, 0, 0],
|
152 |
+
[1, 1, 1, 1],
|
153 |
+
[1, 1, 1, 1]]
|
154 |
+
"""
|
155 |
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
156 |
+
# actually this is not needed after we have inference cache implemented, will remove it later
|
157 |
+
pos_idx = torch.arange(size, device=device)
|
158 |
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
159 |
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
160 |
+
return ret
|
161 |
+
|
162 |
+
|
163 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
164 |
+
masks: torch.Tensor,
|
165 |
+
use_dynamic_chunk: bool,
|
166 |
+
use_dynamic_left_chunk: bool,
|
167 |
+
decoding_chunk_size: int,
|
168 |
+
static_chunk_size: int,
|
169 |
+
num_decoding_left_chunks: int,
|
170 |
+
enable_full_context: bool = True):
|
171 |
+
""" Apply optional mask for encoder.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
175 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
176 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
177 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
178 |
+
training.
|
179 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
180 |
+
0: default for training, use random dynamic chunk.
|
181 |
+
<0: for decoding, use full chunk.
|
182 |
+
>0: for decoding, use fixed chunk size as set.
|
183 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
184 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
185 |
+
this parameter will be ignored
|
186 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
187 |
+
the chunk size is decoding_chunk_size.
|
188 |
+
>=0: use num_decoding_left_chunks
|
189 |
+
<0: use all left chunks
|
190 |
+
enable_full_context (bool):
|
191 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
192 |
+
False: chunk size ~ U[1, 25]
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
torch.Tensor: chunk mask of the input xs.
|
196 |
+
"""
|
197 |
+
# Whether to use chunk mask or not
|
198 |
+
if use_dynamic_chunk:
|
199 |
+
max_len = xs.size(1)
|
200 |
+
if decoding_chunk_size < 0:
|
201 |
+
chunk_size = max_len
|
202 |
+
num_left_chunks = -1
|
203 |
+
elif decoding_chunk_size > 0:
|
204 |
+
chunk_size = decoding_chunk_size
|
205 |
+
num_left_chunks = num_decoding_left_chunks
|
206 |
+
else:
|
207 |
+
# chunk size is either [1, 25] or full context(max_len).
|
208 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
209 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
210 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
211 |
+
num_left_chunks = -1
|
212 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
213 |
+
chunk_size = max_len
|
214 |
+
else:
|
215 |
+
chunk_size = chunk_size % 25 + 1
|
216 |
+
if use_dynamic_left_chunk:
|
217 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
218 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
219 |
+
(1, )).item()
|
220 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
221 |
+
num_left_chunks,
|
222 |
+
xs.device) # (L, L)
|
223 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
224 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
225 |
+
elif static_chunk_size > 0:
|
226 |
+
num_left_chunks = num_decoding_left_chunks
|
227 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
228 |
+
num_left_chunks,
|
229 |
+
xs.device) # (L, L)
|
230 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
231 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
232 |
+
else:
|
233 |
+
chunk_masks = masks
|
234 |
+
assert chunk_masks.dtype == torch.bool
|
235 |
+
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
236 |
+
logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
237 |
+
chunk_masks[chunk_masks.sum(dim=-1)==0] = True
|
238 |
+
return chunk_masks
|
239 |
+
|
240 |
+
|
241 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
242 |
+
"""Make mask tensor containing indices of padded part.
|
243 |
+
|
244 |
+
See description of make_non_pad_mask.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
248 |
+
Returns:
|
249 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
250 |
+
|
251 |
+
Examples:
|
252 |
+
>>> lengths = [5, 3, 2]
|
253 |
+
>>> make_pad_mask(lengths)
|
254 |
+
masks = [[0, 0, 0, 0 ,0],
|
255 |
+
[0, 0, 0, 1, 1],
|
256 |
+
[0, 0, 1, 1, 1]]
|
257 |
+
"""
|
258 |
+
batch_size = lengths.size(0)
|
259 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
260 |
+
seq_range = torch.arange(0,
|
261 |
+
max_len,
|
262 |
+
dtype=torch.int64,
|
263 |
+
device=lengths.device)
|
264 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
265 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
266 |
+
mask = seq_range_expand >= seq_length_expand
|
267 |
+
return mask
|
cosyvoice/utils/train_utils.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
# 2023 Horizon Inc. (authors: Xingchen Song)
|
3 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
import torch
|
20 |
+
import json
|
21 |
+
import re
|
22 |
+
import datetime
|
23 |
+
import yaml
|
24 |
+
|
25 |
+
import deepspeed
|
26 |
+
import torch.optim as optim
|
27 |
+
import torch.distributed as dist
|
28 |
+
|
29 |
+
from torch.utils.tensorboard import SummaryWriter
|
30 |
+
from torch.utils.data import DataLoader
|
31 |
+
from torch.nn.utils import clip_grad_norm_
|
32 |
+
|
33 |
+
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
|
34 |
+
|
35 |
+
from cosyvoice.dataset.dataset import Dataset
|
36 |
+
from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
|
37 |
+
|
38 |
+
|
39 |
+
def init_distributed(args):
|
40 |
+
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
41 |
+
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
42 |
+
rank = int(os.environ.get('RANK', 0))
|
43 |
+
logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
|
44 |
+
', rank {}, world_size {}'.format(rank, world_size))
|
45 |
+
if args.train_engine == 'torch_ddp':
|
46 |
+
torch.cuda.set_device(local_rank)
|
47 |
+
dist.init_process_group(args.dist_backend)
|
48 |
+
else:
|
49 |
+
deepspeed.init_distributed(dist_backend=args.dist_backend)
|
50 |
+
return world_size, local_rank, rank
|
51 |
+
|
52 |
+
|
53 |
+
def init_dataset_and_dataloader(args, configs, gan):
|
54 |
+
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
|
55 |
+
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True)
|
56 |
+
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False)
|
57 |
+
|
58 |
+
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
|
59 |
+
train_data_loader = DataLoader(train_dataset,
|
60 |
+
batch_size=None,
|
61 |
+
pin_memory=args.pin_memory,
|
62 |
+
num_workers=args.num_workers,
|
63 |
+
prefetch_factor=args.prefetch)
|
64 |
+
cv_data_loader = DataLoader(cv_dataset,
|
65 |
+
batch_size=None,
|
66 |
+
pin_memory=args.pin_memory,
|
67 |
+
num_workers=args.num_workers,
|
68 |
+
prefetch_factor=args.prefetch)
|
69 |
+
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
|
70 |
+
|
71 |
+
|
72 |
+
def check_modify_and_save_config(args, configs):
|
73 |
+
if args.train_engine == "torch_ddp":
|
74 |
+
configs['train_conf']["dtype"] = 'fp32'
|
75 |
+
else:
|
76 |
+
with open(args.deepspeed_config, 'r') as fin:
|
77 |
+
ds_configs = json.load(fin)
|
78 |
+
if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
|
79 |
+
configs['train_conf']["dtype"] = "fp16"
|
80 |
+
elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
|
81 |
+
configs['train_conf']["dtype"] = "bf16"
|
82 |
+
else:
|
83 |
+
configs['train_conf']["dtype"] = "fp32"
|
84 |
+
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
|
85 |
+
# if use deepspeed, override ddp config
|
86 |
+
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
|
87 |
+
configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
|
88 |
+
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
|
89 |
+
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
|
90 |
+
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
|
91 |
+
return configs
|
92 |
+
|
93 |
+
|
94 |
+
def wrap_cuda_model(args, model):
|
95 |
+
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
|
96 |
+
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
97 |
+
if args.train_engine == "torch_ddp": # native pytorch ddp
|
98 |
+
assert (torch.cuda.is_available())
|
99 |
+
model.cuda()
|
100 |
+
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
|
101 |
+
else:
|
102 |
+
if int(os.environ.get('RANK', 0)) == 0:
|
103 |
+
logging.info("Estimating model states memory needs (zero2)...")
|
104 |
+
estimate_zero2_model_states_mem_needs_all_live(
|
105 |
+
model,
|
106 |
+
num_gpus_per_node=local_world_size,
|
107 |
+
num_nodes=world_size // local_world_size)
|
108 |
+
return model
|
109 |
+
|
110 |
+
|
111 |
+
def init_optimizer_and_scheduler(args, configs, model, gan):
|
112 |
+
if gan is False:
|
113 |
+
if configs['train_conf']['optim'] == 'adam':
|
114 |
+
optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
|
115 |
+
elif configs['train_conf']['optim'] == 'adamw':
|
116 |
+
optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
|
117 |
+
else:
|
118 |
+
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
119 |
+
|
120 |
+
if configs['train_conf']['scheduler'] == 'warmuplr':
|
121 |
+
scheduler_type = WarmupLR
|
122 |
+
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
123 |
+
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
124 |
+
scheduler_type = NoamHoldAnnealing
|
125 |
+
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
126 |
+
elif configs['train_conf']['scheduler'] == 'constantlr':
|
127 |
+
scheduler_type = ConstantLR
|
128 |
+
scheduler = ConstantLR(optimizer)
|
129 |
+
else:
|
130 |
+
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
131 |
+
|
132 |
+
# use deepspeed optimizer for speedup
|
133 |
+
if args.train_engine == "deepspeed":
|
134 |
+
def scheduler(opt):
|
135 |
+
return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
|
136 |
+
model, optimizer, _, scheduler = deepspeed.initialize(
|
137 |
+
args=args,
|
138 |
+
model=model,
|
139 |
+
optimizer=None,
|
140 |
+
lr_scheduler=scheduler,
|
141 |
+
model_parameters=model.parameters())
|
142 |
+
|
143 |
+
optimizer_d, scheduler_d = None, None
|
144 |
+
|
145 |
+
else:
|
146 |
+
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
|
147 |
+
if configs['train_conf']['optim'] == 'adam':
|
148 |
+
optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
149 |
+
elif configs['train_conf']['optim'] == 'adamw':
|
150 |
+
optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
151 |
+
else:
|
152 |
+
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
153 |
+
|
154 |
+
if configs['train_conf']['scheduler'] == 'warmuplr':
|
155 |
+
scheduler_type = WarmupLR
|
156 |
+
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
157 |
+
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
158 |
+
scheduler_type = NoamHoldAnnealing
|
159 |
+
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
160 |
+
elif configs['train_conf']['scheduler'] == 'constantlr':
|
161 |
+
scheduler_type = ConstantLR
|
162 |
+
scheduler = ConstantLR(optimizer)
|
163 |
+
else:
|
164 |
+
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
165 |
+
|
166 |
+
if configs['train_conf']['optim_d'] == 'adam':
|
167 |
+
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
168 |
+
elif configs['train_conf']['optim_d'] == 'adamw':
|
169 |
+
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
170 |
+
else:
|
171 |
+
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
172 |
+
|
173 |
+
if configs['train_conf']['scheduler_d'] == 'warmuplr':
|
174 |
+
scheduler_type = WarmupLR
|
175 |
+
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
176 |
+
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
|
177 |
+
scheduler_type = NoamHoldAnnealing
|
178 |
+
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
179 |
+
elif configs['train_conf']['scheduler'] == 'constantlr':
|
180 |
+
scheduler_type = ConstantLR
|
181 |
+
scheduler_d = ConstantLR(optimizer_d)
|
182 |
+
else:
|
183 |
+
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
184 |
+
return model, optimizer, scheduler, optimizer_d, scheduler_d
|
185 |
+
|
186 |
+
|
187 |
+
def init_summarywriter(args):
|
188 |
+
writer = None
|
189 |
+
if int(os.environ.get('RANK', 0)) == 0:
|
190 |
+
os.makedirs(args.model_dir, exist_ok=True)
|
191 |
+
writer = SummaryWriter(args.tensorboard_dir)
|
192 |
+
return writer
|
193 |
+
|
194 |
+
|
195 |
+
def save_model(model, model_name, info_dict):
|
196 |
+
rank = int(os.environ.get('RANK', 0))
|
197 |
+
model_dir = info_dict["model_dir"]
|
198 |
+
save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
|
199 |
+
|
200 |
+
if info_dict["train_engine"] == "torch_ddp":
|
201 |
+
if rank == 0:
|
202 |
+
torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
|
203 |
+
else:
|
204 |
+
with torch.no_grad():
|
205 |
+
model.save_checkpoint(save_dir=model_dir,
|
206 |
+
tag=model_name,
|
207 |
+
client_state=info_dict)
|
208 |
+
if rank == 0:
|
209 |
+
info_path = re.sub('.pt$', '.yaml', save_model_path)
|
210 |
+
info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
|
211 |
+
with open(info_path, 'w') as fout:
|
212 |
+
data = yaml.dump(info_dict)
|
213 |
+
fout.write(data)
|
214 |
+
logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
|
215 |
+
|
216 |
+
|
217 |
+
def cosyvoice_join(group_join, info_dict):
|
218 |
+
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
219 |
+
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
220 |
+
rank = int(os.environ.get('RANK', 0))
|
221 |
+
|
222 |
+
if info_dict["batch_idx"] != 0:
|
223 |
+
# we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
|
224 |
+
try:
|
225 |
+
dist.monitored_barrier(group=group_join,
|
226 |
+
timeout=group_join.options._timeout)
|
227 |
+
return False
|
228 |
+
except RuntimeError as e:
|
229 |
+
logging.info("Detected uneven workload distribution: {}\n".format(e) +
|
230 |
+
"Break current worker to manually join all workers, " +
|
231 |
+
"world_size {}, current rank {}, current local_rank {}\n".
|
232 |
+
format(world_size, rank, local_rank))
|
233 |
+
return True
|
234 |
+
else:
|
235 |
+
return False
|
236 |
+
|
237 |
+
|
238 |
+
def batch_forward(model, batch, scaler, info_dict):
|
239 |
+
device = int(os.environ.get('LOCAL_RANK', 0))
|
240 |
+
|
241 |
+
dtype = info_dict["dtype"]
|
242 |
+
if dtype == "fp16":
|
243 |
+
dtype = torch.float16
|
244 |
+
elif dtype == "bf16":
|
245 |
+
dtype = torch.bfloat16
|
246 |
+
else: # fp32
|
247 |
+
dtype = torch.float32
|
248 |
+
|
249 |
+
if info_dict['train_engine'] == 'torch_ddp':
|
250 |
+
autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
|
251 |
+
else:
|
252 |
+
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
|
253 |
+
|
254 |
+
with autocast:
|
255 |
+
info_dict['loss_dict'] = model(batch, device)
|
256 |
+
return info_dict
|
257 |
+
|
258 |
+
|
259 |
+
def batch_backward(model, scaler, info_dict):
|
260 |
+
if info_dict["train_engine"] == "deepspeed":
|
261 |
+
scaled_loss = model.backward(info_dict['loss_dict']['loss'])
|
262 |
+
else:
|
263 |
+
scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
|
264 |
+
if scaler is not None:
|
265 |
+
scaler.scale(scaled_loss).backward()
|
266 |
+
else:
|
267 |
+
scaled_loss.backward()
|
268 |
+
|
269 |
+
info_dict['loss_dict']['loss'] = scaled_loss
|
270 |
+
return info_dict
|
271 |
+
|
272 |
+
|
273 |
+
def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
|
274 |
+
grad_norm = 0.0
|
275 |
+
if info_dict['train_engine'] == "deepspeed":
|
276 |
+
info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
|
277 |
+
model.step()
|
278 |
+
grad_norm = model.get_global_grad_norm()
|
279 |
+
elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
|
280 |
+
# Use mixed precision training
|
281 |
+
if scaler is not None:
|
282 |
+
scaler.unscale_(optimizer)
|
283 |
+
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
284 |
+
# We don't check grad here since that if the gradient
|
285 |
+
# has inf/nan values, scaler.step will skip
|
286 |
+
# optimizer.step().
|
287 |
+
if torch.isfinite(grad_norm):
|
288 |
+
scaler.step(optimizer)
|
289 |
+
scaler.update()
|
290 |
+
else:
|
291 |
+
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
292 |
+
if torch.isfinite(grad_norm):
|
293 |
+
optimizer.step()
|
294 |
+
optimizer.zero_grad()
|
295 |
+
scheduler.step()
|
296 |
+
info_dict["lr"] = optimizer.param_groups[0]['lr']
|
297 |
+
info_dict["grad_norm"] = grad_norm
|
298 |
+
return info_dict
|
299 |
+
|
300 |
+
|
301 |
+
def log_per_step(writer, info_dict):
|
302 |
+
tag = info_dict["tag"]
|
303 |
+
epoch = info_dict.get('epoch', 0)
|
304 |
+
step = info_dict["step"]
|
305 |
+
batch_idx = info_dict["batch_idx"]
|
306 |
+
loss_dict = info_dict['loss_dict']
|
307 |
+
rank = int(os.environ.get('RANK', 0))
|
308 |
+
|
309 |
+
# only rank 0 write to tensorboard to avoid multi-process write
|
310 |
+
if writer is not None:
|
311 |
+
if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
|
312 |
+
(info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
|
313 |
+
for k in ['epoch', 'lr', 'grad_norm']:
|
314 |
+
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
|
315 |
+
for k, v in loss_dict.items():
|
316 |
+
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
|
317 |
+
|
318 |
+
# TRAIN & CV, Shell log (stdout)
|
319 |
+
if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
|
320 |
+
log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
|
321 |
+
for name, value in loss_dict.items():
|
322 |
+
log_str += '{} {:.6f} '.format(name, value)
|
323 |
+
if tag == "TRAIN":
|
324 |
+
log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
|
325 |
+
info_dict["lr"], info_dict['grad_norm'])
|
326 |
+
log_str += ' rank {}'.format(rank)
|
327 |
+
logging.debug(log_str)
|
328 |
+
|
329 |
+
|
330 |
+
def log_per_save(writer, info_dict):
|
331 |
+
tag = info_dict["tag"]
|
332 |
+
epoch = info_dict["epoch"]
|
333 |
+
step = info_dict["step"]
|
334 |
+
loss_dict = info_dict["loss_dict"]
|
335 |
+
lr = info_dict['lr']
|
336 |
+
rank = int(os.environ.get('RANK', 0))
|
337 |
+
logging.info(
|
338 |
+
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
|
339 |
+
epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
|
340 |
+
|
341 |
+
if writer is not None:
|
342 |
+
for k in ['epoch', 'lr']:
|
343 |
+
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
|
344 |
+
for k, v in loss_dict.items():
|
345 |
+
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
|
docker/Dockerfile
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
|
2 |
+
|
3 |
+
ARG VENV_NAME="cosyvoice"
|
4 |
+
ENV VENV=$VENV_NAME
|
5 |
+
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
|
6 |
+
|
7 |
+
ENV DEBIAN_FRONTEN=noninteractive
|
8 |
+
ENV PYTHONUNBUFFERED=1
|
9 |
+
SHELL ["/bin/bash", "--login", "-c"]
|
10 |
+
|
11 |
+
RUN apt-get update -y --fix-missing
|
12 |
+
RUN apt-get install -y git build-essential curl wget ffmpeg unzip git git-lfs sox libsox-dev && \
|
13 |
+
apt-get clean && \
|
14 |
+
git lfs install
|
15 |
+
|
16 |
+
# ==================================================================
|
17 |
+
# conda install and conda forge channel as default
|
18 |
+
# ------------------------------------------------------------------
|
19 |
+
# Install miniforge
|
20 |
+
RUN wget --quiet https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O ~/miniforge.sh && \
|
21 |
+
/bin/bash ~/miniforge.sh -b -p /opt/conda && \
|
22 |
+
rm ~/miniforge.sh && \
|
23 |
+
ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
|
24 |
+
echo "source /opt/conda/etc/profile.d/conda.sh" >> /opt/nvidia/entrypoint.d/100.conda.sh && \
|
25 |
+
echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
|
26 |
+
echo "conda activate ${VENV}" >> /opt/nvidia/entrypoint.d/110.conda_default_env.sh && \
|
27 |
+
echo "conda activate ${VENV}" >> $HOME/.bashrc
|
28 |
+
|
29 |
+
ENV PATH /opt/conda/bin:$PATH
|
30 |
+
|
31 |
+
RUN conda config --add channels conda-forge && \
|
32 |
+
conda config --set channel_priority strict
|
33 |
+
# ------------------------------------------------------------------
|
34 |
+
# ~conda
|
35 |
+
# ==================================================================
|
36 |
+
|
37 |
+
RUN conda create -y -n ${VENV} python=3.10
|
38 |
+
ENV CONDA_DEFAULT_ENV=${VENV}
|
39 |
+
ENV PATH /opt/conda/bin:/opt/conda/envs/${VENV}/bin:$PATH
|
40 |
+
|
41 |
+
WORKDIR /workspace
|
42 |
+
|
43 |
+
ENV PYTHONPATH="${PYTHONPATH}:/workspace/CosyVoice:/workspace/CosyVoice/third_party/Matcha-TTS"
|
44 |
+
|
45 |
+
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
46 |
+
|
47 |
+
RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
|
48 |
+
RUN conda activate ${VENV} && cd CosyVoice && \
|
49 |
+
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
50 |
+
|
51 |
+
WORKDIR /workspace/CosyVoice
|
examples/libritts/cosyvoice/conf/cosyvoice.yaml
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# set random seed, so that you may reproduce your result.
|
2 |
+
__set_seed1: !apply:random.seed [1986]
|
3 |
+
__set_seed2: !apply:numpy.random.seed [1986]
|
4 |
+
__set_seed3: !apply:torch.manual_seed [1986]
|
5 |
+
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
6 |
+
|
7 |
+
# fixed params
|
8 |
+
sample_rate: 22050
|
9 |
+
text_encoder_input_size: 512
|
10 |
+
llm_input_size: 1024
|
11 |
+
llm_output_size: 1024
|
12 |
+
spk_embed_dim: 192
|
13 |
+
|
14 |
+
# model params
|
15 |
+
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
16 |
+
# for system/third_party class/function, we do not require this.
|
17 |
+
llm: !new:cosyvoice.llm.llm.TransformerLM
|
18 |
+
text_encoder_input_size: !ref <text_encoder_input_size>
|
19 |
+
llm_input_size: !ref <llm_input_size>
|
20 |
+
llm_output_size: !ref <llm_output_size>
|
21 |
+
text_token_size: 51866 # change to 60515 if you want to train with CosyVoice-300M-25Hz recipe
|
22 |
+
speech_token_size: 4096
|
23 |
+
length_normalized_loss: True
|
24 |
+
lsm_weight: 0
|
25 |
+
spk_embed_dim: !ref <spk_embed_dim>
|
26 |
+
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
27 |
+
input_size: !ref <text_encoder_input_size>
|
28 |
+
output_size: 1024
|
29 |
+
attention_heads: 16
|
30 |
+
linear_units: 4096
|
31 |
+
num_blocks: 6
|
32 |
+
dropout_rate: 0.1
|
33 |
+
positional_dropout_rate: 0.1
|
34 |
+
attention_dropout_rate: 0.0
|
35 |
+
normalize_before: True
|
36 |
+
input_layer: 'linear'
|
37 |
+
pos_enc_layer_type: 'rel_pos_espnet'
|
38 |
+
selfattention_layer_type: 'rel_selfattn'
|
39 |
+
use_cnn_module: False
|
40 |
+
macaron_style: False
|
41 |
+
use_dynamic_chunk: False
|
42 |
+
use_dynamic_left_chunk: False
|
43 |
+
static_chunk_size: 1
|
44 |
+
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
|
45 |
+
input_size: !ref <llm_input_size>
|
46 |
+
output_size: !ref <llm_output_size>
|
47 |
+
attention_heads: 16
|
48 |
+
linear_units: 4096
|
49 |
+
num_blocks: 14
|
50 |
+
dropout_rate: 0.1
|
51 |
+
positional_dropout_rate: 0.1
|
52 |
+
attention_dropout_rate: 0.0
|
53 |
+
input_layer: 'linear_legacy'
|
54 |
+
pos_enc_layer_type: 'rel_pos_espnet'
|
55 |
+
selfattention_layer_type: 'rel_selfattn'
|
56 |
+
static_chunk_size: 1
|
57 |
+
sampling: !name:cosyvoice.utils.common.ras_sampling
|
58 |
+
top_p: 0.8
|
59 |
+
top_k: 25
|
60 |
+
win_size: 10
|
61 |
+
tau_r: 0.1
|
62 |
+
|
63 |
+
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
64 |
+
input_size: 512
|
65 |
+
output_size: 80
|
66 |
+
spk_embed_dim: !ref <spk_embed_dim>
|
67 |
+
output_type: 'mel'
|
68 |
+
vocab_size: 4096
|
69 |
+
input_frame_rate: 50 # change to 25 if you want to train with CosyVoice-300M-25Hz recipe
|
70 |
+
only_mask_loss: True
|
71 |
+
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
72 |
+
output_size: 512
|
73 |
+
attention_heads: 8
|
74 |
+
linear_units: 2048
|
75 |
+
num_blocks: 6
|
76 |
+
dropout_rate: 0.1
|
77 |
+
positional_dropout_rate: 0.1
|
78 |
+
attention_dropout_rate: 0.1
|
79 |
+
normalize_before: True
|
80 |
+
input_layer: 'linear'
|
81 |
+
pos_enc_layer_type: 'rel_pos_espnet'
|
82 |
+
selfattention_layer_type: 'rel_selfattn'
|
83 |
+
input_size: 512
|
84 |
+
use_cnn_module: False
|
85 |
+
macaron_style: False
|
86 |
+
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
|
87 |
+
channels: 80
|
88 |
+
sampling_ratios: [1, 1, 1, 1]
|
89 |
+
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
|
90 |
+
in_channels: 240
|
91 |
+
n_spks: 1
|
92 |
+
spk_emb_dim: 80
|
93 |
+
cfm_params: !new:omegaconf.DictConfig
|
94 |
+
content:
|
95 |
+
sigma_min: 1e-06
|
96 |
+
solver: 'euler'
|
97 |
+
t_scheduler: 'cosine'
|
98 |
+
training_cfg_rate: 0.2
|
99 |
+
inference_cfg_rate: 0.7
|
100 |
+
reg_loss_type: 'l1'
|
101 |
+
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
|
102 |
+
in_channels: 320
|
103 |
+
out_channels: 80
|
104 |
+
channels: [256, 256]
|
105 |
+
dropout: 0.0
|
106 |
+
attention_head_dim: 64
|
107 |
+
n_blocks: 4
|
108 |
+
num_mid_blocks: 12
|
109 |
+
num_heads: 8
|
110 |
+
act_fn: 'gelu'
|
111 |
+
|
112 |
+
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
113 |
+
in_channels: 80
|
114 |
+
base_channels: 512
|
115 |
+
nb_harmonics: 8
|
116 |
+
sampling_rate: !ref <sample_rate>
|
117 |
+
nsf_alpha: 0.1
|
118 |
+
nsf_sigma: 0.003
|
119 |
+
nsf_voiced_threshold: 10
|
120 |
+
upsample_rates: [8, 8]
|
121 |
+
upsample_kernel_sizes: [16, 16]
|
122 |
+
istft_params:
|
123 |
+
n_fft: 16
|
124 |
+
hop_len: 4
|
125 |
+
resblock_kernel_sizes: [3, 7, 11]
|
126 |
+
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
127 |
+
source_resblock_kernel_sizes: [7, 11]
|
128 |
+
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
|
129 |
+
lrelu_slope: 0.1
|
130 |
+
audio_limit: 0.99
|
131 |
+
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
132 |
+
num_class: 1
|
133 |
+
in_channels: 80
|
134 |
+
cond_channels: 512
|
135 |
+
|
136 |
+
# gan related module
|
137 |
+
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
|
138 |
+
n_fft: 1024
|
139 |
+
num_mels: 80
|
140 |
+
sampling_rate: !ref <sample_rate>
|
141 |
+
hop_size: 256
|
142 |
+
win_size: 1024
|
143 |
+
fmin: 0
|
144 |
+
fmax: null
|
145 |
+
center: False
|
146 |
+
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
|
147 |
+
generator: !ref <hift>
|
148 |
+
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
|
149 |
+
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
|
150 |
+
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
|
151 |
+
mel_spec_transform: [
|
152 |
+
!ref <mel_spec_transform1>
|
153 |
+
]
|
154 |
+
|
155 |
+
# processor functions
|
156 |
+
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
157 |
+
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
|
158 |
+
multilingual: True
|
159 |
+
num_languages: 100
|
160 |
+
language: 'en'
|
161 |
+
task: 'transcribe'
|
162 |
+
allowed_special: 'all'
|
163 |
+
tokenize: !name:cosyvoice.dataset.processor.tokenize
|
164 |
+
get_tokenizer: !ref <get_tokenizer>
|
165 |
+
allowed_special: !ref <allowed_special>
|
166 |
+
filter: !name:cosyvoice.dataset.processor.filter
|
167 |
+
max_length: 40960
|
168 |
+
min_length: 0
|
169 |
+
token_max_length: 200
|
170 |
+
token_min_length: 1
|
171 |
+
resample: !name:cosyvoice.dataset.processor.resample
|
172 |
+
resample_rate: !ref <sample_rate>
|
173 |
+
truncate: !name:cosyvoice.dataset.processor.truncate
|
174 |
+
truncate_length: 24576 # must be a multiplier of hop_size
|
175 |
+
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
176 |
+
n_fft: 1024
|
177 |
+
num_mels: 80
|
178 |
+
sampling_rate: !ref <sample_rate>
|
179 |
+
hop_size: 256
|
180 |
+
win_size: 1024
|
181 |
+
fmin: 0
|
182 |
+
fmax: 8000
|
183 |
+
center: False
|
184 |
+
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
185 |
+
feat_extractor: !ref <feat_extractor>
|
186 |
+
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
187 |
+
sample_rate: !ref <sample_rate>
|
188 |
+
hop_size: 256
|
189 |
+
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
190 |
+
normalize: True
|
191 |
+
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
192 |
+
shuffle_size: 1000
|
193 |
+
sort: !name:cosyvoice.dataset.processor.sort
|
194 |
+
sort_size: 500 # sort_size should be less than shuffle_size
|
195 |
+
batch: !name:cosyvoice.dataset.processor.batch
|
196 |
+
batch_type: 'dynamic'
|
197 |
+
max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g
|
198 |
+
padding: !name:cosyvoice.dataset.processor.padding
|
199 |
+
use_spk_embedding: False # change to True during sft
|
200 |
+
|
201 |
+
# dataset processor pipeline
|
202 |
+
data_pipeline: [
|
203 |
+
!ref <parquet_opener>,
|
204 |
+
!ref <tokenize>,
|
205 |
+
!ref <filter>,
|
206 |
+
!ref <resample>,
|
207 |
+
!ref <compute_fbank>,
|
208 |
+
!ref <parse_embedding>,
|
209 |
+
!ref <shuffle>,
|
210 |
+
!ref <sort>,
|
211 |
+
!ref <batch>,
|
212 |
+
!ref <padding>,
|
213 |
+
]
|
214 |
+
data_pipeline_gan: [
|
215 |
+
!ref <parquet_opener>,
|
216 |
+
!ref <tokenize>,
|
217 |
+
!ref <filter>,
|
218 |
+
!ref <resample>,
|
219 |
+
!ref <truncate>,
|
220 |
+
!ref <compute_fbank>,
|
221 |
+
!ref <compute_f0>,
|
222 |
+
!ref <parse_embedding>,
|
223 |
+
!ref <shuffle>,
|
224 |
+
!ref <sort>,
|
225 |
+
!ref <batch>,
|
226 |
+
!ref <padding>,
|
227 |
+
]
|
228 |
+
|
229 |
+
# llm flow train conf
|
230 |
+
train_conf:
|
231 |
+
optim: adam
|
232 |
+
optim_conf:
|
233 |
+
lr: 0.001 # change to 1e-5 during sft
|
234 |
+
scheduler: warmuplr # change to constantlr during sft
|
235 |
+
scheduler_conf:
|
236 |
+
warmup_steps: 2500
|
237 |
+
max_epoch: 200
|
238 |
+
grad_clip: 5
|
239 |
+
accum_grad: 2
|
240 |
+
log_interval: 100
|
241 |
+
save_per_step: -1
|
242 |
+
|
243 |
+
# gan train conf
|
244 |
+
train_conf_gan:
|
245 |
+
optim: adam
|
246 |
+
optim_conf:
|
247 |
+
lr: 0.0002 # use small lr for gan training
|
248 |
+
scheduler: constantlr
|
249 |
+
optim_d: adam
|
250 |
+
optim_conf_d:
|
251 |
+
lr: 0.0002 # use small lr for gan training
|
252 |
+
scheduler_d: constantlr
|
253 |
+
max_epoch: 200
|
254 |
+
grad_clip: 5
|
255 |
+
accum_grad: 1 # in gan training, accum_grad must be 1
|
256 |
+
log_interval: 100
|
257 |
+
save_per_step: -1
|
examples/libritts/cosyvoice/local/download_and_untar.sh
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
|
4 |
+
# Apache 2.0
|
5 |
+
|
6 |
+
remove_archive=false
|
7 |
+
|
8 |
+
if [ "$1" == --remove-archive ]; then
|
9 |
+
remove_archive=true
|
10 |
+
shift
|
11 |
+
fi
|
12 |
+
|
13 |
+
if [ $# -ne 3 ]; then
|
14 |
+
echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
|
15 |
+
echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
|
16 |
+
echo "With --remove-archive it will remove the archive after successfully un-tarring it."
|
17 |
+
echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
|
18 |
+
echo " train-clean-100, train-clean-360, train-other-500."
|
19 |
+
exit 1
|
20 |
+
fi
|
21 |
+
|
22 |
+
data=$1
|
23 |
+
url=$2
|
24 |
+
part=$3
|
25 |
+
|
26 |
+
if [ ! -d "$data" ]; then
|
27 |
+
echo "$0: no such directory $data"
|
28 |
+
exit 1
|
29 |
+
fi
|
30 |
+
|
31 |
+
part_ok=false
|
32 |
+
list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
|
33 |
+
for x in $list; do
|
34 |
+
if [ "$part" == $x ]; then part_ok=true; fi
|
35 |
+
done
|
36 |
+
if ! $part_ok; then
|
37 |
+
echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
|
38 |
+
exit 1
|
39 |
+
fi
|
40 |
+
|
41 |
+
if [ -z "$url" ]; then
|
42 |
+
echo "$0: empty URL base."
|
43 |
+
exit 1
|
44 |
+
fi
|
45 |
+
|
46 |
+
if [ -f $data/LibriTTS/$part/.complete ]; then
|
47 |
+
echo "$0: data part $part was already successfully extracted, nothing to do."
|
48 |
+
exit 0
|
49 |
+
fi
|
50 |
+
|
51 |
+
|
52 |
+
# sizes of the archive files in bytes. This is some older versions.
|
53 |
+
sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
|
54 |
+
# sizes_new is the archive file sizes of the final release. Some of these sizes are of
|
55 |
+
# things we probably won't download.
|
56 |
+
sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
|
57 |
+
|
58 |
+
if [ -f $data/$part.tar.gz ]; then
|
59 |
+
size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
|
60 |
+
size_ok=false
|
61 |
+
for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
|
62 |
+
if ! $size_ok; then
|
63 |
+
echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
|
64 |
+
echo "does not equal the size of one of the archives."
|
65 |
+
rm $data/$part.tar.gz
|
66 |
+
else
|
67 |
+
echo "$data/$part.tar.gz exists and appears to be complete."
|
68 |
+
fi
|
69 |
+
fi
|
70 |
+
|
71 |
+
if [ ! -f $data/$part.tar.gz ]; then
|
72 |
+
if ! which wget >/dev/null; then
|
73 |
+
echo "$0: wget is not installed."
|
74 |
+
exit 1
|
75 |
+
fi
|
76 |
+
full_url=$url/$part.tar.gz
|
77 |
+
echo "$0: downloading data from $full_url. This may take some time, please be patient."
|
78 |
+
|
79 |
+
if ! wget -P $data --no-check-certificate $full_url; then
|
80 |
+
echo "$0: error executing wget $full_url"
|
81 |
+
exit 1
|
82 |
+
fi
|
83 |
+
fi
|
84 |
+
|
85 |
+
if ! tar -C $data -xvzf $data/$part.tar.gz; then
|
86 |
+
echo "$0: error un-tarring archive $data/$part.tar.gz"
|
87 |
+
exit 1
|
88 |
+
fi
|
89 |
+
|
90 |
+
touch $data/LibriTTS/$part/.complete
|
91 |
+
|
92 |
+
echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
|
93 |
+
|
94 |
+
if $remove_archive; then
|
95 |
+
echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
|
96 |
+
rm $data/$part.tar.gz
|
97 |
+
fi
|
examples/libritts/cosyvoice/run.sh
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
3 |
+
. ./path.sh || exit 1;
|
4 |
+
|
5 |
+
stage=-1
|
6 |
+
stop_stage=3
|
7 |
+
|
8 |
+
data_url=www.openslr.org/resources/60
|
9 |
+
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
|
10 |
+
pretrained_model_dir=../../../pretrained_models/CosyVoice-300M
|
11 |
+
|
12 |
+
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
13 |
+
echo "Data Download"
|
14 |
+
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
|
15 |
+
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
|
16 |
+
done
|
17 |
+
fi
|
18 |
+
|
19 |
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
20 |
+
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
|
21 |
+
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
22 |
+
mkdir -p data/$x
|
23 |
+
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
|
24 |
+
done
|
25 |
+
fi
|
26 |
+
|
27 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
28 |
+
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
29 |
+
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
30 |
+
tools/extract_embedding.py --dir data/$x \
|
31 |
+
--onnx_path $pretrained_model_dir/campplus.onnx
|
32 |
+
done
|
33 |
+
fi
|
34 |
+
|
35 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
36 |
+
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
37 |
+
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
38 |
+
tools/extract_speech_token.py --dir data/$x \
|
39 |
+
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
|
40 |
+
done
|
41 |
+
fi
|
42 |
+
|
43 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
44 |
+
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
45 |
+
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
46 |
+
mkdir -p data/$x/parquet
|
47 |
+
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
48 |
+
--num_processes 10 \
|
49 |
+
--src_dir data/$x \
|
50 |
+
--des_dir data/$x/parquet
|
51 |
+
done
|
52 |
+
fi
|
53 |
+
|
54 |
+
# inference
|
55 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
56 |
+
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
|
57 |
+
for mode in sft zero_shot; do
|
58 |
+
python cosyvoice/bin/inference.py --mode $mode \
|
59 |
+
--gpu 0 \
|
60 |
+
--config conf/cosyvoice.yaml \
|
61 |
+
--prompt_data data/test-clean/parquet/data.list \
|
62 |
+
--prompt_utt2data data/test-clean/parquet/utt2data.list \
|
63 |
+
--tts_text `pwd`/tts_text.json \
|
64 |
+
--llm_model $pretrained_model_dir/llm.pt \
|
65 |
+
--flow_model $pretrained_model_dir/flow.pt \
|
66 |
+
--hifigan_model $pretrained_model_dir/hift.pt \
|
67 |
+
--result_dir `pwd`/exp/cosyvoice/test-clean/$mode
|
68 |
+
done
|
69 |
+
fi
|
70 |
+
|
71 |
+
# train llm
|
72 |
+
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
73 |
+
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
74 |
+
job_id=1986
|
75 |
+
dist_backend="nccl"
|
76 |
+
num_workers=2
|
77 |
+
prefetch=100
|
78 |
+
train_engine=torch_ddp
|
79 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
80 |
+
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
|
81 |
+
if [ $train_engine == 'deepspeed' ]; then
|
82 |
+
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
83 |
+
fi
|
84 |
+
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
85 |
+
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
86 |
+
for model in llm flow hifigan; do
|
87 |
+
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
88 |
+
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
89 |
+
cosyvoice/bin/train.py \
|
90 |
+
--train_engine $train_engine \
|
91 |
+
--config conf/cosyvoice.yaml \
|
92 |
+
--train_data data/train.data.list \
|
93 |
+
--cv_data data/dev.data.list \
|
94 |
+
--model $model \
|
95 |
+
--checkpoint $pretrained_model_dir/$model.pt \
|
96 |
+
--model_dir `pwd`/exp/cosyvoice/$model/$train_engine \
|
97 |
+
--tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \
|
98 |
+
--ddp.dist_backend $dist_backend \
|
99 |
+
--num_workers ${num_workers} \
|
100 |
+
--prefetch ${prefetch} \
|
101 |
+
--pin_memory \
|
102 |
+
--use_amp \
|
103 |
+
--deepspeed_config ./conf/ds_stage2.json \
|
104 |
+
--deepspeed.save_states model+optimizer
|
105 |
+
done
|
106 |
+
fi
|
107 |
+
|
108 |
+
# average model
|
109 |
+
average_num=5
|
110 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
111 |
+
for model in llm flow hifigan; do
|
112 |
+
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
|
113 |
+
echo "do model average and final checkpoint is $decode_checkpoint"
|
114 |
+
python cosyvoice/bin/average_model.py \
|
115 |
+
--dst_model $decode_checkpoint \
|
116 |
+
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
|
117 |
+
--num ${average_num} \
|
118 |
+
--val_best
|
119 |
+
done
|
120 |
+
fi
|
121 |
+
|
122 |
+
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
123 |
+
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
124 |
+
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
125 |
+
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
126 |
+
fi
|
examples/magicdata-read/cosyvoice/conf/ds_stage2.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train_micro_batch_size_per_gpu": 1,
|
3 |
+
"gradient_accumulation_steps": 1,
|
4 |
+
"steps_per_print": 100,
|
5 |
+
"gradient_clipping": 5,
|
6 |
+
"fp16": {
|
7 |
+
"enabled": false,
|
8 |
+
"auto_cast": false,
|
9 |
+
"loss_scale": 0,
|
10 |
+
"initial_scale_power": 16,
|
11 |
+
"loss_scale_window": 256,
|
12 |
+
"hysteresis": 2,
|
13 |
+
"consecutive_hysteresis": false,
|
14 |
+
"min_loss_scale": 1
|
15 |
+
},
|
16 |
+
"bf16": {
|
17 |
+
"enabled": false
|
18 |
+
},
|
19 |
+
"zero_force_ds_cpu_optimizer": false,
|
20 |
+
"zero_optimization": {
|
21 |
+
"stage": 2,
|
22 |
+
"offload_optimizer": {
|
23 |
+
"device": "none",
|
24 |
+
"pin_memory": true
|
25 |
+
},
|
26 |
+
"allgather_partitions": true,
|
27 |
+
"allgather_bucket_size": 5e8,
|
28 |
+
"overlap_comm": false,
|
29 |
+
"reduce_scatter": true,
|
30 |
+
"reduce_bucket_size": 5e8,
|
31 |
+
"contiguous_gradients" : true
|
32 |
+
},
|
33 |
+
"optimizer": {
|
34 |
+
"type": "AdamW",
|
35 |
+
"params": {
|
36 |
+
"lr": 0.001,
|
37 |
+
"weight_decay": 0.0001,
|
38 |
+
"torch_adam": true,
|
39 |
+
"adam_w_mode": true
|
40 |
+
}
|
41 |
+
}
|
42 |
+
}
|
examples/magicdata-read/cosyvoice/local/download_and_untar.sh
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
|
4 |
+
# Apache 2.0
|
5 |
+
|
6 |
+
remove_archive=false
|
7 |
+
|
8 |
+
if [ "$1" == --remove-archive ]; then
|
9 |
+
remove_archive=true
|
10 |
+
shift
|
11 |
+
fi
|
12 |
+
|
13 |
+
if [ $# -ne 3 ]; then
|
14 |
+
echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
|
15 |
+
echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
|
16 |
+
echo "With --remove-archive it will remove the archive after successfully un-tarring it."
|
17 |
+
echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
|
18 |
+
echo " train-clean-100, train-clean-360, train-other-500."
|
19 |
+
exit 1
|
20 |
+
fi
|
21 |
+
|
22 |
+
data=$1
|
23 |
+
url=$2
|
24 |
+
part=$3
|
25 |
+
|
26 |
+
if [ ! -d "$data" ]; then
|
27 |
+
echo "$0: no such directory $data"
|
28 |
+
exit 1
|
29 |
+
fi
|
30 |
+
|
31 |
+
part_ok=false
|
32 |
+
list="dev_set test_set train_set"
|
33 |
+
for x in $list; do
|
34 |
+
if [ "$part" == $x ]; then part_ok=true; fi
|
35 |
+
done
|
36 |
+
if ! $part_ok; then
|
37 |
+
echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
|
38 |
+
exit 1
|
39 |
+
fi
|
40 |
+
|
41 |
+
if [ -z "$url" ]; then
|
42 |
+
echo "$0: empty URL base."
|
43 |
+
exit 1
|
44 |
+
fi
|
45 |
+
|
46 |
+
if [ -f $data/.$part.complete ]; then
|
47 |
+
echo "$0: data part $part was already successfully extracted, nothing to do."
|
48 |
+
exit 0
|
49 |
+
fi
|
50 |
+
|
51 |
+
|
52 |
+
# sizes of the archive files in bytes. This is some older versions.
|
53 |
+
sizes_old="1035537823 2201936013 52627842921"
|
54 |
+
# sizes_new is the archive file sizes of the final release. Some of these sizes are of
|
55 |
+
# things we probably won't download.
|
56 |
+
sizes_new="3886385"
|
57 |
+
|
58 |
+
if [ -f $data/$part.tar.gz ]; then
|
59 |
+
size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
|
60 |
+
size_ok=false
|
61 |
+
for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
|
62 |
+
if ! $size_ok; then
|
63 |
+
echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
|
64 |
+
echo "does not equal the size of one of the archives."
|
65 |
+
rm $data/$part.tar.gz
|
66 |
+
else
|
67 |
+
echo "$data/$part.tar.gz exists and appears to be complete."
|
68 |
+
fi
|
69 |
+
fi
|
70 |
+
|
71 |
+
if [ ! -f $data/$part.tar.gz ]; then
|
72 |
+
if ! which wget >/dev/null; then
|
73 |
+
echo "$0: wget is not installed."
|
74 |
+
exit 1
|
75 |
+
fi
|
76 |
+
full_url=$url/$part.tar.gz
|
77 |
+
echo "$0: downloading data from $full_url. This may take some time, please be patient."
|
78 |
+
|
79 |
+
if ! wget -P $data --no-check-certificate $full_url; then
|
80 |
+
echo "$0: error executing wget $full_url"
|
81 |
+
exit 1
|
82 |
+
fi
|
83 |
+
fi
|
84 |
+
|
85 |
+
if ! tar -C $data -xvzf $data/$part.tar.gz; then
|
86 |
+
echo "$0: error un-tarring archive $data/$part.tar.gz"
|
87 |
+
exit 1
|
88 |
+
fi
|
89 |
+
|
90 |
+
touch $data/.$part.complete
|
91 |
+
|
92 |
+
echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
|
93 |
+
|
94 |
+
if $remove_archive; then
|
95 |
+
echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
|
96 |
+
rm $data/$part.tar.gz
|
97 |
+
fi
|
examples/magicdata-read/cosyvoice/local/prepare_data.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
logger = logging.getLogger()
|
8 |
+
|
9 |
+
|
10 |
+
def main():
|
11 |
+
utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
|
12 |
+
with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
|
13 |
+
lines = f.readlines()[1:]
|
14 |
+
lines = [l.split('\t') for l in lines]
|
15 |
+
for wav, spk, content in tqdm(lines):
|
16 |
+
wav, spk, content = wav.strip(), spk.strip(), content.strip()
|
17 |
+
content = content.replace('[FIL]', '')
|
18 |
+
content = content.replace('[SPK]', '')
|
19 |
+
wav = os.path.join(args.src_dir, spk, wav)
|
20 |
+
if not os.path.exists(wav):
|
21 |
+
continue
|
22 |
+
utt = os.path.basename(wav).replace('.wav', '')
|
23 |
+
utt2wav[utt] = wav
|
24 |
+
utt2text[utt] = content
|
25 |
+
utt2spk[utt] = spk
|
26 |
+
if spk not in spk2utt:
|
27 |
+
spk2utt[spk] = []
|
28 |
+
spk2utt[spk].append(utt)
|
29 |
+
|
30 |
+
with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
|
31 |
+
for k, v in utt2wav.items():
|
32 |
+
f.write('{} {}\n'.format(k, v))
|
33 |
+
with open('{}/text'.format(args.des_dir), 'w') as f:
|
34 |
+
for k, v in utt2text.items():
|
35 |
+
f.write('{} {}\n'.format(k, v))
|
36 |
+
with open('{}/utt2spk'.format(args.des_dir), 'w') as f:
|
37 |
+
for k, v in utt2spk.items():
|
38 |
+
f.write('{} {}\n'.format(k, v))
|
39 |
+
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
|
40 |
+
for k, v in spk2utt.items():
|
41 |
+
f.write('{} {}\n'.format(k, ' '.join(v)))
|
42 |
+
return
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
parser = argparse.ArgumentParser()
|
47 |
+
parser.add_argument('--src_dir',
|
48 |
+
type=str)
|
49 |
+
parser.add_argument('--des_dir',
|
50 |
+
type=str)
|
51 |
+
args = parser.parse_args()
|
52 |
+
main()
|
examples/magicdata-read/cosyvoice/path.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
2 |
+
export PYTHONIOENCODING=UTF-8
|
3 |
+
export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH
|
examples/magicdata-read/cosyvoice/run.sh
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
3 |
+
. ./path.sh || exit 1;
|
4 |
+
|
5 |
+
stage=-1
|
6 |
+
stop_stage=3
|
7 |
+
|
8 |
+
data_url=www.openslr.org/resources/68
|
9 |
+
data_dir=/mnt/hengwu.zty/data/tts/openslr/magicdata-read
|
10 |
+
pretrained_model_dir=../../../pretrained_models/CosyVoice-300M
|
11 |
+
|
12 |
+
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
13 |
+
echo "Data Download"
|
14 |
+
for part in dev_set test_set train_set; do
|
15 |
+
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
|
16 |
+
done
|
17 |
+
fi
|
18 |
+
|
19 |
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
20 |
+
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
|
21 |
+
for x in dev test train; do
|
22 |
+
mkdir -p data/$x
|
23 |
+
python local/prepare_data.py --src_dir $data_dir/$x --des_dir data/$x
|
24 |
+
done
|
25 |
+
fi
|
26 |
+
|
27 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
28 |
+
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
29 |
+
for x in dev test train; do
|
30 |
+
tools/extract_embedding.py --dir data/$x \
|
31 |
+
--onnx_path $pretrained_model_dir/campplus.onnx
|
32 |
+
done
|
33 |
+
fi
|
34 |
+
|
35 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
36 |
+
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
37 |
+
for x in dev test train; do
|
38 |
+
tools/extract_speech_token.py --dir data/$x \
|
39 |
+
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
|
40 |
+
done
|
41 |
+
fi
|
42 |
+
|
43 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
44 |
+
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
45 |
+
for x in dev test train; do
|
46 |
+
mkdir -p data/$x/parquet
|
47 |
+
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
48 |
+
--num_processes 10 \
|
49 |
+
--src_dir data/$x \
|
50 |
+
--des_dir data/$x/parquet
|
51 |
+
done
|
52 |
+
fi
|
53 |
+
|
54 |
+
# inference
|
55 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
56 |
+
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
|
57 |
+
for mode in sft zero_shot; do
|
58 |
+
python cosyvoice/bin/inference.py --mode $mode \
|
59 |
+
--gpu 0 \
|
60 |
+
--config conf/cosyvoice.yaml \
|
61 |
+
--prompt_data data/test/parquet/data.list \
|
62 |
+
--prompt_utt2data data/test/parquet/utt2data.list \
|
63 |
+
--tts_text `pwd`/tts_text.json \
|
64 |
+
--llm_model $pretrained_model_dir/llm.pt \
|
65 |
+
--flow_model $pretrained_model_dir/flow.pt \
|
66 |
+
--hifigan_model $pretrained_model_dir/hift.pt \
|
67 |
+
--result_dir `pwd`/exp/cosyvoice/test/$mode
|
68 |
+
done
|
69 |
+
fi
|
70 |
+
|
71 |
+
# train llm
|
72 |
+
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
73 |
+
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
74 |
+
job_id=1986
|
75 |
+
dist_backend="nccl"
|
76 |
+
num_workers=2
|
77 |
+
prefetch=100
|
78 |
+
train_engine=torch_ddp
|
79 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
80 |
+
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
|
81 |
+
if [ $train_engine == 'deepspeed' ]; then
|
82 |
+
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
83 |
+
fi
|
84 |
+
cp data/train/parquet/data.list data/train.data.list
|
85 |
+
cp data/dev/parquet/data.list data/dev.data.list
|
86 |
+
for model in llm flow; do
|
87 |
+
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
88 |
+
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
89 |
+
cosyvoice/bin/train.py \
|
90 |
+
--train_engine $train_engine \
|
91 |
+
--config conf/cosyvoice.yaml \
|
92 |
+
--train_data data/train.data.list \
|
93 |
+
--cv_data data/dev.data.list \
|
94 |
+
--model $model \
|
95 |
+
--checkpoint $pretrained_model_dir/$model.pt \
|
96 |
+
--model_dir `pwd`/exp/cosyvoice/$model/$train_engine \
|
97 |
+
--tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \
|
98 |
+
--ddp.dist_backend $dist_backend \
|
99 |
+
--num_workers ${num_workers} \
|
100 |
+
--prefetch ${prefetch} \
|
101 |
+
--pin_memory \
|
102 |
+
--deepspeed_config ./conf/ds_stage2.json \
|
103 |
+
--deepspeed.save_states model+optimizer
|
104 |
+
done
|
105 |
+
fi
|
106 |
+
|
107 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
108 |
+
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
109 |
+
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
110 |
+
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
111 |
+
fi
|
examples/magicdata-read/cosyvoice/tts_text.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"38_5718_20170915093303": [
|
3 |
+
"我想这出最好歌曲把歌词发到网上请别人帮我作曲急急",
|
4 |
+
"叫他明天早上差五分儿九点去机场"
|
5 |
+
],
|
6 |
+
"38_5721_20170915091235": [
|
7 |
+
"变温室调到零下两度档",
|
8 |
+
"交谈中请勿轻信汇款信息陌生电话请勿使用外挂软件"
|
9 |
+
],
|
10 |
+
"38_5733_20170915130323": [
|
11 |
+
"这是老鹰乐队的一首经典歌曲",
|
12 |
+
"我急用这段音乐我自己找到一段但是有现场杂音"
|
13 |
+
],
|
14 |
+
"38_5836_20170916221414": [
|
15 |
+
"给我播一个陶喆的专辑",
|
16 |
+
"这套餐好贵呀我发这么多短信贵死了"
|
17 |
+
]
|
18 |
+
}
|
runtime/python/fastapi/server.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import argparse
|
17 |
+
import logging
|
18 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
19 |
+
from fastapi import FastAPI, UploadFile, Form, File
|
20 |
+
from fastapi.responses import StreamingResponse
|
21 |
+
from fastapi.middleware.cors import CORSMiddleware
|
22 |
+
import uvicorn
|
23 |
+
import numpy as np
|
24 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
26 |
+
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
27 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
28 |
+
from cosyvoice.utils.file_utils import load_wav
|
29 |
+
|
30 |
+
app = FastAPI()
|
31 |
+
# set cross region allowance
|
32 |
+
app.add_middleware(
|
33 |
+
CORSMiddleware,
|
34 |
+
allow_origins=["*"],
|
35 |
+
allow_credentials=True,
|
36 |
+
allow_methods=["*"],
|
37 |
+
allow_headers=["*"])
|
38 |
+
|
39 |
+
|
40 |
+
def generate_data(model_output):
|
41 |
+
for i in model_output:
|
42 |
+
tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
43 |
+
yield tts_audio
|
44 |
+
|
45 |
+
|
46 |
+
@app.get("/inference_sft")
|
47 |
+
@app.post("/inference_sft")
|
48 |
+
async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
|
49 |
+
model_output = cosyvoice.inference_sft(tts_text, spk_id)
|
50 |
+
return StreamingResponse(generate_data(model_output))
|
51 |
+
|
52 |
+
|
53 |
+
@app.get("/inference_zero_shot")
|
54 |
+
@app.post("/inference_zero_shot")
|
55 |
+
async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
|
56 |
+
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
57 |
+
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
|
58 |
+
return StreamingResponse(generate_data(model_output))
|
59 |
+
|
60 |
+
|
61 |
+
@app.get("/inference_cross_lingual")
|
62 |
+
@app.post("/inference_cross_lingual")
|
63 |
+
async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
|
64 |
+
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
65 |
+
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
|
66 |
+
return StreamingResponse(generate_data(model_output))
|
67 |
+
|
68 |
+
|
69 |
+
@app.get("/inference_instruct")
|
70 |
+
@app.post("/inference_instruct")
|
71 |
+
async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
|
72 |
+
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
|
73 |
+
return StreamingResponse(generate_data(model_output))
|
74 |
+
|
75 |
+
@app.get("/inference_instruct2")
|
76 |
+
@app.post("/inference_instruct2")
|
77 |
+
async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File()):
|
78 |
+
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
79 |
+
model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k)
|
80 |
+
return StreamingResponse(generate_data(model_output))
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
parser = argparse.ArgumentParser()
|
86 |
+
parser.add_argument('--port',
|
87 |
+
type=int,
|
88 |
+
default=50000)
|
89 |
+
parser.add_argument('--model_dir',
|
90 |
+
type=str,
|
91 |
+
default='iic/CosyVoice-300M',
|
92 |
+
help='local path or modelscope repo id')
|
93 |
+
args = parser.parse_args()
|
94 |
+
try:
|
95 |
+
cosyvoice = CosyVoice(args.model_dir)
|
96 |
+
except Exception:
|
97 |
+
try:
|
98 |
+
cosyvoice = CosyVoice2(args.model_dir)
|
99 |
+
except Exception:
|
100 |
+
raise TypeError('no valid model_type!')
|
101 |
+
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
runtime/python/grpc/client.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
17 |
+
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
18 |
+
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
19 |
+
import logging
|
20 |
+
import argparse
|
21 |
+
import torchaudio
|
22 |
+
import cosyvoice_pb2
|
23 |
+
import cosyvoice_pb2_grpc
|
24 |
+
import grpc
|
25 |
+
import torch
|
26 |
+
import numpy as np
|
27 |
+
from cosyvoice.utils.file_utils import load_wav
|
28 |
+
|
29 |
+
|
30 |
+
def main():
|
31 |
+
with grpc.insecure_channel("{}:{}".format(args.host, args.port)) as channel:
|
32 |
+
stub = cosyvoice_pb2_grpc.CosyVoiceStub(channel)
|
33 |
+
request = cosyvoice_pb2.Request()
|
34 |
+
if args.mode == 'sft':
|
35 |
+
logging.info('send sft request')
|
36 |
+
sft_request = cosyvoice_pb2.sftRequest()
|
37 |
+
sft_request.spk_id = args.spk_id
|
38 |
+
sft_request.tts_text = args.tts_text
|
39 |
+
request.sft_request.CopyFrom(sft_request)
|
40 |
+
elif args.mode == 'zero_shot':
|
41 |
+
logging.info('send zero_shot request')
|
42 |
+
zero_shot_request = cosyvoice_pb2.zeroshotRequest()
|
43 |
+
zero_shot_request.tts_text = args.tts_text
|
44 |
+
zero_shot_request.prompt_text = args.prompt_text
|
45 |
+
prompt_speech = load_wav(args.prompt_wav, 16000)
|
46 |
+
zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
|
47 |
+
request.zero_shot_request.CopyFrom(zero_shot_request)
|
48 |
+
elif args.mode == 'cross_lingual':
|
49 |
+
logging.info('send cross_lingual request')
|
50 |
+
cross_lingual_request = cosyvoice_pb2.crosslingualRequest()
|
51 |
+
cross_lingual_request.tts_text = args.tts_text
|
52 |
+
prompt_speech = load_wav(args.prompt_wav, 16000)
|
53 |
+
cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
|
54 |
+
request.cross_lingual_request.CopyFrom(cross_lingual_request)
|
55 |
+
else:
|
56 |
+
logging.info('send instruct request')
|
57 |
+
instruct_request = cosyvoice_pb2.instructRequest()
|
58 |
+
instruct_request.tts_text = args.tts_text
|
59 |
+
instruct_request.spk_id = args.spk_id
|
60 |
+
instruct_request.instruct_text = args.instruct_text
|
61 |
+
request.instruct_request.CopyFrom(instruct_request)
|
62 |
+
|
63 |
+
response = stub.Inference(request)
|
64 |
+
tts_audio = b''
|
65 |
+
for r in response:
|
66 |
+
tts_audio += r.tts_audio
|
67 |
+
tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
|
68 |
+
logging.info('save response to {}'.format(args.tts_wav))
|
69 |
+
torchaudio.save(args.tts_wav, tts_speech, target_sr)
|
70 |
+
logging.info('get response')
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
parser.add_argument('--host',
|
76 |
+
type=str,
|
77 |
+
default='0.0.0.0')
|
78 |
+
parser.add_argument('--port',
|
79 |
+
type=int,
|
80 |
+
default='50000')
|
81 |
+
parser.add_argument('--mode',
|
82 |
+
default='sft',
|
83 |
+
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
|
84 |
+
help='request mode')
|
85 |
+
parser.add_argument('--tts_text',
|
86 |
+
type=str,
|
87 |
+
default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?')
|
88 |
+
parser.add_argument('--spk_id',
|
89 |
+
type=str,
|
90 |
+
default='中文女')
|
91 |
+
parser.add_argument('--prompt_text',
|
92 |
+
type=str,
|
93 |
+
default='希望你以后能够做的比我还好呦。')
|
94 |
+
parser.add_argument('--prompt_wav',
|
95 |
+
type=str,
|
96 |
+
default='../../../asset/zero_shot_prompt.wav')
|
97 |
+
parser.add_argument('--instruct_text',
|
98 |
+
type=str,
|
99 |
+
default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
|
100 |
+
Fights with fervor for justice, but struggles with impulsiveness.')
|
101 |
+
parser.add_argument('--tts_wav',
|
102 |
+
type=str,
|
103 |
+
default='demo.wav')
|
104 |
+
args = parser.parse_args()
|
105 |
+
prompt_sr, target_sr = 16000, 22050
|
106 |
+
main()
|
runtime/python/grpc/server.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
from concurrent import futures
|
17 |
+
import argparse
|
18 |
+
import cosyvoice_pb2
|
19 |
+
import cosyvoice_pb2_grpc
|
20 |
+
import logging
|
21 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
22 |
+
import grpc
|
23 |
+
import torch
|
24 |
+
import numpy as np
|
25 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
26 |
+
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
27 |
+
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
28 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
29 |
+
|
30 |
+
logging.basicConfig(level=logging.DEBUG,
|
31 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
32 |
+
|
33 |
+
|
34 |
+
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
35 |
+
def __init__(self, args):
|
36 |
+
try:
|
37 |
+
self.cosyvoice = CosyVoice(args.model_dir)
|
38 |
+
except Exception:
|
39 |
+
try:
|
40 |
+
self.cosyvoice = CosyVoice2(args.model_dir)
|
41 |
+
except Exception:
|
42 |
+
raise TypeError('no valid model_type!')
|
43 |
+
logging.info('grpc service initialized')
|
44 |
+
|
45 |
+
def Inference(self, request, context):
|
46 |
+
if request.HasField('sft_request'):
|
47 |
+
logging.info('get sft inference request')
|
48 |
+
model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)
|
49 |
+
elif request.HasField('zero_shot_request'):
|
50 |
+
logging.info('get zero_shot inference request')
|
51 |
+
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
52 |
+
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
53 |
+
model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
|
54 |
+
request.zero_shot_request.prompt_text,
|
55 |
+
prompt_speech_16k)
|
56 |
+
elif request.HasField('cross_lingual_request'):
|
57 |
+
logging.info('get cross_lingual inference request')
|
58 |
+
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
59 |
+
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
60 |
+
model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
|
61 |
+
else:
|
62 |
+
logging.info('get instruct inference request')
|
63 |
+
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
|
64 |
+
request.instruct_request.spk_id,
|
65 |
+
request.instruct_request.instruct_text)
|
66 |
+
|
67 |
+
logging.info('send inference response')
|
68 |
+
for i in model_output:
|
69 |
+
response = cosyvoice_pb2.Response()
|
70 |
+
response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
71 |
+
yield response
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
|
76 |
+
cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
|
77 |
+
grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))
|
78 |
+
grpcServer.start()
|
79 |
+
logging.info("server listening on 0.0.0.0:{}".format(args.port))
|
80 |
+
grpcServer.wait_for_termination()
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
parser = argparse.ArgumentParser()
|
85 |
+
parser.add_argument('--port',
|
86 |
+
type=int,
|
87 |
+
default=50000)
|
88 |
+
parser.add_argument('--max_conc',
|
89 |
+
type=int,
|
90 |
+
default=4)
|
91 |
+
parser.add_argument('--model_dir',
|
92 |
+
type=str,
|
93 |
+
default='iic/CosyVoice-300M',
|
94 |
+
help='local path or modelscope repo id')
|
95 |
+
args = parser.parse_args()
|
96 |
+
main()
|
third_party/Matcha-TTS/.env.example
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# example of file for storing private and user specific environment variables, like keys or system paths
|
2 |
+
# rename it to ".env" (excluded from version control by default)
|
3 |
+
# .env is loaded by train.py automatically
|
4 |
+
# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
|
5 |
+
|
6 |
+
MY_VAR="/home/user/my/system/path"
|
third_party/Matcha-TTS/.github/codecov.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
coverage:
|
2 |
+
status:
|
3 |
+
# measures overall project coverage
|
4 |
+
project:
|
5 |
+
default:
|
6 |
+
threshold: 100% # how much decrease in coverage is needed to not consider success
|
7 |
+
|
8 |
+
# measures PR or single commit coverage
|
9 |
+
patch:
|
10 |
+
default:
|
11 |
+
threshold: 100% # how much decrease in coverage is needed to not consider success
|
12 |
+
|
13 |
+
|
14 |
+
# project: off
|
15 |
+
# patch: off
|
third_party/Matcha-TTS/.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.venv
|
106 |
+
env/
|
107 |
+
venv/
|
108 |
+
ENV/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
# Spyder project settings
|
113 |
+
.spyderproject
|
114 |
+
.spyproject
|
115 |
+
|
116 |
+
# Rope project settings
|
117 |
+
.ropeproject
|
118 |
+
|
119 |
+
# mkdocs documentation
|
120 |
+
/site
|
121 |
+
|
122 |
+
# mypy
|
123 |
+
.mypy_cache/
|
124 |
+
.dmypy.json
|
125 |
+
dmypy.json
|
126 |
+
|
127 |
+
# Pyre type checker
|
128 |
+
.pyre/
|
129 |
+
|
130 |
+
### VisualStudioCode
|
131 |
+
.vscode/*
|
132 |
+
!.vscode/settings.json
|
133 |
+
!.vscode/tasks.json
|
134 |
+
!.vscode/launch.json
|
135 |
+
!.vscode/extensions.json
|
136 |
+
*.code-workspace
|
137 |
+
**/.vscode
|
138 |
+
|
139 |
+
# JetBrains
|
140 |
+
.idea/
|
141 |
+
|
142 |
+
# Data & Models
|
143 |
+
*.h5
|
144 |
+
*.tar
|
145 |
+
*.tar.gz
|
146 |
+
|
147 |
+
# Lightning-Hydra-Template
|
148 |
+
configs/local/default.yaml
|
149 |
+
/data/
|
150 |
+
/logs/
|
151 |
+
.env
|
152 |
+
|
153 |
+
# Aim logging
|
154 |
+
.aim
|
155 |
+
|
156 |
+
# Cython complied files
|
157 |
+
matcha/utils/monotonic_align/core.c
|
158 |
+
|
159 |
+
# Ignoring hifigan checkpoint
|
160 |
+
generator_v1
|
161 |
+
g_02500000
|
162 |
+
gradio_cached_examples/
|
163 |
+
synth_output/
|
third_party/Matcha-TTS/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_language_version:
|
2 |
+
python: python3.10
|
3 |
+
|
4 |
+
repos:
|
5 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
6 |
+
rev: v4.5.0
|
7 |
+
hooks:
|
8 |
+
# list of supported hooks: https://pre-commit.com/hooks.html
|
9 |
+
- id: trailing-whitespace
|
10 |
+
- id: end-of-file-fixer
|
11 |
+
# - id: check-docstring-first
|
12 |
+
- id: check-yaml
|
13 |
+
- id: debug-statements
|
14 |
+
- id: detect-private-key
|
15 |
+
- id: check-toml
|
16 |
+
- id: check-case-conflict
|
17 |
+
- id: check-added-large-files
|
18 |
+
|
19 |
+
# python code formatting
|
20 |
+
- repo: https://github.com/psf/black
|
21 |
+
rev: 23.12.1
|
22 |
+
hooks:
|
23 |
+
- id: black
|
24 |
+
args: [--line-length, "120"]
|
25 |
+
|
26 |
+
# python import sorting
|
27 |
+
- repo: https://github.com/PyCQA/isort
|
28 |
+
rev: 5.13.2
|
29 |
+
hooks:
|
30 |
+
- id: isort
|
31 |
+
args: ["--profile", "black", "--filter-files"]
|
32 |
+
|
33 |
+
# python upgrading syntax to newer version
|
34 |
+
- repo: https://github.com/asottile/pyupgrade
|
35 |
+
rev: v3.15.0
|
36 |
+
hooks:
|
37 |
+
- id: pyupgrade
|
38 |
+
args: [--py38-plus]
|
39 |
+
|
40 |
+
# python check (PEP8), programming errors and code complexity
|
41 |
+
- repo: https://github.com/PyCQA/flake8
|
42 |
+
rev: 7.0.0
|
43 |
+
hooks:
|
44 |
+
- id: flake8
|
45 |
+
args:
|
46 |
+
[
|
47 |
+
"--max-line-length", "120",
|
48 |
+
"--extend-ignore",
|
49 |
+
"E203,E402,E501,F401,F841,RST2,RST301",
|
50 |
+
"--exclude",
|
51 |
+
"logs/*,data/*,matcha/hifigan/*",
|
52 |
+
]
|
53 |
+
additional_dependencies: [flake8-rst-docstrings==0.3.0]
|
54 |
+
|
55 |
+
# pylint
|
56 |
+
- repo: https://github.com/pycqa/pylint
|
57 |
+
rev: v3.0.3
|
58 |
+
hooks:
|
59 |
+
- id: pylint
|
third_party/Matcha-TTS/.project-root
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# this file is required for inferring the project root directory
|
2 |
+
# do not delete
|
third_party/Matcha-TTS/Makefile
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
help: ## Show help
|
3 |
+
@grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
4 |
+
|
5 |
+
clean: ## Clean autogenerated files
|
6 |
+
rm -rf dist
|
7 |
+
find . -type f -name "*.DS_Store" -ls -delete
|
8 |
+
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
|
9 |
+
find . | grep -E ".pytest_cache" | xargs rm -rf
|
10 |
+
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
|
11 |
+
rm -f .coverage
|
12 |
+
|
13 |
+
clean-logs: ## Clean logs
|
14 |
+
rm -rf logs/**
|
15 |
+
|
16 |
+
create-package: ## Create wheel and tar gz
|
17 |
+
rm -rf dist/
|
18 |
+
python setup.py bdist_wheel --plat-name=manylinux1_x86_64
|
19 |
+
python setup.py sdist
|
20 |
+
python -m twine upload dist/* --verbose --skip-existing
|
21 |
+
|
22 |
+
format: ## Run pre-commit hooks
|
23 |
+
pre-commit run -a
|
24 |
+
|
25 |
+
sync: ## Merge changes from main branch to your current branch
|
26 |
+
git pull
|
27 |
+
git pull origin main
|
28 |
+
|
29 |
+
test: ## Run not slow tests
|
30 |
+
pytest -k "not slow"
|
31 |
+
|
32 |
+
test-full: ## Run all tests
|
33 |
+
pytest
|
34 |
+
|
35 |
+
train-ljspeech: ## Train the model
|
36 |
+
python matcha/train.py experiment=ljspeech
|
37 |
+
|
38 |
+
train-ljspeech-min: ## Train the model with minimum memory
|
39 |
+
python matcha/train.py experiment=ljspeech_min_memory
|
40 |
+
|
41 |
+
start_app: ## Start the app
|
42 |
+
python matcha/app.py
|
third_party/Matcha-TTS/configs/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# this file is needed here to include configs when building project as a package
|
third_party/Matcha-TTS/configs/callbacks/default.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- model_checkpoint.yaml
|
3 |
+
- model_summary.yaml
|
4 |
+
- rich_progress_bar.yaml
|
5 |
+
- _self_
|
third_party/Matcha-TTS/configs/callbacks/model_summary.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
2 |
+
|
3 |
+
model_summary:
|
4 |
+
_target_: lightning.pytorch.callbacks.RichModelSummary
|
5 |
+
max_depth: 3 # the maximum depth of layer nesting that the summary will include
|
third_party/Matcha-TTS/configs/debug/fdr.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# runs 1 train, 1 validation and 1 test step
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
fast_dev_run: true
|
third_party/Matcha-TTS/configs/debug/limit.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# uses only 1% of the training data and 5% of validation/test data
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
max_epochs: 3
|
10 |
+
limit_train_batches: 0.01
|
11 |
+
limit_val_batches: 0.05
|
12 |
+
limit_test_batches: 0.05
|
third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=multispeaker
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /data: hi-fi_en-US_female.yaml
|
8 |
+
|
9 |
+
# all parameters below will be merged with parameters from default configurations set above
|
10 |
+
# this allows you to overwrite only specified parameters
|
11 |
+
|
12 |
+
tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"]
|
13 |
+
|
14 |
+
run_name: hi-fi_en-US_female_piper_phonemizer
|
third_party/Matcha-TTS/configs/hydra/default.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://hydra.cc/docs/configure_hydra/intro/
|
2 |
+
|
3 |
+
# enable color logging
|
4 |
+
defaults:
|
5 |
+
- override hydra_logging: colorlog
|
6 |
+
- override job_logging: colorlog
|
7 |
+
|
8 |
+
# output directory, generated dynamically on each run
|
9 |
+
run:
|
10 |
+
dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
11 |
+
sweep:
|
12 |
+
dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
13 |
+
subdir: ${hydra.job.num}
|
14 |
+
|
15 |
+
job_logging:
|
16 |
+
handlers:
|
17 |
+
file:
|
18 |
+
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
19 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
third_party/Matcha-TTS/configs/logger/comet.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://www.comet.ml
|
2 |
+
|
3 |
+
comet:
|
4 |
+
_target_: lightning.pytorch.loggers.comet.CometLogger
|
5 |
+
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
6 |
+
save_dir: "${paths.output_dir}"
|
7 |
+
project_name: "lightning-hydra-template"
|
8 |
+
rest_api_key: null
|
9 |
+
# experiment_name: ""
|
10 |
+
experiment_key: null # set to resume experiment
|
11 |
+
offline: False
|
12 |
+
prefix: ""
|
third_party/Matcha-TTS/configs/logger/many_loggers.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train with many loggers at once
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
# - comet
|
5 |
+
- csv
|
6 |
+
# - mlflow
|
7 |
+
# - neptune
|
8 |
+
- tensorboard
|
9 |
+
- wandb
|
third_party/Matcha-TTS/configs/logger/mlflow.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://mlflow.org
|
2 |
+
|
3 |
+
mlflow:
|
4 |
+
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
|
5 |
+
# experiment_name: ""
|
6 |
+
# run_name: ""
|
7 |
+
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
|
8 |
+
tags: null
|
9 |
+
# save_dir: "./mlruns"
|
10 |
+
prefix: ""
|
11 |
+
artifact_location: null
|
12 |
+
# run_id: ""
|
third_party/Matcha-TTS/configs/logger/neptune.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://neptune.ai
|
2 |
+
|
3 |
+
neptune:
|
4 |
+
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
|
5 |
+
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
|
6 |
+
project: username/lightning-hydra-template
|
7 |
+
# name: ""
|
8 |
+
log_model_checkpoints: True
|
9 |
+
prefix: ""
|