kevinwang676 commited on
Commit
fd82c69
·
verified ·
1 Parent(s): 0433746

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. 00000309-00000300.wav +3 -0
  3. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.gitattributes +38 -0
  4. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.msc +0 -0
  5. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.mv +1 -0
  6. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/README.md +119 -0
  7. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/__init__.py +0 -0
  8. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/added_tokens.json +3 -0
  9. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/config.json +39 -0
  10. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/generation_config.json +11 -0
  11. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/hf_rwkv_tokenizer.py +279 -0
  12. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/modeling_rwkv7.py +4 -0
  13. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/rwkv_vocab_v20230424.txt +0 -0
  14. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/special_tokens_map.json +6 -0
  15. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/tokenizer_config.json +28 -0
  16. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/README.md +14 -0
  17. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/asset/dingding.png +0 -0
  18. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/campplus.onnx +3 -0
  19. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/configuration.json +1 -0
  20. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/cosyvoice.yaml +116 -0
  21. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/flow.encoder.fp16.zip +3 -0
  22. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/hift.pt +3 -0
  23. CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/spk2info.pt +3 -0
  24. Inference.md +98 -0
  25. LICENSE +201 -0
  26. README.md +181 -3
  27. Trump.wav +3 -0
  28. _config.yml +3 -0
  29. another.wav +3 -0
  30. badXT_71.wav +3 -0
  31. data/cosy/data/data_processor.py +128 -0
  32. data/cosy/test/test_vq.py +171 -0
  33. data/utils/convert_embeddings_2_pt.py +34 -0
  34. data/utils/create_embeddings_from_raw.py +263 -0
  35. data/utils/create_lm_corpus_from_raw.py +156 -0
  36. data/utils/llm_dataset.py +206 -0
  37. data/utils/test_utilities.py +31 -0
  38. data/utils/utilitie.py +767 -0
  39. eval/eval_seed_generate.py +66 -0
  40. gradio/tts_demo_page.py +81 -0
  41. mine.wav +0 -0
  42. new.mp3 +0 -0
  43. new.wav +3 -0
  44. run_multiple_process.sh +137 -0
  45. rwkvtts_requirements.txt +264 -0
  46. third_party/cosyvoice/dataset/processor.py +435 -0
  47. third_party/cosyvoice/flow/decoder.py +301 -0
  48. third_party/cosyvoice/flow/flow.py +239 -0
  49. third_party/cosyvoice/flow/flow_matching.py +217 -0
  50. third_party/cosyvoice/flow/length_regulator.py +69 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Trump.wav filter=lfs diff=lfs merge=lfs -text
37
+ new.wav filter=lfs diff=lfs merge=lfs -text
38
+ badXT_71.wav filter=lfs diff=lfs merge=lfs -text
39
+ zero_shot_prompt.wav filter=lfs diff=lfs merge=lfs -text
40
+ 00000309-00000300.wav filter=lfs diff=lfs merge=lfs -text
41
+ another.wav filter=lfs diff=lfs merge=lfs -text
42
+ zero_2_0.wav filter=lfs diff=lfs merge=lfs -text
43
+ zero_shot_0.wav filter=lfs diff=lfs merge=lfs -text
44
+ zero_1_0.wav filter=lfs diff=lfs merge=lfs -text
45
+ zero_3_0.wav filter=lfs diff=lfs merge=lfs -text
46
+ zero_0_0.wav filter=lfs diff=lfs merge=lfs -text
00000309-00000300.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:631608f5c8b931ece1d45adc7f40a3b3b0ae2ec056a8a08a3565b04cc5750a4b
3
+ size 243244
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ flow.decoder.estimator.fp16.a10.plan filter=lfs diff=lfs merge=lfs -text
37
+ flow.decoder.estimator.fp16.l20.plan filter=lfs diff=lfs merge=lfs -text
38
+ flow.decoder.estimator.fp16.v100.plan filter=lfs diff=lfs merge=lfs -text
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.msc ADDED
Binary file (1.66 kB). View file
 
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.mv ADDED
@@ -0,0 +1 @@
 
 
1
+ Revision:master,CreatedAt:1736490687
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - zh
6
+ - ja
7
+ - ko
8
+ - fr
9
+ - ar
10
+ - es
11
+ - pt
12
+ metrics:
13
+ - accuracy
14
+ base_model:
15
+ - BlinkDL/rwkv-7-world
16
+ pipeline_tag: text-generation
17
+ ---
18
+
19
+ # rwkv7-1.5B-world
20
+
21
+ <!-- Provide a quick summary of what the model is/does. -->
22
+
23
+ This is RWKV-7 model under flash-linear attention format.
24
+
25
+ ## Model Details
26
+
27
+
28
+ ### Model Description
29
+
30
+ <!-- Provide a longer summary of what this model is. -->
31
+
32
+ - **Developed by:** Bo Peng, Yu Zhang, Songlin Yang, Ruichong Zhang
33
+ - **Funded by:** RWKV Project (Under LF AI & Data Foundation)
34
+ - **Model type:** RWKV7
35
+ - **Language(s) (NLP):** English
36
+ - **License:** Apache-2.0
37
+ - **Parameter count:** 1.52B
38
+ - **Tokenizer:** RWKV World tokenizer
39
+ - **Vocabulary size:** 65,536
40
+
41
+ ### Model Sources
42
+
43
+ <!-- Provide the basic links for the model. -->
44
+
45
+ - **Repository:** https://github.com/fla-org/flash-linear-attention ; https://github.com/BlinkDL/RWKV-LM
46
+ - **Paper:** With in Progress
47
+
48
+ ## Uses
49
+
50
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
51
+ Install `flash-linear-attention` and the latest version of `transformers` before using this model:
52
+
53
+ ```bash
54
+ pip install git+https://github.com/fla-org/flash-linear-attention
55
+ pip install 'transformers>=4.48.0'
56
+ ```
57
+
58
+ ### Direct Use
59
+
60
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
61
+ You can use this model just as any other HuggingFace models:
62
+ ```python
63
+ from transformers import AutoModelForCausalLM, AutoTokenizer
64
+ model = AutoModelForCausalLM.from_pretrained('fla-hub/rwkv7-1.5B-world', trust_remote_code=True)
65
+ tokenizer = AutoTokenizer.from_pretrained('fla-hub/rwkv7-1.5B-world', trust_remote_code=True)
66
+
67
+ model = model.cuda()
68
+ prompt = "What is a large language model?"
69
+ messages = [
70
+ {"role": "user", "content": "Who are you?"},
71
+ {"role": "assistant", "content": "I am a GPT-3 based model."},
72
+ {"role": "user", "content": prompt}
73
+ ]
74
+ text = tokenizer.apply_chat_template(
75
+ messages,
76
+ tokenize=False,
77
+ add_generation_prompt=True
78
+ )
79
+
80
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
81
+
82
+ generated_ids = model.generate(
83
+ **model_inputs,
84
+ max_new_tokens=1024,
85
+ )
86
+ generated_ids = [
87
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
88
+ ]
89
+
90
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
91
+ print(response)
92
+ ```
93
+
94
+ ## Training Details
95
+
96
+ ### Training Data
97
+
98
+ This model is trained on the World v3 with a total of 3.119 trillion tokens.
99
+
100
+ #### Training Hyperparameters
101
+
102
+ - **Training regime:** bfloat16, lr 4e-4 to 1e-5 "delayed" cosine decay, wd 0.1 (with increasing batch sizes during the middle)
103
+ - **Final Loss:** 1.9965
104
+ - **Token Count:** 3.119 trillion
105
+
106
+ ## Evaluation
107
+
108
+ #### Metrics
109
+
110
+ `lambada_openai`:
111
+
112
+ before conversion: ppl 4.13 acc 69.4%
113
+
114
+ after conversion: ppl 4.26 acc 68.8% (without apply temple)
115
+
116
+ ## FAQ
117
+ Q: safetensors metadata is none.
118
+
119
+ A: upgrade transformers to >=4.48.0: `pip install 'transformers>=4.48.0'`
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/__init__.py ADDED
File without changes
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<|rwkv_tokenizer_end_of_text|>": 0
3
+ }
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "a_low_rank_dim": 96,
4
+ "architectures": [
5
+ "RWKV7ForCausalLM"
6
+ ],
7
+ "attn": null,
8
+ "attn_mode": "fused_recurrent",
9
+ "auto_map": {
10
+ "AutoConfig": "modeling_rwkv7.RWKV7Config",
11
+ "AutoModel": "modeling_rwkv7.RWKV7Model",
12
+ "AutoModelForCausalLM": "modeling_rwkv7.RWKV7ForCausalLM"
13
+ },
14
+ "bos_token_id": 1,
15
+ "decay_low_rank_dim": 96,
16
+ "eos_token_id": 2,
17
+ "fuse_cross_entropy": true,
18
+ "fuse_norm": false,
19
+ "gate_low_rank_dim": 256,
20
+ "head_dim": 64,
21
+ "hidden_act": "sqrelu",
22
+ "hidden_ratio": 4.0,
23
+ "hidden_size": 2048,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 8192,
26
+ "max_position_embeddings": 2048,
27
+ "model_type": "rwkv7",
28
+ "norm_bias": true,
29
+ "norm_eps": 1e-05,
30
+ "norm_first": true,
31
+ "num_heads": null,
32
+ "num_hidden_layers": 24,
33
+ "tie_word_embeddings": false,
34
+ "torch_dtype": "float32",
35
+ "transformers_version": "4.48.1",
36
+ "use_cache": true,
37
+ "v_low_rank_dim": 64,
38
+ "vocab_size": 65536
39
+ }
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 0,
3
+ "eos_token_id": 0,
4
+ "pad_token_id": 0,
5
+ "max_window_size": 2147483647,
6
+ "do_sample": true,
7
+ "top_k": 65536,
8
+ "top_p": 1.0,
9
+ "temperature": 1.0,
10
+ "transformers_version": "4.48.0"
11
+ }
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/hf_rwkv_tokenizer.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for RWKV."""
16
+
17
+ import os
18
+ import re
19
+ from typing import TYPE_CHECKING, List, Optional, Tuple
20
+
21
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
22
+ from transformers.utils import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ pass
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ VOCAB_FILES_NAMES = {
32
+ "vocab_file": "rwkv_vocab_v20230424.txt",
33
+ }
34
+
35
+ class TRIE:
36
+ __slots__ = tuple("ch,to,values,front".split(","))
37
+ to: list
38
+ values: set
39
+
40
+ def __init__(self, front=None, ch=None):
41
+ self.ch = ch
42
+ self.to = [None for ch in range(256)]
43
+ self.values = set()
44
+ self.front = front
45
+
46
+ def __repr__(self):
47
+ fr = self
48
+ ret = []
49
+ while fr != None:
50
+ if fr.ch != None:
51
+ ret.append(fr.ch)
52
+ fr = fr.front
53
+ return "<TRIE %s %s>" % (ret[::-1], self.values)
54
+
55
+ def add(self, key: bytes, idx: int = 0, val=None):
56
+ if idx == len(key):
57
+ if val is None:
58
+ val = key
59
+ self.values.add(val)
60
+ return self
61
+ ch = key[idx]
62
+ if self.to[ch] is None:
63
+ self.to[ch] = TRIE(front=self, ch=ch)
64
+ return self.to[ch].add(key, idx=idx + 1, val=val)
65
+
66
+ def find_longest(self, key: bytes, idx: int = 0):
67
+ u: TRIE = self
68
+ ch: int = key[idx]
69
+
70
+ while u.to[ch] is not None:
71
+ u = u.to[ch]
72
+ idx += 1
73
+ if u.values:
74
+ ret = idx, u, u.values
75
+ if idx == len(key):
76
+ break
77
+ ch = key[idx]
78
+ return ret
79
+
80
+
81
+ class RWKV_TOKENIZER:
82
+ def __init__(self, file_name):
83
+ self.idx2token = {}
84
+ sorted = [] # must be already sorted
85
+ with open(file_name, "r", encoding="utf-8") as f:
86
+ lines = f.readlines()
87
+ for l in lines:
88
+ idx = int(l[: l.index(" ")])
89
+ x = eval(l[l.index(" ") : l.rindex(" ")])
90
+ x = x.encode("utf-8") if isinstance(x, str) else x
91
+ assert isinstance(x, bytes)
92
+
93
+ assert len(x) == int(l[l.rindex(" ") :])
94
+ sorted += [x]
95
+ self.idx2token[idx] = x
96
+
97
+ self.token2idx = {}
98
+ for k, v in self.idx2token.items():
99
+ self.token2idx[v] = int(k)
100
+
101
+ self.root = TRIE()
102
+ for t, i in self.token2idx.items():
103
+ _ = self.root.add(t, val=(t, i))
104
+
105
+ def encodeBytes(self, src: bytes):
106
+ idx: int = 0
107
+ tokens = []
108
+ while idx < len(src):
109
+ _idx: int = idx
110
+ idx, _, values = self.root.find_longest(src, idx)
111
+ assert idx != _idx
112
+ _, token = next(iter(values))
113
+ tokens.append(token)
114
+ return tokens
115
+
116
+ def decodeBytes(self, tokens):
117
+ return b"".join(map(lambda i: self.idx2token[i], tokens))
118
+
119
+ def encode(self, src):
120
+ if isinstance(src, str):
121
+ return [self.encodeBytes(src.encode("utf-8"))]
122
+ elif isinstance(src, list):
123
+ return [self.encodeBytes(s.encode("utf-8")) for s in src]
124
+
125
+ def decode(self, tokens):
126
+ return [self.decodeBytes(batch).decode("utf-8") for batch in tokens]
127
+ # try:
128
+ # return self.decodeBytes(tokens).decode('utf-8')
129
+ # except:
130
+ # return '\ufffd' # bad utf-8
131
+
132
+ def printTokens(self, tokens):
133
+ for i in tokens:
134
+ s = self.idx2token[i]
135
+ try:
136
+ s = s.decode("utf-8")
137
+ except:
138
+ pass
139
+ print(f"{repr(s)}{i}", end=" ")
140
+ print()
141
+
142
+
143
+ class RwkvTokenizer(PreTrainedTokenizer):
144
+ vocab_files_names = VOCAB_FILES_NAMES
145
+ model_input_names = ["input_ids", "attention_mask"]
146
+
147
+ def __init__(
148
+ self, vocab_file, bos_token="<|rwkv_tokenizer_end_of_text|>", eos_token="<|rwkv_tokenizer_end_of_text|>", unk_token="<|rwkv_tokenizer_end_of_text|>", **kwargs
149
+ ):
150
+ if not os.path.isfile(vocab_file):
151
+ raise ValueError(
152
+ f"Can't find a vocabulary file at path '{vocab_file}'."
153
+ )
154
+
155
+ with open(vocab_file, "r", encoding="utf-8") as reader:
156
+ tokens = reader.readlines()
157
+
158
+ if "add_bos_token" in kwargs:
159
+ self.add_bos_token = kwargs["add_bos_token"]
160
+ else:
161
+ self.add_bos_token = False
162
+ self.trie_tokenizer = RWKV_TOKENIZER(vocab_file)
163
+ vocab = self.trie_tokenizer.token2idx
164
+ self.encoder = vocab
165
+ self.decoder = {v: k for k, v in vocab.items()}
166
+ self._added_tokens_decoder = {0: AddedToken(str(bos_token))}
167
+ super().__init__(
168
+ bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
169
+ )
170
+
171
+ @property
172
+ def vocab_size(self):
173
+ return len(self.encoder)
174
+
175
+ def get_vocab(self):
176
+ vocab = self.encoder
177
+ vocab.update(self.added_tokens_encoder)
178
+ vocab = dict(sorted(vocab.items(), key=lambda item: item[1]))
179
+ return vocab
180
+
181
+ def _tokenize(self, text, split_special_tokens=False):
182
+ # return self.wordpiece_tokenizer.tokenize(text.encode("utf-8"))
183
+ return self.trie_tokenizer.encode(text)[0]
184
+
185
+ def _convert_token_to_id(self, token):
186
+ return token
187
+
188
+ def _convert_id_to_token(self, index):
189
+ """Converts an index (integer) in a token (byte) using the vocab."""
190
+ token = self.decoder.get(index, self.unk_token)
191
+ if isinstance(token, (bytes)):
192
+ token = token.decode("utf-8", errors="replace")
193
+ return token
194
+
195
+ def convert_tokens_to_string(self, tokens):
196
+ """Converts a sequence of tokens (bytes) in a single string. Additional tokens are encoded to bytes"""
197
+ out_string = b"".join(
198
+ [k.encode(errors="replace") if isinstance(k, str) else k for k in tokens]
199
+ ).decode("utf-8")
200
+ return out_string
201
+
202
+ def save_vocabulary(
203
+ self, save_directory: str, filename_prefix: Optional[str] = None
204
+ ) -> Tuple[str]:
205
+ index = 0
206
+ if os.path.isdir(save_directory):
207
+ vocab_file = os.path.join(
208
+ save_directory,
209
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
210
+ )
211
+ else:
212
+ vocab_file = (
213
+ filename_prefix + "-" if filename_prefix else ""
214
+ ) + save_directory
215
+ with open(vocab_file, "w", encoding="utf-8") as writer:
216
+ for token, token_index in sorted(
217
+ self.encoder.items(), key=lambda kv: kv[1]
218
+ ):
219
+ if index != token_index:
220
+ logger.warning(
221
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
222
+ " Please check that the vocabulary is not corrupted!"
223
+ )
224
+ index = token_index
225
+ writer.write(str(token) + "\n")
226
+ index += 1
227
+ return (vocab_file,)
228
+
229
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
230
+ if self.add_bos_token:
231
+ bos_token_ids = [self.bos_token_id]
232
+ else:
233
+ bos_token_ids = []
234
+
235
+ output = bos_token_ids + token_ids_0
236
+
237
+ if token_ids_1 is None:
238
+ return output
239
+
240
+ return output + bos_token_ids + token_ids_1
241
+
242
+ def get_special_tokens_mask(
243
+ self,
244
+ token_ids_0: List[int],
245
+ token_ids_1: Optional[List[int]] = None,
246
+ already_has_special_tokens: bool = False,
247
+ ) -> List[int]:
248
+ """
249
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
250
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
251
+
252
+ Args:
253
+ token_ids_0 (`List[int]`):
254
+ List of IDs.
255
+ token_ids_1 (`List[int]`, *optional*):
256
+ Optional second list of IDs for sequence pairs.
257
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
258
+ Whether or not the token list is already formatted with special tokens for the model.
259
+
260
+ Returns:
261
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
262
+ """
263
+ if already_has_special_tokens:
264
+ return super().get_special_tokens_mask(
265
+ token_ids_0=token_ids_0,
266
+ token_ids_1=token_ids_1,
267
+ already_has_special_tokens=True,
268
+ )
269
+
270
+ if not self.add_bos_token:
271
+ return super().get_special_tokens_mask(
272
+ token_ids_0=token_ids_0,
273
+ token_ids_1=token_ids_1,
274
+ already_has_special_tokens=False,
275
+ )
276
+
277
+ if token_ids_1 is None:
278
+ return [1] + ([0] * len(token_ids_0))
279
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/modeling_rwkv7.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from rwkvfla.models.rwkv7 import RWKV7ForCausalLM, RWKV7Model, RWKV7Config
2
+ RWKV7ForCausalLM = RWKV7ForCausalLM
3
+ RWKV7Model = RWKV7Model
4
+ RWKV7Config = RWKV7Config
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|rwkv_tokenizer_end_of_text|>",
3
+ "eos_token": "\n\n",
4
+ "unk_token": "<|rwkv_tokenizer_end_of_text|>",
5
+ "pad_token": "<|rwkv_tokenizer_end_of_text|>"
6
+ }
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/tokenizer_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<|rwkv_tokenizer_end_of_text|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "auto_map": {
14
+ "AutoTokenizer": [
15
+ "hf_rwkv_tokenizer.RwkvTokenizer",
16
+ null
17
+ ]
18
+ },
19
+ "bos_token": "<|rwkv_tokenizer_end_of_text|>",
20
+ "pad_token": "<|rwkv_tokenizer_end_of_text|>",
21
+ "clean_up_tokenization_spaces": false,
22
+ "eos_token": "\n\n",
23
+ "model_max_length": 1000000000000000019884624838656,
24
+ "tokenizer_class": "RwkvTokenizer",
25
+ "unk_token": "<|rwkv_tokenizer_end_of_text|>",
26
+ "use_fast": false,
27
+ "chat_template": "{{ '<|rwkv_tokenizer_end_of_text|>' }}{% for message in messages %}{% if message['role'] == 'user' %}{{'User: ' + message['content'] + '\n\n'}}{% elif message['role'] == 'system' %}{{'System: ' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'Assistant: ' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
28
+ }
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - zh
4
+ - en
5
+ - ko
6
+ - ja
7
+ base_model:
8
+ - fla-hub/rwkv7-1.5B-world
9
+ pipeline_tag: text-to-speech
10
+ ---
11
+ This is TTS model combined with Cosy's FSQ and RWKV Language model.
12
+ Please refer :
13
+ https://github.com/yynil/RWKVTTS/blob/main/Inference.md
14
+ to use this checkpoint.
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/asset/dingding.png ADDED
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/campplus.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
3
+ size 28303423
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"text-to-speech"}
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/cosyvoice.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set random seed, so that you may reproduce your result.
2
+ __set_seed1: !apply:random.seed [1986]
3
+ __set_seed2: !apply:numpy.random.seed [1986]
4
+ __set_seed3: !apply:torch.manual_seed [1986]
5
+ __set_seed4: !apply:torch.cuda.manual_seed_all [1986]
6
+
7
+ # fixed params
8
+ sample_rate: 24000
9
+ llm_input_size: 2048
10
+ llm_output_size: 2048
11
+ spk_embed_dim: 192
12
+ qwen_pretrain_path: ''
13
+
14
+ # model params
15
+ # for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
16
+ # for system/third_party class/function, we do not require this.
17
+ llm: !new:model.llm.llm.RWKV7LM
18
+ llm_input_size: !ref <llm_input_size>
19
+ llm_output_size: !ref <llm_output_size>
20
+ speech_token_size: 6561
21
+ length_normalized_loss: True
22
+ lsm_weight: 0
23
+ vocab_size: 65548
24
+ llm: !ref <qwen_pretrain_path>
25
+ sampling: !name:cosyvoice.utils.common.ras_sampling
26
+ top_p: 0.8
27
+ top_k: 25
28
+ win_size: 10
29
+ tau_r: 0.1
30
+
31
+ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
32
+ input_size: 512
33
+ output_size: 80
34
+ spk_embed_dim: !ref <spk_embed_dim>
35
+ output_type: 'mel'
36
+ vocab_size: 6561
37
+ input_frame_rate: 25
38
+ only_mask_loss: True
39
+ token_mel_ratio: 2
40
+ pre_lookahead_len: 3
41
+ encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
42
+ output_size: 512
43
+ attention_heads: 8
44
+ linear_units: 2048
45
+ num_blocks: 6
46
+ dropout_rate: 0.1
47
+ positional_dropout_rate: 0.1
48
+ attention_dropout_rate: 0.1
49
+ normalize_before: True
50
+ input_layer: 'linear'
51
+ pos_enc_layer_type: 'rel_pos_espnet'
52
+ selfattention_layer_type: 'rel_selfattn'
53
+ input_size: 512
54
+ use_cnn_module: False
55
+ macaron_style: False
56
+ decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
57
+ in_channels: 240
58
+ n_spks: 1
59
+ spk_emb_dim: 80
60
+ cfm_params: !new:omegaconf.DictConfig
61
+ content:
62
+ sigma_min: 1e-06
63
+ solver: 'euler'
64
+ t_scheduler: 'cosine'
65
+ training_cfg_rate: 0.2
66
+ inference_cfg_rate: 0.7
67
+ reg_loss_type: 'l1'
68
+ estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
69
+ in_channels: 320
70
+ out_channels: 80
71
+ causal: True
72
+ channels: [256]
73
+ dropout: 0.0
74
+ attention_head_dim: 64
75
+ n_blocks: 4
76
+ num_mid_blocks: 12
77
+ num_heads: 8
78
+ act_fn: 'gelu'
79
+
80
+ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
81
+ in_channels: 80
82
+ base_channels: 512
83
+ nb_harmonics: 8
84
+ sampling_rate: !ref <sample_rate>
85
+ nsf_alpha: 0.1
86
+ nsf_sigma: 0.003
87
+ nsf_voiced_threshold: 10
88
+ upsample_rates: [8, 5, 3]
89
+ upsample_kernel_sizes: [16, 11, 7]
90
+ istft_params:
91
+ n_fft: 16
92
+ hop_len: 4
93
+ resblock_kernel_sizes: [3, 7, 11]
94
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
95
+ source_resblock_kernel_sizes: [7, 7, 11]
96
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
97
+ lrelu_slope: 0.1
98
+ audio_limit: 0.99
99
+ f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
100
+ num_class: 1
101
+ in_channels: 80
102
+ cond_channels: 512
103
+
104
+ # processor functions
105
+ get_tokenizer: !name:utils.utilities.get_tokenizer
106
+ model_dir: !ref <qwen_pretrain_path>
107
+ allowed_special: 'all'
108
+ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
109
+ n_fft: 1920
110
+ num_mels: 80
111
+ sampling_rate: !ref <sample_rate>
112
+ hop_size: 480
113
+ win_size: 1920
114
+ fmin: 0
115
+ fmax: 8000
116
+ center: False
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/flow.encoder.fp16.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46d2539ad8bdb90026cd50cb42e45bd389f10108111d742b912feddca105aeb6
3
+ size 116703414
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/hift.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d4af0d661a416c69544eec83ff9c070dc80c37ee53ef44af3a37d910c95bc21
3
+ size 83364158
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/spk2info.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbc8f9064db35ee8163b538c0f6ed9fe0c3e2fe0f560cca910e578138d961285
3
+ size 3281245
Inference.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install the code base and the dependencies
2
+ ```bash
3
+ git clone https://github.com/yynil/RWKVTTS
4
+ ```
5
+ Add these two directories to the PYTHONPATH
6
+ ```bash
7
+ export PYTHONPATH=$PYTHONPATH:/home/user/RWKVTTS:/home/user/RWKVTTS/third_party
8
+ ```
9
+ # Install the dependencies
10
+ ```bash
11
+ conda create -n rwkvtts-311 -y python=3.11
12
+ conda activate rwkvtts-311
13
+ conda install -y -c conda-forge pynini==2.1.6
14
+ cd RWKVTTS
15
+ pip install -r rwkvtts_requirements.txt
16
+ ```
17
+
18
+ Download the pretrained models from the following links:
19
+ https://huggingface.co/yueyulin/CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO
20
+
21
+ Place the CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO to local directory. Let's say /home/user/CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO
22
+
23
+ Add two directories to the PYTHONPATH
24
+
25
+ The example code for inference is as follows:
26
+ ```python
27
+ def do_tts(tts_text,prompt_texts,cosyvoice):
28
+ import logging
29
+ for i, (prompt_audio_file, prompt_text) in enumerate(zip(prompt_audios, prompt_texts)):
30
+ logging.info(f'Processing {prompt_text}')
31
+ prompt_speech_16k = load_wav(prompt_audio_file, 16000)
32
+ with torch.no_grad():
33
+ if prompt_text is not None:
34
+ for j, k in enumerate(cosyvoice.inference_zero_shot(tts_text,prompt_text, prompt_speech_16k, stream=False,speed=1)):
35
+ torchaudio.save('zero_{}_{}.wav'.format(i, j), k['tts_speech'], cosyvoice.sample_rate)
36
+ else:
37
+ for j, k in enumerate(cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=False,speed=1)):
38
+ torchaudio.save('zero_{}_{}.wav'.format(i, j), k['tts_speech'], cosyvoice.sample_rate)
39
+ logging.info(f'Finished processing {prompt_text}')
40
+ if __name__ == '__main__':
41
+ from cosyvoice.cli.cosyvoice import CosyVoice2
42
+ import torch
43
+ import sys
44
+ # model_path = '/home/yueyulin/models/CosyVoice2-0.5B_RWKV_0.19B/'
45
+ # device = 'cuda:0'
46
+ print(sys.argv)
47
+ model_path = sys.argv[1]
48
+ device = sys.argv[2] if len(sys.argv) > 2 else 'cuda:0'
49
+ is_flow_only = sys.argv[3]=='True' if len(sys.argv) > 3 else False
50
+ print(f'is_flow_only: {is_flow_only}')
51
+ cosyvoice = CosyVoice2(model_path,device=device,fp16=False,load_jit=False)
52
+
53
+ from cosyvoice.utils.file_utils import load_wav
54
+ import torchaudio
55
+ prompt_audios = [
56
+ '/home/yueyulin/github/RWKVTTS/zero_shot_prompt.wav',
57
+ '/home/yueyulin/github/RWKVTTS/mine.wav',
58
+ '/home/yueyulin/github/RWKVTTS/new.wav',
59
+ '/home/yueyulin/github/RWKVTTS/Trump.wav',
60
+ ]
61
+
62
+ if not is_flow_only:
63
+ prompt_texts = [
64
+ '希望你以后做的比我还好呦。',
65
+ '少年强则中国强。',
66
+ '我随便说一句话,我喊开始录就开始录。',
67
+ 'numbers of Latino, African American, Asian American and native American voters.'
68
+ ]
69
+ else:
70
+ prompt_texts = [
71
+ None,
72
+ None,
73
+ None,
74
+ None
75
+ ]
76
+ do_tts('Make America great again!',prompt_texts,cosyvoice)
77
+ ```
78
+ More examples can be found in the model/test directory.
79
+
80
+ [Instruct example](model/test/test_instructed.py) is an example to use the instructed voice flow to generate the audio.
81
+ [Embedded ref voice example](model/test/test_speaker_adapter.py) is an example to use the speaker adapter to generate the audio.
82
+
83
+ Please refer the [Service Call URL](service/README.md) for the instructions and reference voices.
84
+
85
+ If you pass the prompt_texts as None, the engine will only clone the voice flow and texture which is good to clone voice cross lingual. If you pass the correct prompt texts to the engine, the engine will try to continue to finish the audio tokens following the prompt audio you provided. This will be good to continue the audio you provided but it will be weird when you try to mix languages.
86
+
87
+ The test source code is [test code](model/test/test_initialize.py).
88
+
89
+ Please change the paths to the correct paths in your system.
90
+
91
+ You can also use your own prompt audio and text. Since the llm module is to finish your audio tokens for you, so please make sure the audio is clean,complete and the text is correct. Otherwise, the result may not be good.
92
+
93
+ The following table shows the example results of the above code:
94
+ | Prompt Audio | Prompt Text | TTS Text | Result |
95
+ | --- | --- | --- | --- |
96
+ | https://github.com/yynil/RWKVTTS/raw/main/zero_shot_prompt.wav | 希望你以后做的比我还好呦。 | 中国在东亚,是世界上最大的国家,也是世界上人口最多的国家。 | https://github.com/yynil/RWKVTTS/raw/main/zero_0_0.wav |
97
+ | https://github.com/yynil/RWKVTTS/raw/main/mine.wav| 少年强则中国强。 | 中国在东亚,是世界上最大的国家,也是世界上人口最多的国家。 | https://github.com/yynil/RWKVTTS/raw/main/zero_1_0.wav |
98
+ | https://github.com/yynil/RWKVTTS/raw/main/new.wav | 我随便说一句话,我喊开始录就开始录。 | 中国在东亚,是世界上最大的国家,也是世界上人口最���的国家。 | https://github.com/yynil/RWKVTTS/raw/main/zero_2_0.wav |
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,181 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RWKVTTS
2
+ This project is to train an RWKV LLM for TTS generation which compatible to other TTS engine(like fish/cosy/chattts).
3
+
4
+ For most of modern LLM based TTS engine, there are two parts :
5
+ 1. VQ VAE: this model is to encode audio to audio tokens and decode audio tokens to audio.
6
+ 2. LLM: this model is to generate audio tokens using text tokens and the prompt audio tokens. The prompt audio tokens are also from VQ VAE.
7
+
8
+ Typically the training of the LLM based TTS involves VQ-VAE training and LLM training, like CosyTTS, ChatTTS and FishTTS. However we focus to train an RWKV LLM model to replace the LLM part in these TTS engines.
9
+
10
+ ```mermaid
11
+ flowchart TB
12
+ node_1[["Input Prompt Text"]]
13
+ node_2(["Text Tokenizer"])
14
+ node_3(["Audio Tokenizer(VQ)"])
15
+ node_4[["Input Reference Audio"]]
16
+ node_5[["Text Tokens"]]
17
+ node_6[["Audio Tokens"]]
18
+ node_7(["Text Embedder"])
19
+ node_8(["Audio Embedder"])
20
+ node_9[["Text Embeddings"]]
21
+ node_10[["Audio Embeddings"]]
22
+ node_11(["Concatenate Embeddings"])
23
+ node_12[["Input Embeddings"]]
24
+ node_13{{"Language Model"}}
25
+ node_14[["Hidden States"]]
26
+ node_15(["Audio Head"])
27
+ node_16{"Continue to decode?"}
28
+ node_17(["Next Step Input"])
29
+ node_18(["Finish Decode"])
30
+ node_1 --> node_2
31
+ node_4 --> node_3
32
+ node_2 --> node_5
33
+ node_3 --> node_6
34
+ node_5 --> node_7
35
+ node_6 --> node_8
36
+ node_7 --> node_9
37
+ node_8 --> node_10
38
+ node_9 --> node_11
39
+ node_10 --> node_11
40
+ node_11 --> node_12
41
+ node_12 --> node_13
42
+ node_13 --> node_14
43
+ node_14 --> node_15
44
+ node_15 --> node_16
45
+ node_16 --"Yes"--> node_17
46
+ node_17 --> node_13
47
+ node_16 --"No"--> node_18
48
+ ```
49
+
50
+ Different TTS engines might have different data layout and different special control token, so we need to prepare different data and train a RWKV LLM model for each TTS engine.
51
+
52
+ # Process to train LLM for different TTS engine
53
+
54
+ ## Cosy 2.0
55
+
56
+ ### Cosy 2.0 Data Layout
57
+
58
+ The layout of Cosy 2.0 for LLM:
59
+
60
+ ```mermaid
61
+
62
+ flowchart LR
63
+ node_1[["SOS Embeddings"]]
64
+ node_2[["Text Embeddings"]]
65
+ node_3[["Task ID Embedings"]]
66
+ node_4[["Audio Embeddings"]]
67
+ node_5[["Last Audio Embeddings"]]
68
+ node_1 --- node_2
69
+ node_2 --- node_3
70
+ node_3 --> node_4
71
+ node_4 --> node_5
72
+
73
+ ```
74
+
75
+ The forward of LLM for cosy 2.0:
76
+ ```mermaid
77
+ graph TD
78
+ A[Input: batch] --> B[Extract tokens and lengths]
79
+ B --> C1[Prepare LLM Target]
80
+ B --> C2[Encode Text Tokens]
81
+ B --> C3[Generate SOS/EOS and Task ID Embeddings]
82
+ B --> C4[Encode Speech Tokens]
83
+
84
+ C1[Prepare LLM Target] --> D1["Create target sequence for each sample<br>[IGNORE_ID, ..., speech_tokens, EOS]"]
85
+ D1 --> D2[Pad and move target to device]
86
+
87
+ C2[Encode Text Tokens] --> E1[Apply text_embedding layer]
88
+
89
+ C3[Generate SOS/EOS and Task ID Embeddings] --> F1[Get SOS/EOS embeddings from llm_embedding]
90
+ C3 --> F2[Get task_id embeddings from llm_embedding]
91
+
92
+ C4[Encode Speech Tokens] --> G1[Apply speech_embedding layer]
93
+
94
+ E1 --> H[Unpad and pad sequence]
95
+ F1 --> H
96
+ F2 --> H
97
+ G1 --> H
98
+
99
+ H --> I1[Generate LM input]
100
+ H --> I2[Create attention mask]
101
+
102
+ I1 --> J[Run LLM forward pass]
103
+ I2 --> J
104
+
105
+ J --> K[Extract hidden states]
106
+ K --> L[Generate logits through llm_decoder]
107
+
108
+ D2 --> M[Compute loss and accuracy]
109
+ L --> M
110
+
111
+ M --> N[Return loss and accuracy]
112
+ ```
113
+
114
+ There are some points to note for Cosy 2.0:
115
+ 1. The prompt audio tokens are used to act reference audio, LLM will generate audio tokens mimic the reference audio.
116
+ 2. '<|endofprompt|>' is used for prompt text, it is a special token to indicate this prompt is an instruction.
117
+
118
+ ### Cosy 2.0 Data Preparation
119
+
120
+ 1. Download reference audio files from https://huggingface.co/datasets/yueyulin/TTS_Reference and put them to folder $REF_AUDIO_DIR. These audios are used to generate audio tokens.
121
+ 2. Download Cosy 2.0-0.5B model from https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B and put it to folder $MODEL_DIR.
122
+ 3. Clone the Cosy 2.0 repo from:https://github.com/yynil/CosyVoice and follow the instruction to install the environment. In this repository, I change the codes to allow user to specify cuda device for multiple processes generation. If you have installed torch 2.6, please remember to force triton downgrading to 3.1.0.
123
+ 4. Prepare the text data for audio tokens's training dataset. Currently we support parquet files and jsonl files. The text field is the only required field in the data file. I download the parquet from [wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia) for Chinese and Engish parquet files.
124
+ 5. Generate the audio tokens using the following command:
125
+ ```bash
126
+ bash run_multiple_process.sh --parquet_files /home/yueyulin/data/wiki/zh/train-00000-of-00006.parquet /home/yueyulin/data/wiki/zh/train-00001-of-00006.parquet /home/yueyulin/data/wiki/zh/train-00002-of-00006.parquet /home/yueyulin/data/wiki/zh/train-00003-of-00006.parquet /home/yueyulin/data/wiki/zh/train-00004-of-00006.parquet /home/yueyulin/data/wiki/zh/train-00005-of-00006.parquet --language zh --prompts_dir extract_data/prompts/zh --device cuda:0 --output_dir /home/yueyulin/data/speech_corpus
127
+ ```
128
+ The prompts_dir is the $REF_AUDIO_DIR, the parquet_files are the list of files downloaded from wikimedia, each file is processed by one file. In my experience, one 4090 can process 6 files at the same time. The output_dir is the dirctory that audio tokens generated and saved.
129
+
130
+
131
+ ### Cosy 2.0 LLM Training
132
+ After data is generated and saved, we will get the JSONL files like :
133
+ ```json
134
+ {"text": "甄别重点监测企业是确保监测数据全面性和代表性的基础。首先,需要根据预警机制的覆盖范围和目标,明确监测企业的选择标准。选择标准可以包括企业规模、市场份额、行业影响力等。其次,通过企业调查、行业协会推荐等方式,初步筛选出符合条件的潜在监测企业。", "tts_speech_tokens": [2031, 4137, 6405, 6405, 6405, 6405, 6405, 6324, 6324, 6324, 6324, 6324, 6324, 4218, 1761, 4509, 2333, 4483, 5934, 6258, 1929, 3482, 314, 2300, 957, 5163, 6309, 5064, 6425, 3992, 1932, 80, 305, 734, 1479, 5650, 2472, 4778, 4487, 6175, 5667, 5373, 2187, 4851, 137, 141, 4919, 4407, 2436, 1295, 2024, 1294, 4940, 4778, 2330, 764, 1762, 2031, 1788, 5943, 5319, 5238, 5338, 3872, 1614, 4920, 6055, 6027, 3084, 5343, 4605, 2330, 218, 2172, 572, 1949, 1331, 865, 4921, 2472, 4688, 4379, 5850, 6342, 6373, 2997, 2529, 5087, 623, 3700, 6292, 6291, 5823, 5830, 2102, 1041, 6225, 6316, 3887, 889, 5487, 3813, 1626, 953, 734, 909, 4314, 4804, 4821, 4463, 23, 4683, 4678, 2724, 4832, 992, 1238, 2673, 324, 2099, 2486, 135, 2001, 4537, 5271, 2519, 957, 1699, 953, 1304, 1028, 4752, 2553, 5560, 4154, 1287, 59, 879, 4921, 2499, 5748, 5019, 240, 5889, 6264, 4293, 2186, 2105, 2005, 6405, 6405, 6324, 6324, 6324, 4137, 4218, 3651, 6048, 3132, 1433, 1457, 3962, 4515, 2482, 4490, 4561, 4669, 6054, 6270, 6316, 4615, 4781, 575, 632, 2031, 183, 4598, 4479, 6181, 5496, 4128, 3887, 1943, 1861, 6288, 5343, 6072, 3319, 2733, 322, 1187, 1727, 1807, 4921, 4677, 5668, 5019, 2427, 2976, 6066, 5332, 63, 73, 380, 4239, 6534, 6543, 5101, 1452, 213, 5921, 2273, 6453, 4347, 4537, 4459, 11, 2124, 866, 386, 485, 2511, 333, 632, 4317, 5772, 5803, 1457, 2163, 889, 5021, 2381, 5675, 5056, 5092, 1951, 3888, 3645, 4218, 6405, 6324, 4137, 1884, 1646, 2726, 377, 3992, 5529, 2481, 6054, 3822, 5340, 2330, 71, 2733, 2499, 5012, 4463, 5850, 6342, 6373, 2268, 4851, 137, 151, 4921, 4435, 4650, 528, 1295, 1295, 2023, 2753, 4850, 4570, 2243, 1047, 56, 113, 4512, 5568, 1662, 971, 5, 1480, 6387, 1045, 65, 460, 2160, 5102, 4568, 5056, 5098, 1602, 6048, 4367, 956, 59, 1524, 6405, 6405, 6324, 6324, 6324, 6324, 6324, 4137, 2031, 2706, 5325, 1653, 3887, 2219, 3667, 5664, 803, 4592, 2163, 5587, 4598, 5026, 5089, 1692, 5976, 1937, 146, 41, 1507, 1950, 2031, 0, 2349, 343, 4607, 5019, 566, 1683, 2166, 5051, 5678, 5057, 5830, 573, 2835, 2856, 5099, 707, 947, 1113, 4675, 4408, 4623, 1294, 2024, 2023, 3481, 4778, 2411, 1208, 1302, 660, 5827, 5345, 5074, 4560, 6501, 1403, 635, 716, 680, 5057, 4970, 1947, 3645, 1458, 1707, 6024, 6049, 5238, 5340, 1696, 5244, 1468, 1946, 509, 1318, 6534, 2800, 4510, 2234, 1991, 2017, 2018, 1370, 470, 2891, 4997, 1972, 1701, 5832, 1458, 1950, 4860, 5589, 1946, 1949, 509, 5369, 4966, 5019, 4849, 2411, 314, 1293, 1267, 377, 6421, 4800, 4416, 4893, 8, 1946, 1967, 1584, 4615, 5019, 2510, 867, 63, 245, 533, 1991, 4218, 6405, 6405, 6324, 6324, 6324, 6324, 6324, 4137, 1950, 4920, 4516, 276, 2024, 4777, 4194, 6373, 5643, 4851, 4448, 65, 1517, 1978, 4218, 6405, 4218, 2112, 1350, 4860, 5074, 5772, 6262, 672, 5097, 5090, 221, 1032, 4675, 4408, 285, 1295, 1294, 557, 4490, 228, 276, 4858, 4807, 2870, 1675, 6051, 1539, 4141, 1946, 4133, 6320, 4699, 982, 1950, 5832, 5835, 3645, 1947, 5589, 5589, 4136, 1946, 1235, 4642, 4993, 4857, 4598, 62, 4431, 4675, 285, 1043, 314, 2414, 2760, 2850, 5094, 3158, 1214, 1032, 2997, 2763, 5345, 5100, 402, 4677, 4857, 4543, 5, 1482, 2004, 56, 515, 1970, 2077, 6534, 3488, 5591, 5690, 5869, 5319, 2331, 5342, 1688, 1679, 1735, 4218, 6324, 6324, 6405, 4218, 2031, 5886, 6291, 6480, 2883, 5829, 5826, 2175, 5799, 5826, 2186, 2183, 5940, 5322, 120, 5918, 4571, 4687, 3813, 962, 737, 1561, 5886, 4077, 1429, 5831, 6560, 3644, 6429, 6507, 6534, 2101, 2186, 5097, 2682, 2673, 2017, 2576, 4594, 1005, 4785, 2760, 854, 1946, 683, 4844, 2733, 4695, 4840, 2192, 1482, 72, 29, 788, 1761, 4921, 4408, 2517, 566, 35, 2192, 5934, 4209, 5652, 4537, 5920, 278, 160, 3462, 4686, 5021, 4490, 5853, 3912, 6374, 2997, 4716, 2567, 140, 3462, 4435, 2436, 1295, 1295, 2023, 3482, 4769, 4598, 89, 1736, 4218, 6405, 6405, 6324, 6324, 4137], "prompt_text": "那么就在两侧的象限同时忙碌。", "llm_prompt_speech_token": [3686, 6324, 4137, 1959, 3666, 4376, 2836, 2127, 578, 2441, 1041, 2337, 6073, 3560, 1369, 5650, 4691, 5192, 2924, 89, 1687, 1539, 4218, 1848, 160, 4760, 2825, 1463, 1946, 1223, 1313, 2067, 5648, 2997, 2268, 2277, 4842, 4763, 308, 1038, 140, 842, 2983, 4672, 4650, 4696, 5995, 5603, 1238, 1238, 4672, 4650, 4777, 2474, 8, 767, 1731, 4299, 2079, 4941, 4947, 665, 719, 4319, 6424, 5067, 5967, 6048, 5967, 5238, 1523, 3875, 3872, 4314, 661, 1946, 1217, 500, 6422, 1506, 4852, 5831, 1457, 1448]}
135
+ {"text": "Once all the Cabinet and Cabinet-level officers have been invested, the act of their investiture usually ends with a \"family photo\" of the new Administration around the new president and vice-president. For this photo, the new ministers' alignment and proximity to the president is dictated by the order of precedence, with the ministers who head older departments standing in the first row, and the heads of the newer departments standing in the back rows. Some departments, such as the Department of Defence, take precedence from prior departments now abolished.", "tts_speech_tokens": [764, 35, 1896, 4299, 6486, 4299, 4299, 4299, 4218, 651, 2112, 2131, 1403, 2792, 2207, 1725, 5401, 281, 575, 683, 4997, 3474, 4492, 195, 87, 5109, 5846, 6077, 2270, 2172, 3828, 4424, 4543, 1520, 1753, 6258, 4075, 141, 5109, 5845, 3647, 1188, 3987, 3750, 4414, 1516, 4180, 5014, 5348, 1441, 6534, 5075, 5100, 1274, 1301, 3569, 3488, 3996, 6183, 4752, 4919, 2328, 3158, 6071, 5264, 5482, 5403, 5844, 5837, 191, 2139, 1839, 2255, 831, 4508, 4576, 6255, 1857, 29, 2, 2228, 5482, 6459, 2004, 2253, 2267, 2255, 885, 2112, 1788, 5916, 5835, 5919, 5919, 5919, 4056, 4299, 2058, 2982, 1295, 305, 1463, 3647, 2383, 2112, 3054, 4603, 3043, 4272, 2260, 4841, 6029, 6062, 5329, 6256, 6465, 2386, 2921, 2204, 4429, 5647, 2085, 2490, 809, 159, 546, 5325, 5298, 917, 1688, 3863, 3872, 3884, 3481, 3480, 4130, 5993, 5979, 5322, 5257, 5634, 4691, 4533, 5100, 1277, 764, 5111, 5, 47, 3748, 4929, 2376, 3583, 2990, 6456, 2232, 2306, 6507, 6210, 4463, 5840, 2270, 4071, 5693, 4663, 5100, 5226, 6510, 6534, 2900, 2567, 137, 882, 1199, 2831, 632, 389, 4251, 4191, 73, 49, 3831, 404, 971, 4853, 4613, 4074, 4314, 2417, 3750, 4507, 4416, 4594, 3624, 5325, 962, 224, 404, 5295, 4596, 2238, 3670, 3848, 4339, 1676, 812, 2441, 6097, 3934, 2261, 3750, 1564, 3401, 6074, 5823, 1383, 4293, 3816, 3734, 2219, 4450, 5482, 2996, 150, 3063, 143, 3019, 3667, 149, 3748, 4278, 4347, 3485, 5270, 4858, 5239, 2568, 2028, 4050, 3011, 32, 2264, 4672, 2991, 888, 804, 149, 2234, 5934, 1744, 2112, 3975, 5916, 5943, 5919, 5943, 5919, 5946, 5916, 3972, 4299, 6402, 6534, 1927, 140, 1038, 2263, 4567, 4413, 5563, 4672, 3999, 6264, 4826, 2810, 2567, 228, 227, 2324, 2504, 1773, 6375, 77, 3831, 754, 3401, 4612, 6498, 4311, 2411, 831, 2255, 4414, 5320, 4920, 2328, 5345, 5169, 4752, 4763, 5014, 6449, 2687, 3413, 3647, 2276, 3670, 4069, 1883, 2330, 4499, 1525, 1762, 1490, 2921, 1639, 2166, 4050, 4304, 2837, 732, 6049, 5405, 2266, 910, 4315, 2399, 798, 4859, 4857, 1923, 4434, 4485, 5152, 4206, 4447, 1917, 2136, 3807, 3740, 5, 2264, 5166, 5409, 806, 2982, 878, 2258, 860, 1525, 1762, 3320, 5169, 2166, 546, 2994, 4526, 4056, 2112, 60, 2274, 2528, 5084, 231, 4450, 4597, 1938, 2163, 650, 5108, 2335, 4188, 4859, 1760, 2096, 2903, 4349, 1684, 873, 3872, 6059, 6058, 5976, 4299, 2136, 4050, 3740, 2, 4432, 6455, 2226, 886, 3063, 881, 71, 2234, 5937, 5650, 5238, 4296, 1422, 2342, 2139, 3462, 2261, 1641, 4314, 230, 186, 2965, 4523, 4509, 4999, 4839, 5345, 6070, 5263, 4839, 3813, 3018, 5825, 2926, 5106, 2924, 194, 147, 1433, 728, 2915, 477, 2325, 5330, 6070, 1527, 2421, 2166, 3564, 6166, 1865, 1676, 2092, 4068, 2255, 1483, 5658, 5726, 2085, 3219, 71, 35, 2219, 3828, 2210, 5047, 6100, 4526, 2934, 3909, 4511, 6453, 6534, 3367, 3863, 3146, 5241, 5323, 6054, 1872, 3881, 947, 380, 632, 2909, 2884, 4296, 5913, 5835, 5919, 5919, 5919, 5838, 3975, 2112, 3648, 2192, 831, 3906, 2222, 5118, 5111, 4487, 879, 5650, 4422, 5256, 6465, 4446, 4522, 3831, 2294, 5588, 5825, 3377, 6050, 1698, 147, 1920, 1404, 6328, 1622, 1676, 2083, 2124, 2336, 3669, 5402, 4269, 2490, 71, 8, 113, 1563, 395, 4238, 2510, 3016, 3936, 4430, 2163, 461, 5192, 5998, 5272, 1869, 651, 4302, 1685, 221, 380, 389, 803, 5412, 4753, 2244, 2028, 3648, 3729, 5916, 5919, 5916, 3732, 3975, 2112, 3894, 5239, 5648, 2250, 2918, 4807, 6258, 879, 4600, 2166, 3483, 6327, 6239, 1652, 1757, 1881, 128, 2264, 5935, 5631, 5729, 5482, 2198, 2309, 1329, 4756, 2263, 4448, 4437, 6454, 4272, 3465, 157, 66, 954, 2166, 5598, 3980, 3836, 1838, 2064, 4069, 2371, 2938, 4565, 4356, 789, 4612, 5940, 6510, 3270, 5, 737, 8, 2234, 3747, 5650, 5482, 4269, 303, 2193, 2447, 4849, 2112, 2085, 4050, 3739, 2192, 4428, 5486, 2253, 885, 2992, 2249, 5205, 3453, 4672, 6186, 6534, 6059, 4068, 2184, 4320, 3978, 4052, 1622, 926, 3140, 231, 157, 2160, 1404, 6084, 3809, 1598, 2092, 6255, 2234, 3750, 5405, 3459, 3669, 23, 1463, 974, 2675, 2891, 2166, 712, 5030, 5023, 5080, 2741, 308, 32, 2203, 5217, 4593, 1437, 303, 2112, 3975], "prompt_text": " So I am gonna do this right now. So let's do it.", "llm_prompt_speech_token": [1822, 5727, 5000, 930, 5015, 2912, 3616, 692, 1250, 1978, 4214, 3485, 2036, 1298, 2918, 5192, 5056, 5074, 5065, 4813, 3005, 3002, 3313, 4238, 795, 4523, 4520, 3038, 4496, 859, 1887, 2490, 3309, 6235, 5264, 6074, 6047, 5339, 5474, 4291, 2915, 2666, 3759, 4056, 4299, 3975, 6159, 6186, 6186, 6186, 5838, 5109, 3732, 2112, 2139, 3945, 4534, 4569, 4575, 6453, 5405, 4461, 4338, 5572, 3809, 2411, 1214, 1205, 3805, 4526, 4379, 2189, 3890, 3242, 1418, 2876, 5828, 2799, 5133, 5563, 5481, 2325, 155, 533, 2801, 3617, 725, 56, 4385, 834, 3444, 5482, 3273, 2166, 2328, 1908, 1372, 868]}
136
+ ```
137
+
138
+ We use Deepspeed to train the model:
139
+ ```bash
140
+ deepspeed --num_nodes 1 --num_gpus 4 train_scripts/train_llm.py --data_file /external_data/yueyudata/speech_corpus/ --model_name /external_data/models/rwkv7-1.5B-world/ --output_dir /external_data/yueyudata/cosy_voice_llm --max_length 2048 --wandb_project toy_cosy_llm --wandb_run_name server2_rwkv_7_1.5B --ds_param_offload True --ds_optimizer_offload True --ds_stage 2 --gradient_checkpointing True --logging_steps 10 --per_device_train_batch_size 8
141
+ ```
142
+ The base model can be downloaded from https://huggingface.co/collections/fla-hub/rwkv7-6790fd37b4b6137b088a0d8a , just choose a proper model for your training.
143
+
144
+
145
+ ### Cosy 2.0 LLM Inference
146
+
147
+ ### Some samples
148
+
149
+ #### Zero shot inference
150
+ prompt audio :
151
+ [prompt audio](mine.wav)
152
+
153
+ prompt text: "今天天气挺不错的。"
154
+
155
+ tts text: "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
156
+
157
+ tts audio:
158
+ [tts audio](zero_shot_0.wav)
159
+
160
+
161
+
162
+ ### TODO:
163
+ 0. Drop prompt audio tokens randomly to simulate unconditional guided generation.
164
+ 1. Add special control tokens in Cosy 2.0 in RWKV tokenizer and add them to generate audio tokens again:
165
+ ```python
166
+ special_tokens = {
167
+ 'eos_token': '<|endoftext|>',
168
+ 'pad_token': '<|endoftext|>',
169
+ 'additional_special_tokens': [
170
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
171
+ '[breath]', '<strong>', '</strong>', '[noise]',
172
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
173
+ '[quick_breath]',
174
+ "<laughter>", "</laughter>",
175
+ "[hissing]", "[sigh]", "[vocalized-noise]",
176
+ "[lipsmack]", "[mn]"
177
+ ]
178
+ }
179
+ ```
180
+ 2. Add special control tokens like dialects in RWKV7LM and generate audio tokens for training.
181
+ 3. Implement streaming generation for Cosy 2.0 in RWKV7LM.
Trump.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:296432bb06954080b77c04a88841d61928d936077f5162947359520fa17836be
3
+ size 342108
_config.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ markdown: kramdown
2
+ kramdown:
3
+ parse_block_html: true
another.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4d103efaf538db967559861dbcf9995b60eca582360a6add5cf27c3faf3a49e
3
+ size 199724
badXT_71.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c5e28420eb8c4506a1988d484fe9270b8422161d733c567abfccd74c106ceb9
3
+ size 794726
data/cosy/data/data_processor.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyexpat import model
2
+ import torchaudio
3
+ from hyperpyyaml import load_hyperpyyaml
4
+ import os
5
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
6
+ from cosyvoice.cli.cosyvoice import CosyVoice2
7
+ import json
8
+ import torch
9
+
10
+ def load_from_configuration(model_dir):
11
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
12
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
13
+ return configs
14
+ def init_process(model_dir,device):
15
+ cosyvoice = CosyVoice2(model_dir, load_jit=False, load_trt=False, fp16=True,device=device)
16
+ # configs = load_from_configuration(model_dir)
17
+ # frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
18
+ # configs['feat_extractor'],
19
+ # '{}/campplus.onnx'.format(model_dir),
20
+ # '{}/speech_tokenizer_v2.onnx'.format(model_dir),
21
+ # '{}/spk2info.pt'.format(model_dir),
22
+ # configs['allowed_special'],
23
+ # device)
24
+ frontend = cosyvoice.frontend
25
+ llm = cosyvoice.model.llm
26
+ return frontend,llm,cosyvoice
27
+
28
+
29
+ def preprocess_prompts(frontend,prompts_dir):
30
+ language_results = {}
31
+ final_rate = 24000
32
+ for root, dirs, files in os.walk(prompts_dir):
33
+ for file in files:
34
+ if file.endswith('.json'):
35
+ json_file = os.path.join(root, file)
36
+ print(f"处理文件 {json_file}")
37
+ language = json_file.split('/')[-2]
38
+ if language not in language_results:
39
+ language_results[language] = []
40
+
41
+ # 尝试不同的编码格式读取文件
42
+ try:
43
+ with open(json_file, 'r', encoding='utf-8') as f:
44
+ json_data = json.load(f)
45
+ except UnicodeDecodeError:
46
+ try:
47
+ # 尝试 GB2312/GBK 编码 (常用于中文)
48
+ with open(json_file, 'r', encoding='gbk') as f:
49
+ json_data = json.load(f)
50
+ except UnicodeDecodeError:
51
+ try:
52
+ # 尝试 GB18030 编码 (扩展的中文编码)
53
+ with open(json_file, 'r', encoding='gb18030') as f:
54
+ json_data = json.load(f)
55
+ except Exception as e:
56
+ print(f"无法读取文件 {json_file}: {e}")
57
+ continue
58
+
59
+ wav_file = json_file.replace('.json', '.wav')
60
+ prompt_text = json_data['text']
61
+ prompt_speech = torchaudio.load(wav_file, backend='soundfile')[0]
62
+ fake_tts_text = "a"
63
+ with torch.no_grad():
64
+ model_input = frontend.frontend_zero_shot(fake_tts_text, prompt_text, prompt_speech,final_rate)
65
+ language_results[language].append((model_input,prompt_text))
66
+ return language_results
67
+
68
+ def generate_speech_tokens(llm,frontend,tts_text,model_input,device):
69
+ tts_text = frontend.text_normalize(tts_text,split=False, text_frontend=True)
70
+ tts_text_token, tts_text_token_len = frontend._extract_text_token(tts_text)
71
+ tts_text_token_len = torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(device)
72
+ prompt_text = model_input['prompt_text'].to(device)
73
+ prompt_text_len = torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(device)
74
+ llm_prompt_speech_token = model_input['llm_prompt_speech_token'].to(device)
75
+ prompt_speech_token_len = torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(device)
76
+ flow_prompt_speech_token = model_input['flow_prompt_speech_token'].to(device)
77
+ prompt_speech_feat = model_input['prompt_speech_feat'].to(device)
78
+ llm_embedding = model_input['llm_embedding'].to(device)
79
+ flow_embedding = model_input['flow_embedding'].to(device)
80
+ speech_tokens = []
81
+ for i in llm.inference(text = tts_text_token,
82
+ text_len = tts_text_token_len,
83
+ prompt_text = prompt_text,
84
+ prompt_text_len = prompt_text_len,
85
+ prompt_speech_token = llm_prompt_speech_token,
86
+ prompt_speech_token_len = prompt_speech_token_len,
87
+ embedding=llm_embedding
88
+ ):
89
+ speech_tokens.append(i)
90
+ tts_speech_tokens = torch.tensor(speech_tokens).unsqueeze(dim=0).to(device)
91
+ return tts_speech_tokens
92
+
93
+ if __name__ == '__main__':
94
+ model_dir = '/data/yueyu/models/CosyVoice2-0.5B'
95
+ prompts_dir = 'extract_data/prompts'
96
+
97
+ device = 'cuda:0'
98
+ frontend,llm,cosyvoice = init_process(model_dir
99
+ ,device)
100
+ prompts = preprocess_prompts(frontend,prompts_dir)
101
+ print(prompts)
102
+ model_input = prompts['zh'][0][0]
103
+ prompt_text = prompts['zh'][0][1]
104
+ tts_text = '扫一扫,立即体验中国银行信用卡好礼、绑卡立减等热门活动,实时掌握更多优惠信息。'
105
+ tts_text = '在中国的一个偏远山区,有一位名叫李远的年轻人,他对集群通信系统有着浓厚的兴趣。每天晚上,他都会在自己的小屋里研究各种关于集群通信系统的资料,试图弄懂其中的原理和运作机制。他对这个领域的研究不仅仅停留在理论层面,还亲手制作了一些模型,试图通过实践来加深理解。'
106
+ tts_text = "歷史(现代汉语词汇,古典文言文称之为史),指人类社会过去的事件和行动,以及对这些事件行为有系统的记录、诠释和研究。歷史可提供今人理解過去,作為未來行事的參考依據,与伦理、哲学和艺术同属人类精神文明的重要成果。历史的第二个含义,即对过去事件的记录和研究,又称历史学”,或简称“史学”。隶属于历史学或与其密切相关的学科有年代学、编纂学、家谱学、古文字学、计量历史学、考古学、社会学和新闻学等,参见历史学。记录和研究历史的人称为历史学家,简称“史学家”,中国古代称为史官。记录历史的书籍称为史书,如《史記》、《汉书》等,粗分為「官修」與「民載」兩類。"
107
+ tts_text = "### 如何提高花样游泳水平"
108
+ tts_speech_tokens = generate_speech_tokens(llm,frontend,tts_text,model_input,device)
109
+ print(tts_speech_tokens)
110
+
111
+
112
+ flow_prompt_speech_token = model_input['flow_prompt_speech_token'].to(device)
113
+ prompt_speech_feat = model_input['prompt_speech_feat'].to(device)
114
+ llm_embedding = model_input['llm_embedding'].to(device)
115
+ flow_embedding = model_input['flow_embedding'].to(device)
116
+ cosyvoice.model.hift_cache_dict['xxxx'] = None
117
+ tts_speech = cosyvoice.model.token2wav(token=tts_speech_tokens,
118
+ prompt_token=flow_prompt_speech_token,
119
+ prompt_feat=prompt_speech_feat,
120
+ embedding=flow_embedding,
121
+ uuid='xxxx',
122
+ token_offset=0,
123
+ finalize=True,
124
+ speed=1.0)
125
+ print(f'tts_speech shape:{tts_speech.shape}')
126
+ tts_speech = tts_speech.cpu()
127
+ torchaudio.save('zh_tts_S.wav', tts_speech, 24000)
128
+ print(model_input)
data/cosy/test/test_vq.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from turtle import back
2
+ from click import prompt
3
+ import torch
4
+ from cosyvoice.cli.cosyvoice import CosyVoice2
5
+ print(torch.cuda.is_available())
6
+ print(torch.cuda.current_device())
7
+ print(torch.cuda.device(0))
8
+ print(torch.cuda.device_count())
9
+ model_path = '/data/yueyu/models/CosyVoice2-0.5B'
10
+ # cosyvoice = CosyVoice2(model_path, load_jit=False, load_trt=False, fp16=False)
11
+ # print(cosyvoice)
12
+ # from cosyvoice.utils.file_utils import load_wav
13
+ # import torchaudio
14
+ # prompt_speech_16k = load_wav('/home/yueyulin/github/CosyVoice/asset/zero_shot_prompt.wav', 16000)
15
+ # # prompt_speech_16k = torch.rand((1, 16000))
16
+ # for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
17
+ # torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
18
+
19
+ # for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
20
+ # torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
21
+ # # instruct usage
22
+ # for i, j in enumerate(cosyvoice.inference_instruct2('吾今朝早上去外婆家吃饭。', '用上海话说这句话', prompt_speech_16k, stream=False)):
23
+ # torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
24
+
25
+ from hyperpyyaml import load_hyperpyyaml
26
+ import os
27
+ def load_from_configuration(model_dir):
28
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
29
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
30
+ return configs
31
+
32
+ configs = load_from_configuration(model_path)
33
+ print(configs)
34
+
35
+ import torchaudio
36
+ def load_wav(wav, target_sr):
37
+ speech, sample_rate = torchaudio.load(wav, backend='soundfile')
38
+ speech = speech.mean(dim=0, keepdim=True)
39
+ if sample_rate != target_sr:
40
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
41
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
42
+ return speech
43
+
44
+ zh_prompt_tar_file="/data/yueyu/data/Emilia-Dataset/Emilia/ZH/ZH-B000000.tar"
45
+ en_prompt_tar_file="/data/yueyu/data/Emilia-Dataset/Emilia/EN/EN-B000000.tar"
46
+
47
+
48
+ def load_file_list(tar_file):
49
+ #the files are FILE_NAME.mp3/FILE_NAME.json
50
+ #return all FILE_NAME as a list which has a mp3 and json
51
+ import tarfile
52
+ with tarfile.open(tar_file, 'r') as f:
53
+ file_names = f.getnames()
54
+ mp3_files = [i for i in file_names if i.endswith('.mp3')]
55
+ json_files = [i for i in file_names if i.endswith('.json')]
56
+
57
+ #filter mp3_files without corresponded json
58
+ mp3_files = [i for i in mp3_files if i.replace('.mp3', '.json') in json_files]
59
+ return mp3_files
60
+
61
+ zh_files = load_file_list(zh_prompt_tar_file)
62
+ print(zh_files[:10])
63
+ en_files = load_file_list(en_prompt_tar_file)
64
+ print(en_files[:10])
65
+ import io
66
+
67
+ def load_random_samples_from_tar(tar_file, files, num_samples,target_sr,max_duration=10):
68
+ import random
69
+ import tarfile
70
+ import json
71
+ samples = []
72
+ with tarfile.open(tar_file, 'r') as f:
73
+ for i in random.sample(files, len(files)):
74
+ mp3 = f.extractfile(i)
75
+ mp3_bytes = io.BytesIO(mp3.read())
76
+ speech, sample_rate = torchaudio.load(mp3_bytes,backend='soundfile')
77
+ json_file = f.extractfile(i.replace('.mp3', '.json'))
78
+ json_data = json.load(json_file)
79
+ duration = json_data['duration']
80
+ if duration > max_duration:
81
+ continue
82
+ speech = speech.mean(dim=0, keepdim=True)
83
+ if sample_rate != target_sr:
84
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
85
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
86
+ samples.append((speech, json_data,sample_rate))
87
+ if len(samples) == num_samples:
88
+ break
89
+ return samples
90
+ target_sr = 16000
91
+ zh_samples = load_random_samples_from_tar(zh_prompt_tar_file, zh_files, 10, target_sr)
92
+
93
+ one_sample,one_json,sample_rate = zh_samples[0]
94
+ print(one_json)
95
+ print(sample_rate)
96
+ torchaudio.save('zh_sample.wav', one_sample, target_sr)
97
+ print(len(zh_samples))
98
+
99
+ en_samples = load_random_samples_from_tar(en_prompt_tar_file, en_files, 10, target_sr)
100
+ one_sample,one_json,sample_rate = en_samples[0]
101
+ print(one_json)
102
+ print(sample_rate)
103
+ torchaudio.save('en_sample.wav', one_sample, target_sr)
104
+ print(len(en_samples))
105
+
106
+ def resample_audio(samples, target_sr):
107
+ resampled_samples = []
108
+ for i in samples:
109
+ speech, sample_rate = i
110
+ if sample_rate != target_sr:
111
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
112
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
113
+ resampled_samples.append((speech, sample_rate))
114
+ return resampled_samples
115
+
116
+ prompt_text = zh_samples[0][1]['text']
117
+ prompt_speech = zh_samples[0][0]
118
+ print(prompt_text)
119
+ print(prompt_speech)
120
+ from cosyvoice.cli.cosyvoice import CosyVoice2
121
+ cosyvoice = CosyVoice2(model_path, load_jit=False, load_trt=False, fp16=True)
122
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
123
+ frontend = cosyvoice.frontend
124
+ prompt_text = frontend.text_normalize(prompt_text,split=False, text_frontend=True)
125
+ print(f'normalized prompt_text:{prompt_text}')
126
+ tts_text = '扫一扫,立即体验中国银行信用卡好礼、绑卡立减等热门活动,实时掌握更多优惠信息。'
127
+ tts_text = "在中国的一个偏远山区,有一位名叫李远的年轻人,他对集群通信系统有着浓厚的兴趣。每天晚上,他都会在自己的小屋里研究各种关于集群通信系统的资料,试图弄懂其中的原理和运作机制。他对这个领域的研究不仅仅停留在理论层面,还亲手制作了一些模型,试图通过实践来加深理解。"
128
+ tts_text = "歷史(现代汉语词汇,古典文言文称之为史),指人类社会过去的事件和行动,以及对这些事件行为有系统的记录、诠释和研究。歷史可提供今人理解過去,作為未來行事的參考依據,与伦理、哲学和艺术同属人类精神文明的重要成果。历史的第二个含义,即对过去事件的记录和研究,又称历史学”,或简称“史学”。隶属于历史学或与其密切相关的学科有年代学、编纂学、家谱学、古文字学、计量历史学、考古学、社会学和新闻学等,参见历史学。记录和研究历史的人称为历史学家,简称“史学家”,中国古代称为史官。记录历史的书籍称为史书,如《史記》、《汉书》等,粗分為「官修」與「民載」兩類。"
129
+ tts_text = frontend.text_normalize(tts_text,split=False, text_frontend=True)
130
+ print(f'normalized tts_text:{tts_text}')
131
+ final_rate = 24000
132
+ model_input = frontend.frontend_zero_shot(tts_text, prompt_text, prompt_speech,final_rate)
133
+ print(model_input)
134
+ llm = cosyvoice.model.llm
135
+ device = cosyvoice.model.device
136
+ text = model_input['text'].to(device)
137
+ text_len = torch.tensor([text.shape[1]], dtype=torch.int32).to(device)
138
+ prompt_text = model_input['prompt_text'].to(device)
139
+ prompt_text_len = torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(device)
140
+ llm_prompt_speech_token = model_input['llm_prompt_speech_token'].to(device)
141
+ prompt_speech_token_len = torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(device)
142
+ flow_prompt_speech_token = model_input['flow_prompt_speech_token'].to(device)
143
+ prompt_speech_feat = model_input['prompt_speech_feat'].to(device)
144
+ llm_embedding = model_input['llm_embedding'].to(device)
145
+ flow_embedding = model_input['flow_embedding'].to(device)
146
+ speech_tokens = []
147
+ for i in llm.inference(text = text,
148
+ text_len = text_len,
149
+ prompt_text = prompt_text,
150
+ prompt_text_len = prompt_text_len,
151
+ prompt_speech_token = llm_prompt_speech_token,
152
+ prompt_speech_token_len = prompt_speech_token_len,
153
+ embedding=llm_embedding
154
+ ):
155
+ speech_tokens.append(i)
156
+ print(speech_tokens)
157
+
158
+ tts_speech_tokens = torch.tensor(speech_tokens).unsqueeze(dim=0).to(device)
159
+ print(f'tts_speech_tokens shape:{tts_speech_tokens.shape}')
160
+ cosyvoice.model.hift_cache_dict['xxxx'] = None
161
+ tts_speech = cosyvoice.model.token2wav(token=tts_speech_tokens,
162
+ prompt_token=flow_prompt_speech_token,
163
+ prompt_feat=prompt_speech_feat,
164
+ embedding=flow_embedding,
165
+ uuid='xxxx',
166
+ token_offset=0,
167
+ finalize=True,
168
+ speed=1.0)
169
+ print(f'tts_speech shape:{tts_speech.shape}')
170
+ tts_speech = tts_speech.cpu()
171
+ torchaudio.save('zh_tts.wav', tts_speech, final_rate)
data/utils/convert_embeddings_2_pt.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import sys
4
+ import os
5
+ import json
6
+ from sklearn.cluster import KMeans
7
+ jsonl_dir = sys.argv[1]
8
+ output_file_name = sys.argv[2]
9
+
10
+ # Load the embeddings from jsonl files the key is the name of the file
11
+ embeddings = {}
12
+ for file in os.listdir(jsonl_dir):
13
+ print("Processing", file)
14
+ if file.endswith("_embeddings.json"):
15
+ with open(os.path.join(jsonl_dir, file), "r") as f:
16
+ print("Loading", file)
17
+ data = json.load(f)
18
+ key_name = os.path.basename(file).replace("_embeddings.json", "")
19
+ np_array = np.array(data)
20
+ if np_array.shape[0] == 1:
21
+ np_array = np_array[0]
22
+ else:
23
+ #find the cluster center of the embeddings using kmeans
24
+ kmeans = KMeans(n_clusters=1, random_state=0, n_init = 'auto').fit(np_array)
25
+ np_array = kmeans.cluster_centers_[0]
26
+
27
+ embeddings[key_name]= {'embedding' : torch.tensor(np_array, dtype=torch.float32).unsqueeze(0)}
28
+ torch.save(embeddings, output_file_name)
29
+ print("Embeddings saved to", output_file_name)
30
+
31
+ state_dict = torch.load(output_file_name)
32
+ print("Loaded embeddings from", output_file_name)
33
+ for key in state_dict:
34
+ print(key, state_dict[key]['embedding'].shape)
data/utils/create_embeddings_from_raw.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from re import A
3
+ import whisper
4
+ from librosa import resample
5
+ import multiprocessing
6
+ from tqdm import tqdm
7
+ import onnxruntime
8
+ from onnxruntime import InferenceSession
9
+ import torch
10
+ import pyarrow.parquet as pq
11
+ import numpy as np
12
+ import json
13
+ import io
14
+ import soundfile as sf
15
+ import torchaudio
16
+ import torchaudio.compliance.kaldi as kaldi
17
+ import mmap
18
+ import os
19
+ import pyarrow.parquet as pq
20
+ import io
21
+ import soundfile as sf
22
+ import torchaudio.compliance.kaldi as kaldi
23
+ import torch
24
+ import numpy as np
25
+ import onnxruntime
26
+
27
+ def process_file(file_info):
28
+ """处理单个parquet文件的函数,每个进程调用一次"""
29
+ parquet_file, output_path, speaker_extractor, device = file_info
30
+
31
+ # 为每个进程创建独立的speech_tokenizer_session
32
+ option = onnxruntime.SessionOptions()
33
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
34
+ option.intra_op_num_threads = 1
35
+ ort_session = onnxruntime.InferenceSession(speaker_extractor, sess_options=option,
36
+ providers=["CPUExecutionProvider"])
37
+ results = {}
38
+ try:
39
+ # 创建目标文件名
40
+ base_filename = os.path.splitext(os.path.basename(parquet_file))[0]
41
+ output_file = os.path.join(output_path, f"{base_filename}_tokens.jsonl")
42
+
43
+ # 使用PyArrow读取parquet文件的元数据,获取总行数
44
+ parquet_metadata = pq.read_metadata(parquet_file)
45
+ total_rows = parquet_metadata.num_rows
46
+ batch_size = 100
47
+
48
+ # 使用 mmap 读取 parquet 文件
49
+ with open(parquet_file, 'rb') as f:
50
+ mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
51
+
52
+ # 使用 io.BytesIO 将 mmap 对象包装成文件对象
53
+ buffer = io.BytesIO(mm)
54
+
55
+ pf = pq.ParquetFile(buffer) # 使用 mmap 包装的 buffer
56
+
57
+ progress = tqdm(total=total_rows,
58
+ desc=f"Processing {os.path.basename(parquet_file)}",
59
+ position=multiprocessing.current_process()._identity[0] % 10)
60
+
61
+ current_row = 0
62
+ idx = 0
63
+ for batch in pf.iter_batches(batch_size=batch_size):
64
+ df_batch = batch.to_pandas()
65
+
66
+ # 处理当前批次中的每一行
67
+ for _, row in df_batch.iterrows():
68
+ current_row += 1
69
+ audio_obj = row['audio']
70
+ audio_data = audio_obj['bytes']
71
+ transcription = row['transcription']
72
+ language = row['language']
73
+ speaker = row['speaker']
74
+ if speaker not in results:
75
+ results[speaker] = {}
76
+ if language not in results[speaker]:
77
+ results[speaker][language] = []
78
+ if len(results[speaker][language]) >= 10:
79
+ progress.update(1)
80
+ continue
81
+
82
+ with io.BytesIO(audio_data) as audio_buffer:
83
+ prompt_data, sample_rate = sf.read(audio_buffer)
84
+ # 确保是单声道,并转换为float32
85
+ if len(prompt_data.shape) > 1:
86
+ prompt_data = prompt_data[:, 0]
87
+ prompt_data = prompt_data.astype(np.float32)
88
+
89
+ # 重采样到16kHz (如果需要)
90
+ if sample_rate != 16000:
91
+ prompt_data = resample(prompt_data, orig_sr=sample_rate, target_sr=16000)
92
+
93
+ prompt_speech_16k = torch.tensor(prompt_data).unsqueeze(0)
94
+
95
+ feat = kaldi.fbank(prompt_speech_16k,
96
+ num_mel_bins=80,
97
+ dither=0,
98
+ sample_frequency=16000)
99
+ feat = feat - feat.mean(dim=0,keepdim=True)
100
+ embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
101
+
102
+ results[speaker][language].append(embedding)
103
+
104
+ progress.update(1)
105
+
106
+ # 关闭 mmap 对象
107
+ mm.close()
108
+
109
+
110
+
111
+ print(f'All speakers {results.keys()}')
112
+ for speaker in results:
113
+ print(f'{speaker} : All languages {results[speaker].keys()} in {os.getpid()}')
114
+ return results
115
+ except Exception as e:
116
+ import traceback
117
+ traceback.print_exc()
118
+ return f"Error processing {parquet_file}: {str(e)}"
119
+ def process_file_x(file_info):
120
+ """处理单个parquet文件的函数,每个进程调用一次"""
121
+ parquet_file, output_path, speaker_extractor, device = file_info
122
+
123
+ # 为每个进程创���独立的speech_tokenizer_session
124
+ option = onnxruntime.SessionOptions()
125
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
126
+ option.intra_op_num_threads = 1
127
+ ort_session = InferenceSession(speaker_extractor, sess_options=option,
128
+ providers=["CPUExecutionProvider"])
129
+ results = {}
130
+ try:
131
+ # 创建目标文件名
132
+ base_filename = os.path.splitext(os.path.basename(parquet_file))[0]
133
+ output_file = os.path.join(output_path, f"{base_filename}_tokens.jsonl")
134
+
135
+ # 使用PyArrow读取parquet文件的元数据,获取总行数
136
+ parquet_metadata = pq.read_metadata(parquet_file)
137
+ total_rows = parquet_metadata.num_rows
138
+ batch_size = 100
139
+
140
+ pf = pq.ParquetFile(parquet_file)
141
+
142
+ progress = tqdm(total=total_rows,
143
+ desc=f"Processing {os.path.basename(parquet_file)}",
144
+ position=multiprocessing.current_process()._identity[0] % 10)
145
+
146
+ current_row = 0
147
+ idx = 0
148
+ for batch in pf.iter_batches(batch_size=batch_size):
149
+ df_batch = batch.to_pandas()
150
+
151
+ # 处理当前批次中的每一行
152
+ for _, row in df_batch.iterrows():
153
+ current_row += 1
154
+ audio_obj = row['audio']
155
+ audio_data = audio_obj['bytes']
156
+ transcription = row['transcription']
157
+ language = row['language']
158
+ speaker = row['speaker']
159
+ if speaker not in results:
160
+ results[speaker] = {}
161
+ if language not in results[speaker]:
162
+ results[speaker][language] = []
163
+ if len(results[speaker][language]) >= 10:
164
+ progress.update(1)
165
+ continue
166
+
167
+ with io.BytesIO(audio_data) as buffer:
168
+ prompt_data, sample_rate = sf.read(buffer)
169
+ # 确保是单声道,并转换为float32
170
+ if len(prompt_data.shape) > 1:
171
+ prompt_data = prompt_data[:, 0]
172
+ prompt_data = prompt_data.astype(np.float32)
173
+
174
+ # 重采样到16kHz (如果需要)
175
+ if sample_rate != 16000:
176
+ prompt_data = resample(prompt_data, orig_sr=sample_rate, target_sr=16000)
177
+
178
+ prompt_speech_16k = torch.tensor(prompt_data).unsqueeze(0)
179
+
180
+ feat = kaldi.fbank(prompt_speech_16k,
181
+ num_mel_bins=80,
182
+ dither=0,
183
+ sample_frequency=16000)
184
+ feat = feat - feat.mean(dim=0,keepdim=True)
185
+ embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
186
+
187
+ results[speaker][language].append(embedding)
188
+
189
+ progress.update(1)
190
+
191
+
192
+
193
+
194
+
195
+ print(f'All speakers {results.keys()}')
196
+ for speaker in results:
197
+ print(f'{speaker} : All languages {results[speaker].keys()} in {os.getpid()}')
198
+ return results
199
+ except Exception as e:
200
+ import traceback
201
+ traceback.print_exc()
202
+ return f"Error processing {parquet_file}: {str(e)}"
203
+ if __name__ == '__main__':
204
+ import argparse
205
+ parser = argparse.ArgumentParser()
206
+ parser.add_argument('--data_path', type=str, default='/external_data/yueyudata/starrail-voice')
207
+ parser.add_argument('--output_path',type=str,default='/external_data/yueyudata/starrail-voice-speaker-embeddings')
208
+ parser.add_argument('--speaker_extractor',type=str,default='/external_data/models/CosyVoice2-0.5B_RWKV_1.5B/campplus.onnx')
209
+ parser.add_argument('--device',type=str,default='cuda:0')
210
+ parser.add_argument('--num_processes',type=int,default=4)
211
+ args = parser.parse_args()
212
+
213
+ print(args)
214
+ data_path = args.data_path
215
+ output_path = args.output_path
216
+ device = args.device
217
+ speaker_extractor = args.speaker_extractor
218
+ num_processes = args.num_processes
219
+
220
+ # 确保输出目录存在
221
+ os.makedirs(output_path, exist_ok=True)
222
+
223
+ # 找到所有parquet文件
224
+ parquet_files = []
225
+ for root, dirs, files in os.walk(data_path):
226
+ for file in files:
227
+ if file.endswith('.parquet'):
228
+ parquet_files.append(os.path.join(root, file))
229
+ print(f'Found {len(parquet_files)} parquet files in {data_path}')
230
+
231
+ # 准备多进程参数
232
+ file_info_list = [(file, output_path, speaker_extractor, device) for file in parquet_files]
233
+
234
+ # 使用进程池处理文件
235
+ print(f"Starting processing with {num_processes} processes")
236
+
237
+ # 使用进程池处理文件
238
+ print(f"Starting processing with {num_processes} processes")
239
+ with multiprocessing.Pool(processes=num_processes) as pool:
240
+ results = pool.map(process_file, file_info_list)
241
+
242
+ # 输出处理结果
243
+ print('Processing complete,merge results')
244
+ final_results = {}
245
+ for result in results:
246
+ if isinstance(result, dict):
247
+ for speaker in result:
248
+ if speaker not in final_results:
249
+ final_results[speaker] = {}
250
+ for language in result[speaker]:
251
+ if language not in final_results[speaker]:
252
+ final_results[speaker][language] = []
253
+ final_results[speaker][language].extend(result[speaker][language])
254
+ else:
255
+ print(result)
256
+
257
+ # 输出结果
258
+ for speaker in final_results:
259
+ for language in final_results[speaker]:
260
+ output_file = os.path.join(output_path, f"{speaker}_{language}_embeddings.json")
261
+ print(f"Writing embeddings for {speaker} ({language}) to {output_file}")
262
+ with open(output_file, 'w', encoding='utf-8') as f_out:
263
+ json.dump(final_results[speaker][language], f_out)
data/utils/create_lm_corpus_from_raw.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import json
5
+ import io
6
+ import torch
7
+ import soundfile as sf
8
+ import pyarrow.parquet as pq
9
+ import whisper
10
+ from librosa import resample
11
+ import multiprocessing
12
+ from tqdm import tqdm
13
+ import onnxruntime
14
+ from onnxruntime import InferenceSession
15
+
16
+ def process_file(file_info):
17
+ """处理单个parquet文件的函数,每个进程调用一次"""
18
+ parquet_file, output_path, speech_tokenizer_model, device = file_info
19
+
20
+ # 为每个进程创建独立的speech_tokenizer_session
21
+ option = onnxruntime.SessionOptions()
22
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
23
+ option.intra_op_num_threads = 1
24
+ cuda_idx = int(device.split(':')[-1] if device is not None and 'cuda' in device else '0')
25
+ speech_tokenizer_session = InferenceSession(speech_tokenizer_model, sess_options=option,
26
+ providers=[("CUDAExecutionProvider", {"device_id": cuda_idx})
27
+ if torch.cuda.is_available() else "CPUExecutionProvider"])
28
+
29
+ try:
30
+ # 创建目标文件名
31
+ base_filename = os.path.splitext(os.path.basename(parquet_file))[0]
32
+ output_file = os.path.join(output_path, f"{base_filename}_tokens.jsonl")
33
+
34
+ # 使用PyArrow读取parquet文件的元数据,获取总行数
35
+ parquet_metadata = pq.read_metadata(parquet_file)
36
+ total_rows = parquet_metadata.num_rows
37
+ batch_size = 1000
38
+
39
+ # 检查是否有已经处理过的文件,计算已处理的行数
40
+ processed_rows = 0
41
+ if os.path.exists(output_file):
42
+ with open(output_file, 'r', encoding='utf-8') as f_check:
43
+ for _ in f_check:
44
+ processed_rows += 1
45
+ print(f"Found existing file {output_file} with {processed_rows} processed rows")
46
+
47
+ # 如果已经处理完所有行,跳过此文件
48
+ if processed_rows >= total_rows:
49
+ return f"Skipped {parquet_file}: all {total_rows} rows already processed"
50
+
51
+ # 逐批处理数据,以追加方式打开输出文件
52
+ with open(output_file, 'a' if processed_rows > 0 else 'w', encoding='utf-8') as f_out:
53
+ pf = pq.ParquetFile(parquet_file)
54
+ progress = tqdm(total=total_rows, initial=processed_rows,
55
+ desc=f"Processing {os.path.basename(parquet_file)}",
56
+ position=multiprocessing.current_process()._identity[0] % 10)
57
+
58
+ skip_rows = processed_rows
59
+ current_row = 0
60
+
61
+ for batch in pf.iter_batches(batch_size=batch_size):
62
+ df_batch = batch.to_pandas()
63
+
64
+ # 处理当前批次中的每一行
65
+ for _, row in df_batch.iterrows():
66
+ current_row += 1
67
+
68
+ # 跳过已处理的行
69
+ if current_row <= skip_rows:
70
+ continue
71
+
72
+ audio_obj = row['audio']
73
+ audio_data = audio_obj['bytes']
74
+ transcription = row['transcription']
75
+ language = row['language']
76
+ speaker = row['speaker']
77
+
78
+ with io.BytesIO(audio_data) as buffer:
79
+ prompt_data, sample_rate = sf.read(buffer)
80
+ # 确保是单声道,并转换为float32
81
+ if len(prompt_data.shape) > 1:
82
+ prompt_data = prompt_data[:, 0]
83
+ prompt_data = prompt_data.astype(np.float32)
84
+
85
+ # 重采样到16kHz (如果需要)
86
+ if sample_rate != 16000:
87
+ prompt_data = resample(prompt_data, orig_sr=sample_rate, target_sr=16000)
88
+
89
+ prompt_speech_16k = torch.tensor(prompt_data).unsqueeze(0)
90
+
91
+ feat = whisper.log_mel_spectrogram(prompt_speech_16k, n_mels=128)
92
+ speech_token = speech_tokenizer_session.run(None,
93
+ {speech_tokenizer_session.get_inputs()[0].name:
94
+ feat.detach().cpu().numpy(),
95
+ speech_tokenizer_session.get_inputs()[1].name:
96
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
97
+
98
+ # 写入结果
99
+ f_out.write(json.dumps({'tts_speech_tokens':speech_token,
100
+ 'text':transcription,
101
+ 'language':language,
102
+ 'speaker':speaker,
103
+ "prompt_text":"",
104
+ "llm_prompt_speech_token":[]},
105
+ ensure_ascii=False)+'\n')
106
+ progress.update(1)
107
+
108
+ # 释放内存
109
+ del df_batch
110
+ import gc
111
+ gc.collect()
112
+
113
+ return f"Successfully processed {parquet_file}: {total_rows-processed_rows} new rows processed"
114
+ except Exception as e:
115
+ return f"Error processing {parquet_file}: {str(e)}"
116
+
117
+ if __name__ == '__main__':
118
+ import argparse
119
+ parser = argparse.ArgumentParser()
120
+ parser.add_argument('--data_path', type=str, default='/external_data/yueyudata/starrail-voice')
121
+ parser.add_argument('--output_path',type=str,default='/external_data/yueyudata/starrail-voice-voice_tokens')
122
+ parser.add_argument('--speech_tokenizer_model',type=str,default='/external_data/models/CosyVoice2-0.5B_RWKV_1.5B/speech_tokenizer_v2.onnx')
123
+ parser.add_argument('--device',type=str,default='cuda:0')
124
+ parser.add_argument('--num_processes',type=int,default=4)
125
+ args = parser.parse_args()
126
+
127
+ data_path = args.data_path
128
+ output_path = args.output_path
129
+ device = args.device
130
+ speech_tokenizer_model = args.speech_tokenizer_model
131
+ num_processes = args.num_processes
132
+
133
+ # 确保输出目录存在
134
+ os.makedirs(output_path, exist_ok=True)
135
+
136
+ # 找到所有parquet文件
137
+ parquet_files = []
138
+ for root, dirs, files in os.walk(data_path):
139
+ for file in files:
140
+ if file.endswith('.parquet'):
141
+ parquet_files.append(os.path.join(root, file))
142
+ print(f'Found {len(parquet_files)} parquet files in {data_path}')
143
+
144
+ # 准备多进程参数
145
+ file_info_list = [(file, output_path, speech_tokenizer_model, device) for file in parquet_files]
146
+
147
+ # 使用进程池处理文件
148
+ print(f"Starting processing with {num_processes} processes")
149
+ with multiprocessing.Pool(processes=num_processes) as pool:
150
+ results = pool.map(process_file, file_info_list)
151
+
152
+ # 输出处理结果
153
+ for result in results:
154
+ print(result)
155
+
156
+ print("All files processed successfully!")
data/utils/llm_dataset.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import os
3
+ import json
4
+ import torch
5
+ import random
6
+ import time
7
+ random.seed(time.time())
8
+ import logging
9
+ from tqdm import tqdm
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
+
12
+ def verify_jsonl_files(data_files):
13
+ """检查每个 jsonl 文件的有效性"""
14
+ invalid_files = []
15
+
16
+ for file_path in tqdm(data_files, desc="验证文件"):
17
+ try:
18
+ with open(file_path, 'r', encoding='utf-8') as f:
19
+ for i, line in enumerate(f):
20
+ try:
21
+ json.loads(line)
22
+ except json.JSONDecodeError:
23
+ invalid_files.append((file_path, i+1))
24
+ logging.error(f"文件 {file_path} 在第 {i+1} 行有无效的 JSON")
25
+ break
26
+ except Exception as e:
27
+ invalid_files.append((file_path, f"读取错误: {str(e)}"))
28
+ logging.error(f"无法读取文件 {file_path}: {str(e)}")
29
+
30
+ return invalid_files
31
+ def load_jsonl_dataset(directory,tokenizer):
32
+ '''
33
+ load jsonl files in a directory recursively
34
+ '''
35
+ data_files = []
36
+ for root, dirs, files in os.walk(directory):
37
+ for file in files:
38
+ if file.endswith('.jsonl'):
39
+ data_files.append(os.path.join(root, file))
40
+
41
+ logging.info(f"找到 {len(data_files)} 个 JSONL 文件")
42
+ # 验证文件
43
+ invalid_files = verify_jsonl_files(data_files)
44
+ if invalid_files:
45
+ logging.error(f"发现 {len(invalid_files)} 个无效文件:")
46
+ for file_info in invalid_files:
47
+ if isinstance(file_info[1], int):
48
+ logging.error(f" - {file_info[0]} (错误在第 {file_info[1]} 行)")
49
+ else:
50
+ logging.error(f" - {file_info[0]} ({file_info[1]})")
51
+
52
+ # 移除无效文件
53
+ valid_files = [f for f in data_files if f not in [info[0] for info in invalid_files]]
54
+ logging.info(f"继续处理剩余的 {len(valid_files)} 个有效文件")
55
+ data_files = valid_files
56
+ # 手动收集所有样本,确保特征一致性
57
+ all_samples = []
58
+
59
+ for file_path in tqdm(data_files, desc="加载数据集"):
60
+ try:
61
+ # 手动解析JSONL文件,避免datasets加载时的类型推断问题
62
+ with open(file_path, 'r', encoding='utf-8') as f:
63
+ for line in f:
64
+ try:
65
+ data = json.loads(line)
66
+ # 确保所有字段存在且类型一致
67
+ llm_prompt_speech_token = data.get('llm_prompt_speech_token', [])
68
+ tts_speech_tokens = data.get('tts_speech_tokens', [])
69
+ text = str(data.get('text', ""))
70
+ prompt_text = str(data.get('prompt_text', ""))
71
+
72
+ # 确保列表类型
73
+ if not isinstance(llm_prompt_speech_token, list):
74
+ llm_prompt_speech_token = []
75
+ if not isinstance(tts_speech_tokens, list):
76
+ tts_speech_tokens = []
77
+
78
+ # 添加处理后的样本
79
+ all_samples.append({
80
+ 'llm_prompt_speech_token': llm_prompt_speech_token,
81
+ 'tts_speech_tokens': tts_speech_tokens,
82
+ 'text': text,
83
+ 'prompt_text': prompt_text
84
+ })
85
+ except json.JSONDecodeError:
86
+ continue # 跳过无效的JSON行
87
+ except Exception as e:
88
+ logging.error(f"处理样本时出错: {str(e)}")
89
+ except Exception as e:
90
+ logging.error(f"打开文件 {file_path} 时出错: {str(e)}")
91
+
92
+ if not all_samples:
93
+ raise ValueError("没有成功加载任何样本")
94
+
95
+ # 创建数据集
96
+ logging.info(f"手动创建数据集,包含 {len(all_samples)} 个样本")
97
+ dataset = datasets.Dataset.from_list(all_samples)
98
+
99
+ logging.info(f"成功加载 {len(dataset)} 个样本")
100
+
101
+ #1. concatenate llm_prompt_speech_token and tts_speech_tokens (list of int)
102
+ #delay the concatenation to collate_fn since sometimes we want to drop the prompt
103
+ # dataset = dataset.map(lambda x: {'speech_token': x['llm_prompt_speech_token'] + x['tts_speech_tokens']},remove_columns=['tts_speech_tokens','llm_prompt_speech_token'])
104
+ #2. Filter the data either :
105
+ # 1. the length of the speech_token is less than 1
106
+ # 2. the length of the speech_token is greater than 1000
107
+ # 3. the length of the text is greater than 500
108
+ # 4. the length of the prompt_text is greater than 500
109
+ # 5. the length of the text_token is less than 1
110
+ # 6. the length of the prompt_text_token is less than 1
111
+ dataset = dataset.filter(lambda x:len(x['llm_prompt_speech_token']) < 2048 and len(x['tts_speech_tokens']) < 2048
112
+ and len(tokenizer.encode(x['text'])) < 2048 and len(tokenizer.encode(x['prompt_text'])) < 2048 )
113
+ logging.info(f"过滤后剩余 {len(dataset)} 个样本")
114
+ #2. tokenize the text to text_tokens and prompt_text to prompt_text_tokens
115
+ # dataset = dataset.map(lambda x: {'text_tokens': tokenizer.encode(x['text']), 'prompt_text_tokens': tokenizer.encode(x['prompt_text'])},remove_columns=['text','prompt_text'])
116
+ return dataset
117
+
118
+ def collate_fn(batch, tokenizer, pad_to_max_length=True, max_length=2048, drop_prompt_audio_rate=-0.1):
119
+ '''
120
+ convert the data to torch tensors
121
+ 1. call tokenizer.encode('text') and tokenizer.encode('prompt_text'), concatenate them to get the text_token, record each sample's length to text_token_len
122
+ 2. convert the text_tokens and text_token_len to torch tensor
123
+ 3. record each sample's speech_token length to speech_token_len
124
+ 4. convert the speech_token and speech_token_len to torch tensor
125
+ 5. We will drop prompt with drop_prompt_audio_rate to ask model to learn generate audio without guaidance
126
+ By default we won't drop anything
127
+ '''
128
+ all_text_tokens = []
129
+ all_speech_tokens = []
130
+ speech_token_len = []
131
+ text_token_len = []
132
+ my_max_length = 0
133
+ is_drop_prompt = random.random() < drop_prompt_audio_rate
134
+
135
+ for sample in batch:
136
+ tts_speech_tokens = sample['tts_speech_tokens']
137
+ llm_prompt_speech_token = sample['llm_prompt_speech_token']
138
+
139
+ if is_drop_prompt:
140
+ # 只使用文本部分,不使用提示
141
+ text_tokens = tokenizer.encode(sample['text'])
142
+ all_text_tokens.append(torch.tensor(text_tokens, dtype=torch.int32))
143
+ text_token_len.append(len(text_tokens))
144
+
145
+ # 只使用语音部分,不使用提示语音
146
+ current_speech_tokens = tts_speech_tokens
147
+ all_speech_tokens.append(torch.tensor(current_speech_tokens, dtype=torch.int32))
148
+ speech_token_len.append(len(current_speech_tokens))
149
+
150
+ total_length = len(text_tokens) + len(current_speech_tokens)
151
+ else:
152
+ # 使用提示+文本
153
+ text_tokens = tokenizer.encode(sample['text'])
154
+ prompt_tokens = tokenizer.encode(sample['prompt_text'])
155
+ combined_text_tokens = prompt_tokens + text_tokens
156
+ all_text_tokens.append(torch.tensor(combined_text_tokens, dtype=torch.int32))
157
+ text_token_len.append(len(combined_text_tokens))
158
+
159
+ # 使用提示语音+语音
160
+ current_speech_tokens = llm_prompt_speech_token + tts_speech_tokens
161
+ all_speech_tokens.append(torch.tensor(current_speech_tokens, dtype=torch.int32))
162
+ speech_token_len.append(len(current_speech_tokens))
163
+
164
+ total_length = len(combined_text_tokens) + len(current_speech_tokens)
165
+
166
+ if total_length > my_max_length:
167
+ my_max_length = total_length
168
+
169
+ # 检查长度是否超出最大长度
170
+ skip = my_max_length > max_length
171
+
172
+ # 将列表转换为填充后的张量
173
+ all_text_tokens = torch.nn.utils.rnn.pad_sequence(all_text_tokens, batch_first=True, padding_value=0)
174
+ all_speech_tokens = torch.nn.utils.rnn.pad_sequence(all_speech_tokens, batch_first=True, padding_value=0)
175
+
176
+ # 如果需要填充到最大长度
177
+ if pad_to_max_length and not skip:
178
+ pad_length = max_length - my_max_length
179
+ if pad_length > 0:
180
+ all_speech_tokens = torch.nn.functional.pad(all_speech_tokens, (0, pad_length), value=0)
181
+
182
+ return {
183
+ 'text_token': all_text_tokens,
184
+ 'text_token_len': torch.tensor(text_token_len, dtype=torch.int32),
185
+ 'speech_token': all_speech_tokens, # 确保命名一致
186
+ 'speech_token_len': torch.tensor(speech_token_len, dtype=torch.int32),
187
+ 'skip': skip
188
+ }
189
+
190
+
191
+ if __name__ == '__main__':
192
+ from transformers import AutoTokenizer
193
+ model_path = "/external_data/models/rwkv7-2.9B-world"
194
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
195
+ directory = '/external_data/yueyudata/speech_corpus'
196
+ dataset = load_jsonl_dataset(directory,tokenizer)
197
+ print(dataset)
198
+ print(dataset[0])
199
+ from functools import partial
200
+ collate_fn = partial(collate_fn,tokenizer=tokenizer,pad_to_max_length=False)
201
+ dataloader = torch.utils.data.DataLoader(dataset,batch_size=1,collate_fn=collate_fn)
202
+ for data in dataloader:
203
+ print(data)
204
+ print(data['speech_token'].shape)
205
+ print(data['text_token'].shape)
206
+ break
data/utils/test_utilities.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.utils.utilitie import generate_mixed_instructions
2
+ if __name__ == '__main__':
3
+ print(generate_mixed_instructions('我来自中国。'))
4
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
5
+ print(generate_mixed_instructions('I am from China.',language='en'))
6
+ print(generate_mixed_instructions('This is a city with a long history.',language='en'))
7
+ print(generate_mixed_instructions('我来自中国。'))
8
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
9
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
10
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
11
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
12
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
13
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
14
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
15
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
16
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
17
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
18
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
19
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
20
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
21
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
22
+ print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
23
+ print(generate_mixed_instructions('I am from China.',language='en'))
24
+ print(generate_mixed_instructions('This is a city with a long history.',language='en'))
25
+ print(generate_mixed_instructions('This is a city with a long history.',language='en'))
26
+ print(generate_mixed_instructions('This is a city with a long history.',language='en'))
27
+ print(generate_mixed_instructions('This is a city with a long history.',language='en'))
28
+ print(generate_mixed_instructions('This is a city with a long history.',language='en'))
29
+ print(generate_mixed_instructions('This is a city with a long history.',language='en'))
30
+ print(generate_mixed_instructions('This is a city with a long history.',language='en'))
31
+ print(generate_mixed_instructions('This is a city with a long history.',language='en'))
data/utils/utilitie.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import thread
2
+ from operator import is_
3
+ from librosa import ex
4
+ from regex import P
5
+ from torch import device
6
+ from tqdm import tqdm
7
+ import tarfile
8
+ import random
9
+ import time
10
+ import io
11
+ import torchaudio
12
+ import json
13
+ import os
14
+ import multiprocessing
15
+ import torch
16
+ from data.cosy.data.data_processor import init_process, preprocess_prompts
17
+ import random
18
+ from typing import List
19
+ import torch
20
+ import torchaudio
21
+ import io
22
+
23
+ '''
24
+ Natural Language Instruction
25
+ Emotion: 高兴(Happy), 悲伤(Sad), 惊讶(Surprised), 愤怒(Angry), 恐惧(Fearful), 厌恶(Disgusted), 冷
26
+ 静(Calm), 严肃(Serious)
27
+ Speaking Rate: 快速(Fast), 非常快速(Very Fast), 慢速(Slow), 非常慢速(Very Slow)
28
+ Dialect: 粤语, 四川话, 上海话, 郑州话, 长沙话, 天津话
29
+ Role-playing: 神秘(Mysterious), 凶猛(Fierce), 好奇(Curious), 优雅(Elegant), 孤独(Lonely), 机器
30
+ 人(Robot), 小猪佩奇(Peppa), etc.
31
+ Fine-grained Instruction
32
+ Vocal Bursts: [laughter], [breath], etc.
33
+ Vocal Features: <laughter></laughter>, <strong></strong>
34
+ Examples
35
+ - 你能用高兴的情感说吗?< |endofprompt| >今天真是太开心了,马上要放假了!I’m so happy,
36
+ Spring Festival is coming!
37
+ - Please speaking very fast.< |endofprompt| >Today is a happy day, full of laughter and joy.
38
+ - 请问你能模仿粤语的口音吗?< |endofprompt| >多保重,早休息。
39
+ - 尝试一下以机器人的角色和我交流。< |endofprompt| >接收知识光波!
40
+ - [laughter]有时候,看着小孩子们的天真行为[laughter],我们总会会心一笑。
41
+ - She pursued her dreams with <strong>enthusiasm</strong> and <strong>grit</strong>.
42
+ '''
43
+
44
+ emotions = ['高兴', '悲伤', '惊讶', '愤怒', '恐惧', '厌恶', '冷静', '严肃']
45
+ emotions_in_english = ['Happy', 'Sad', 'Surprised', 'Angry', 'Fearful', 'Disgusted', 'Calm', 'Serious']
46
+ speaking_rates = ['快速', '非常快速', '慢速', '非常慢速']
47
+ speaking_rates_in_english = ['Fast', 'Very Fast', 'Slow', 'Very Slow']
48
+ dialects = ['普通话','粤语', '四川话', '上海话', '郑州话', '长沙话', '天津话']
49
+ dialects_in_english = ['Mandarin','Cantonese', 'Sichuanese', 'Shanghainese', 'Zhengzhou Dialect', 'Changsha Dialect', 'Tianjin Dialect']
50
+ role_playings = ['神秘', '凶猛', '好奇', '优雅', '孤独', '机器人', '小猪佩奇']
51
+ role_playings_in_english = ['Mysterious', 'Fierce', 'Curious', 'Elegant', 'Lonely', 'Robot', 'Peppa']
52
+ vocal_bursts = ['[laughter]', '[breath]']
53
+ vocal_features = ['<laughter></laughter>', '<strong></strong>']
54
+ end_of_prompt = '<|endofprompt|>'
55
+
56
+ def generate_in_emotion_in_chinese(text :str):
57
+ templates = [
58
+ '你能用{}的情感说吗?{}{}',
59
+ '请用{}的情感说。{}{}',
60
+ '请用{}的情感表达。{}{}',
61
+ '请用{}的情感说一下。{}{}',
62
+ '请用{}的情感说一句。{}{}'
63
+ ]
64
+ select_emotion = random.choice(emotions)
65
+ return random.choice(templates).format(select_emotion,end_of_prompt,text)
66
+
67
+ def generate_in_emotion_in_english(text :str):
68
+ templates = [
69
+ 'Can you say it with {} emotion?{}{}',
70
+ 'Please say it with {} emotion.{}{}',
71
+ 'Please express it with {} emotion.{}{}',
72
+ 'Please say it with {} emotion.{}{}',
73
+ 'Please say a sentence with {} emotion.{}{}'
74
+ ]
75
+ select_emotion = random.choice(emotions_in_english)
76
+ return random.choice(templates).format(select_emotion,end_of_prompt,text)
77
+
78
+ def generate_speaking_rate_in_chinese(text :str):
79
+ templates = [
80
+ '请用{}的语速说。{}{}',
81
+ '请用{}的语速说一下。{}{}',
82
+ '请用{}的语速说一句。{}{}',
83
+ '请用{}的语速表达。{}{}',
84
+ '请用{}的语速说。{}{}',
85
+ '请{}地说。{}{}',
86
+ '请{}地说一下。{}{}',
87
+ '请{}地说一句。{}{}',
88
+ '{}的说。{}{}',
89
+ '{}的说一下。{}{}',
90
+ '{}的说一句。{}{}',
91
+ '{}的表达。{}{}'
92
+
93
+ ]
94
+ select_rate = random.choice(speaking_rates)
95
+ template = random.choice(templates)
96
+ return template.format(select_rate,end_of_prompt,text)
97
+
98
+ def generate_speaking_rate_in_english(text :str):
99
+ templates = [
100
+ 'Please say it with {} speaking rate.{}{}',
101
+ 'Say it with {} speaking rate.{}{}',
102
+ 'Please say a sentence with {} speaking rate.{}{}',
103
+ 'Please express it with {} speaking rate.{}{}',
104
+ 'Please speak {}ly.{}{}',
105
+ 'Speak {}ly.{}{}',
106
+ 'Please say it {}ly.{}{}',
107
+ 'Say it {}ly.{}{}'
108
+ ]
109
+ select_rate = random.choice(speaking_rates_in_english)
110
+ template = random.choice(templates)
111
+ return template.format(select_rate,end_of_prompt,text)
112
+
113
+
114
+ def load_file_list(tar_file):
115
+ #the files are FILE_NAME.mp3/FILE_NAME.json
116
+ #return all FILE_NAME as a list which has a mp3 and json
117
+ import tarfile
118
+ with tarfile.open(tar_file, 'r') as f:
119
+ file_names = f.getnames()
120
+ mp3_files = [i for i in file_names if i.endswith('.mp3')]
121
+ json_files = [i for i in file_names if i.endswith('.json')]
122
+
123
+ #filter mp3_files without corresponded json
124
+ mp3_files = [i for i in mp3_files if i.replace('.mp3', '.json') in json_files]
125
+ return mp3_files
126
+
127
+ def extract_prompt(input_tar_files, input_tar_languages, max_duration=5, num_samples=10, target_sr=16000, output_dir=None):
128
+ """
129
+ Extract prompt from tar files
130
+ Args:
131
+ input_tar_files: list of str, input tar files
132
+ input_tar_languages: list of str, input tar languages for each tar file, must be the same length as input_tar_files
133
+ max_duration: float, max duration of audio
134
+ num_samples: int, number of samples to extract
135
+ target_sr: int, target sample rate
136
+ output_dir: str, output directory
137
+ """
138
+ for tar_file, language in zip(input_tar_files, input_tar_languages):
139
+ print(f'Extracting prompt from {tar_file}...with language {language}')
140
+ random.seed(time.time())
141
+ samples = []
142
+ mp3_files = load_file_list(tar_file)
143
+ with tarfile.open(tar_file, 'r') as f:
144
+ progress_bar = tqdm(total=num_samples,desc=f'Extracting prompt from {tar_file}')
145
+ for i in random.sample(mp3_files, len(mp3_files)):
146
+ mp3 = f.extractfile(i)
147
+ mp3_bytes = io.BytesIO(mp3.read())
148
+ speech, sample_rate = torchaudio.load(mp3_bytes,backend='soundfile')
149
+ json_file = f.extractfile(i.replace('.mp3', '.json'))
150
+ json_data = json.load(json_file)
151
+ duration = json_data['duration']
152
+ if duration > max_duration:
153
+ continue
154
+ speech = speech.mean(dim=0, keepdim=True)
155
+ if sample_rate != target_sr:
156
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
157
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
158
+ samples.append((speech, json_data,sample_rate))
159
+ progress_bar.update(1)
160
+ if len(samples) == num_samples:
161
+ break
162
+ if output_dir is not None:
163
+ """
164
+ json looks like:
165
+ {'id': 'ZH_B00000_S01450_W000017', 'wav': 'ZH_B00000/ZH_B00000_S01450/mp3/ZH_B00000_S01450_W000017.mp3', 'text': '因此,我们认为流通性具有更广泛的含义。', 'duration': 4.193, 'speaker': 'ZH_B00000_S01450', 'language': 'zh', 'dnsmos': 3.3709}
166
+ """
167
+ output_dir_lang = os.path.join(output_dir, language)
168
+ os.makedirs(output_dir_lang, exist_ok=True)
169
+ progress_bar = tqdm(total=len(samples), desc=f'Saving samples to {output_dir_lang}')
170
+ for i, (speech, json_data, sample_rate) in enumerate(samples):
171
+ id = json_data['id']
172
+ wave_file = os.path.join(output_dir_lang, f'{id}.wav')
173
+ json_file = os.path.join(output_dir_lang, f'{id}.json')
174
+ torchaudio.save(wave_file, speech, target_sr)
175
+ with open(json_file, 'w') as f:
176
+ json.dump(json_data, f,ensure_ascii=False)
177
+ progress_bar.update(1)
178
+ print(f'Extracted {len(samples)} samples from {tar_file} with language {language}')
179
+
180
+ def generate_dialect_in_chinese(text: str):
181
+ templates = [
182
+ '请问你能模仿{}的口音吗?{}{}',
183
+ '请用{}的口音说一下。{}{}',
184
+ '用{}的口音说一句。{}{}',
185
+ '能用{}的口音读一下吗?{}{}',
186
+ '请尝试用{}的口音说这段话。{}{}',
187
+ '请以{}的口音表达。{}{}',
188
+ '请用{}的语调说。{}{}',
189
+ '试试用{}的方言说。{}{}',
190
+ '能否用{}的语调读出来?{}{}',
191
+ '请说一段{}。{}{}'
192
+ ]
193
+ select_dialect = random.choice(dialects)
194
+ return random.choice(templates).format(select_dialect, end_of_prompt, text)
195
+
196
+ def generate_dialect_in_english(text: str):
197
+ templates = [
198
+ 'Can you mimic the {} accent?{}{}',
199
+ 'Please speak with a {} accent.{}{}',
200
+ 'Say it with a {} accent.{}{}',
201
+ 'Could you read this with a {} accent?{}{}',
202
+ 'Please try to speak this with a {} accent.{}{}',
203
+ 'Please express it with a {} accent.{}{}',
204
+ 'Please use {} intonation.{}{}',
205
+ 'Try speaking in {}.{}{}',
206
+ 'Could you read this in {}?{}{}',
207
+ 'Please say a passage in {}.{}{}'
208
+ ]
209
+ select_dialect = random.choice(dialects_in_english)
210
+ return random.choice(templates).format(select_dialect, end_of_prompt, text)
211
+
212
+ def generate_role_playing_in_chinese(text: str):
213
+ templates = [
214
+ '尝试一下以{}的角色和我交流。{}{}',
215
+ '请以{}的角色说这句话。{}{}',
216
+ '假装你是{},说一下这句话。{}{}',
217
+ '扮演{}来说这段话。{}{}',
218
+ '请用{}的语气说。{}{}',
219
+ '以{}的形象来表达。{}{}',
220
+ '你能用{}的方式说吗?{}{}',
221
+ '模仿{}说话。{}{}',
222
+ '请���{}的口吻说一下。{}{}',
223
+ '像{}一样说这句话。{}{}'
224
+ ]
225
+ select_role = random.choice(role_playings)
226
+ return random.choice(templates).format(select_role, end_of_prompt, text)
227
+
228
+ def generate_role_playing_in_english(text: str):
229
+ templates = [
230
+ 'Try to communicate with me as a {} character.{}{}',
231
+ 'Please say this as a {} character.{}{}',
232
+ 'Pretend you are {}, say this sentence.{}{}',
233
+ 'Act as {} to say this passage.{}{}',
234
+ 'Please speak with a {} tone.{}{}',
235
+ 'Express this with a {} image.{}{}',
236
+ 'Can you say this in a {} way?{}{}',
237
+ 'Mimic {} speaking.{}{}',
238
+ 'Please say this in the manner of {}.{}{}',
239
+ 'Say this like {}.{}{}'
240
+ ]
241
+ select_role = random.choice(role_playings_in_english)
242
+ return random.choice(templates).format(select_role, end_of_prompt, text)
243
+
244
+ def generate_vocal_bursts(text: str):
245
+ """
246
+ 在文本中随机添加声音爆发标记,如[laughter]、[breath]等
247
+ """
248
+ templates = [
249
+ '{}{}', # 在句首添加
250
+ '{}{}{}', # 在句中添加
251
+ '{}{}' # 在句末添加
252
+ ]
253
+
254
+ burst = random.choice(vocal_bursts)
255
+ template_choice = random.choice(templates)
256
+
257
+ if template_choice == '{}{}': # 句首
258
+ return burst + text
259
+ elif template_choice == '{}{}{}': # 句中
260
+ words = text.split()
261
+ if len(words) <= 3: # 文本太短不分割
262
+ return burst + text
263
+ split_point = random.randint(1, len(words) - 1)
264
+ return ' '.join(words[:split_point]) + ' ' + burst + ' ' + ' '.join(words[split_point:])
265
+ else: # 句末
266
+ return text + ' ' + burst
267
+
268
+ def generate_vocal_features(text: str):
269
+ """
270
+ 在文本中随机添加声音特征标记,如<laughter></laughter>、<strong></strong>等
271
+ 支持中文和英文文本
272
+ """
273
+ feature = random.choice(vocal_features)
274
+ feature_start, feature_end = feature.split('><')
275
+ feature_start += '>'
276
+ feature_end = '<' + feature_end
277
+
278
+ # 检查是否为中文文本
279
+ has_chinese = any('\u4e00' <= char <= '\u9fff' for char in text)
280
+
281
+ if has_chinese:
282
+ # 处理中文文本
283
+ if len(text) <= 10: # 文本太短,整个加强
284
+ return feature_start + text + feature_end
285
+
286
+ # 对中文处理,随机选择一个字符范围
287
+ text_len = len(text)
288
+ # 随机选择一个起始位置和一个范围长度
289
+ start_pos = random.randint(1, max(1, text_len // 2)) # 避免总是从句首开始
290
+ span_length = random.randint(1, min(5, text_len - start_pos))
291
+ end_pos = start_pos + span_length - 1
292
+
293
+ # 在选定位置插入标记
294
+ result = text[:start_pos] + feature_start + text[start_pos:end_pos+1] + feature_end + text[end_pos+1:]
295
+ return result
296
+ else:
297
+ # 处理英文文本
298
+ words = text.split()
299
+ if len(words) <= 3: # 文本太短,整个加强
300
+ return feature_start + text + feature_end
301
+
302
+ # 随机选择一个词或短语来添加特征
303
+ start_idx = random.randint(0, len(words) - 1)
304
+ span_length = random.randint(1, min(3, len(words) - start_idx)) # 最多3个词
305
+
306
+ result = []
307
+ for i, word in enumerate(words):
308
+ if i == start_idx:
309
+ result.append(feature_start + word)
310
+ elif i == start_idx + span_length - 1:
311
+ result.append(word + feature_end)
312
+ else:
313
+ result.append(word)
314
+
315
+ return ' '.join(result)
316
+
317
+ def generate_mixed_instructions(text: str, language="zh"):
318
+ """
319
+ 混合多种指令类型,可以同时包含情感、语速、方言、角色扮演等
320
+ """
321
+ instruction_generators = []
322
+
323
+ if language == "zh":
324
+ instruction_generators = [
325
+ generate_in_emotion_in_chinese,
326
+ generate_speaking_rate_in_chinese,
327
+ generate_dialect_in_chinese,
328
+ generate_role_playing_in_chinese
329
+ ]
330
+ else: # 英文
331
+ instruction_generators = [
332
+ generate_in_emotion_in_english,
333
+ generate_speaking_rate_in_english,
334
+ generate_dialect_in_english,
335
+ generate_role_playing_in_english
336
+ ]
337
+
338
+ # 随机选择1个generator
339
+ selected_generator = random.choice(instruction_generators)
340
+
341
+ # 可能会添加声音特征
342
+ text_with_features = text
343
+ if random.random() < 0.3: # 30%的概率添加声音特征
344
+ text_with_features = generate_vocal_features(text)
345
+
346
+ # 可能会添加声音爆发
347
+ if random.random() < 0.2: # 20%的概率添加声音爆发
348
+ text_with_features = generate_vocal_bursts(text_with_features)
349
+
350
+ # 应用选择的指令生成器
351
+ result = text_with_features
352
+ result = selected_generator(result)
353
+
354
+ return result
355
+
356
+ frontend = None
357
+ llm = None
358
+ cosyvoice = None
359
+ output_fp = None
360
+ prompts = None
361
+ global_device = None
362
+ processed_count = 0
363
+ def initialize_process(model_dir,prompts_dir,output_dir,device):
364
+ current_process = multiprocessing.current_process()
365
+ file_name = f'{output_dir}/{current_process.pid}.jsonl'
366
+ global frontend,llm,cosyvoice,output_fp,prompts,global_device
367
+ global_device = device
368
+ output_fp = open(file_name, 'w')
369
+ print(f'Initializing process with device {device} and output file {file_name}')
370
+ frontend,llm,cosyvoice = init_process(model_dir,device)
371
+ prompts = preprocess_prompts(frontend,prompts_dir)
372
+ print(f'load prompts {prompts.keys()}')
373
+ return frontend,llm,cosyvoice
374
+
375
+ def generate_speech_tokens(llm,frontend,tts_text,model_input,device):
376
+ tts_text = frontend.text_normalize(tts_text,split=False, text_frontend=True)
377
+ tts_text_token, tts_text_token_len = frontend._extract_text_token(tts_text)
378
+ tts_text_token_len = torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(device)
379
+ prompt_text = model_input['prompt_text'].to(device) if 'prompt_text' in model_input else torch.zeros(1, 0, dtype=torch.int32).to(device)
380
+ prompt_text_len = torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(device) if prompt_text is not None else torch.zeros(1, 0, dtype=torch.int32).to(device)
381
+ llm_prompt_speech_token = model_input['llm_prompt_speech_token'].to(device) if 'llm_prompt_speech_token' in model_input else torch.zeros(1, 0, dtype=torch.int32).to(device)
382
+ prompt_speech_token_len = torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(device) if llm_prompt_speech_token is not None else None
383
+ flow_prompt_speech_token = model_input['flow_prompt_speech_token'].to(device)
384
+ prompt_speech_feat = model_input['prompt_speech_feat'].to(device)
385
+ llm_embedding = model_input['llm_embedding'].to(device)
386
+ flow_embedding = model_input['flow_embedding'].to(device)
387
+ speech_tokens = []
388
+ with torch.no_grad():
389
+ for i in llm.inference(text = tts_text_token,
390
+ text_len = tts_text_token_len,
391
+ prompt_text = prompt_text,
392
+ prompt_text_len = prompt_text_len,
393
+ prompt_speech_token = llm_prompt_speech_token,
394
+ prompt_speech_token_len = prompt_speech_token_len,
395
+ embedding=llm_embedding
396
+ ):
397
+ speech_tokens.append(i)
398
+ return speech_tokens
399
+
400
+ def process_text(text,language):
401
+ global frontend,llm,cosyvoice,output_fp,prompts,processed_count,global_device
402
+ processed_count += 1
403
+ if processed_count % 100 == 0:
404
+ print(f'Processed {processed_count} samples')
405
+ tts_text = text
406
+ splits_txt_by_lines = tts_text.split('\n')
407
+ #remove the sentences with length less than 10
408
+ splits_txt_by_lines = [i.strip() for i in splits_txt_by_lines if len(i.strip()) > 10]
409
+ random.seed(time.time())
410
+ model_input,prompt_text = random.choice(prompts[language])
411
+ llm_prompt_speech_token = model_input['llm_prompt_speech_token'].cpu().tolist()
412
+ for tts_text in splits_txt_by_lines:
413
+ tts_speech_tokens = generate_speech_tokens(llm,frontend,tts_text,model_input,cosyvoice.device)
414
+ output_data = {
415
+ 'text': tts_text,
416
+ 'tts_speech_tokens': tts_speech_tokens,
417
+ 'prompt_text': prompt_text,
418
+ 'llm_prompt_speech_token': llm_prompt_speech_token[0]
419
+ }
420
+ output_fp.write(json.dumps(output_data,ensure_ascii=False)+'\n')
421
+ output_fp.flush()
422
+ return processed_count
423
+ def process_jsonl_file(jsonl_file,language,process_pool):
424
+ print(f'Processing {jsonl_file}...')
425
+ count = 0
426
+ import json
427
+ with open(jsonl_file, 'r') as f:
428
+ for line in f:
429
+ line = line.strip()
430
+ if len(line) == 0:
431
+ continue
432
+ data = json.loads(line)
433
+ text = data['text']
434
+ count += 1
435
+ future = process_pool.submit(process_text,text,language)
436
+ print(f'processed {future.result()} requests')
437
+ print(f'Processed {count} samples from {jsonl_file}')
438
+ return count
439
+
440
+ def process_parquet_file(parquet_file,language,process_pool):
441
+ print(f'Processing {parquet_file}...')
442
+ import pandas as pd
443
+ df = pd.read_parquet(parquet_file)
444
+ count = 0
445
+ for i in range(len(df)):
446
+ text = df.iloc[i]['text']
447
+ count += 1
448
+ future = process_pool.submit(process_text,text,language)
449
+ print(f'processed {future.result()} requests')
450
+ print(f'Processed {count} samples from {parquet_file}')
451
+ return count
452
+
453
+ def generate_speech_tokens_single_process(cosy_model_dir, prompts_dir, output_dir, language, jsonl_files=None, parquet_files=None, device="cuda:0",is_cross_lingual=False,is_instructed=False):
454
+ """
455
+ 单进程单线程版本的语音标记生成函数
456
+ """
457
+ import torch
458
+ import json
459
+ import os
460
+ import random
461
+ import time
462
+ import traceback
463
+ import logging
464
+ import sys
465
+ from datetime import datetime
466
+ from data.cosy.data.data_processor import init_process, preprocess_prompts
467
+
468
+ # 设置日志
469
+ output_dir_lang = os.path.join(output_dir, language)
470
+ os.makedirs(output_dir_lang, exist_ok=True)
471
+ process_id = os.getpid()
472
+ log_file = os.path.join(output_dir_lang, f'process_{process_id}_log.txt')
473
+
474
+ # 配置日志输出到文件和控制台
475
+ logging.basicConfig(
476
+ level=logging.INFO,
477
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
478
+ handlers=[
479
+ logging.FileHandler(log_file),
480
+ logging.StreamHandler(sys.stdout)
481
+ ]
482
+ )
483
+ logger = logging.getLogger(f'process_{process_id}')
484
+
485
+ # 记录启动信息
486
+ logger.info(f"='='='='='='='='='='='Instructed={is_instructed}'='='='='='='='='='='='='='='='='='")
487
+ logger.info(f"启动时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
488
+ logger.info(f"进程ID: {process_id}")
489
+ logger.info(f"设备: {device}")
490
+ logger.info(f"模型目录: {cosy_model_dir}")
491
+ logger.info(f"提示词目录: {prompts_dir}")
492
+ logger.info(f"输出目录: {output_dir_lang}")
493
+ if jsonl_files:
494
+ logger.info(f"JSONL文件: {jsonl_files}")
495
+ if parquet_files:
496
+ logger.info(f"Parquet文件: {parquet_files}")
497
+ logger.info(f"='='='='='='='='='='='='='='='='='='='='='='='='='='='='='")
498
+
499
+ output_fp = None
500
+ frontend = None
501
+ llm = None
502
+ cosyvoice = None
503
+ total_processed = 0
504
+
505
+ try:
506
+ # 初始化模型
507
+ logger.info(f'初始化模型,使用设备: {device}')
508
+ frontend, llm, cosyvoice = init_process(cosy_model_dir, device)
509
+
510
+ # 预处理提示
511
+ logger.info(f'开始预处理提示词')
512
+ prompts = preprocess_prompts(frontend, prompts_dir)
513
+ logger.info(f'加载提示完成: {prompts.keys()}')
514
+
515
+ output_file = os.path.join(output_dir_lang, f'{process_id}.jsonl')
516
+ output_fp = open(output_file, 'w')
517
+
518
+ # 处理函数
519
+ def process_single_text(text):
520
+ try:
521
+ tts_text = text
522
+ splits_txt_by_lines = tts_text.split('\n')
523
+ # 删除长度小于10的句子
524
+ splits_txt_by_lines = [i.strip() for i in splits_txt_by_lines if len(i.strip()) > 10]
525
+
526
+ if not splits_txt_by_lines:
527
+ logger.warning(f"文本没有有效句子: '{text[:100]}...'")
528
+ return 0
529
+
530
+ random.seed(time.time())
531
+ cross_linguals_map = {
532
+ 'zh': 'en',
533
+ 'en': 'zh'
534
+ }
535
+ try:
536
+ model_input, prompt_text = random.choice(prompts[language if not is_cross_lingual else cross_linguals_map[language]])
537
+ except KeyError:
538
+ logger.error(f"语言 '{language}' 在提示词中不存在! 可用语言: {list(prompts.keys())}")
539
+ return 0
540
+
541
+ llm_prompt_speech_token = model_input['llm_prompt_speech_token'].cpu().tolist() if 'llm_prompt_speech_token' in model_input else []
542
+
543
+ processed_count = 0
544
+ for tts_text in splits_txt_by_lines:
545
+ try:
546
+ if is_instructed:
547
+ tts_text = generate_mixed_instructions(tts_text, language)
548
+ prompt_text = ""
549
+ llm_prompt_speech_token[0]=[]
550
+ if 'prompt_text' in model_input:
551
+ del model_input['prompt_text']
552
+ if 'prompt_text_len' in model_input:
553
+ del model_input['prompt_text_len']
554
+ if 'llm_prompt_speech_token' in model_input:
555
+ del model_input['llm_prompt_speech_token']
556
+ if 'llm_prompt_speech_token_len' in model_input:
557
+ del model_input['llm_prompt_speech_token_len']
558
+ # 生成语音标记
559
+ tts_speech_tokens = generate_speech_tokens(llm, frontend, tts_text, model_input, device)
560
+ output_data = {
561
+ 'text': tts_text,
562
+ 'tts_speech_tokens': tts_speech_tokens,
563
+ 'prompt_text': prompt_text,
564
+ 'llm_prompt_speech_token': llm_prompt_speech_token[0]
565
+ }
566
+ output_fp.write(json.dumps(output_data, ensure_ascii=False) + '\n')
567
+ output_fp.flush()
568
+ processed_count += 1
569
+ except Exception as e:
570
+ logger.error(f"处理单个句子时出错: '{tts_text[:100]}...'")
571
+ logger.error(f"错误信息: {str(e)}")
572
+ logger.error(traceback.format_exc())
573
+
574
+ return processed_count
575
+ except Exception as e:
576
+ logger.error(f"处理文本块时出错")
577
+ logger.error(f"错误信息: {str(e)}")
578
+ logger.error(traceback.format_exc())
579
+ return 0
580
+
581
+ # 收集要处理的文件
582
+ files_to_process = []
583
+
584
+ # 处理JSONL文件
585
+ if jsonl_files is not None:
586
+ logger.info(f"处理指定的JSONL文件")
587
+ for file in jsonl_files:
588
+ if file.endswith('.jsonl'):
589
+ files_to_process.append(('jsonl', file))
590
+ logger.info(f"共有 {len([f for t, f in files_to_process if t == 'jsonl'])} 个JSONL文件需要处理")
591
+
592
+ # 处理Parquet文件
593
+ if parquet_files is not None:
594
+ logger.info(f"处理指定的Parquet文件")
595
+ for file in parquet_files:
596
+ if file.endswith('.parquet'):
597
+ files_to_process.append(('parquet', file))
598
+ logger.info(f"共有 {len([f for t, f in files_to_process if t == 'parquet'])} 个Parquet文件需要处理")
599
+
600
+ # 顺序处理所有文件
601
+ for file_type, file_path in files_to_process:
602
+ logger.info(f'开始处理文件: {file_path}')
603
+ try:
604
+ if file_type == 'jsonl':
605
+ # 处理JSONL文件
606
+ # 首先计算文件总行数,用于进度条
607
+ total_lines = 0
608
+ with open(file_path, 'r') as f:
609
+ for line in f:
610
+ if line.strip(): # 只计算非空行
611
+ total_lines += 1
612
+
613
+ logger.info(f"JSONL文件 {file_path} 共有 {total_lines} 行")
614
+ # 使用进度条处理文件
615
+ with open(file_path, 'r') as f:
616
+ from tqdm import tqdm
617
+ progress_bar = tqdm(total=total_lines, desc=f'处理JSONL文件: {os.path.basename(file_path)}')
618
+ file_processed = 0
619
+ for line in f:
620
+ line = line.strip()
621
+ if len(line) == 0:
622
+ continue
623
+ try:
624
+ data = json.loads(line)
625
+ text = data['text']
626
+ processed = process_single_text(text)
627
+ total_processed += processed
628
+ file_processed += processed
629
+ progress_bar.update(1)
630
+ progress_bar.set_postfix(total=total_processed)
631
+ except Exception as e:
632
+ logger.error(f"处理JSONL行时出错: {line[:100]}...")
633
+ logger.error(f"错误信息: {str(e)}")
634
+ logger.error(traceback.format_exc())
635
+ progress_bar.close()
636
+ logger.info(f"JSONL文件 {file_path} 完成处理,成功处理 {file_processed} 条记录")
637
+
638
+ elif file_type == 'parquet':
639
+ # 处理Parquet文件
640
+ try:
641
+ import pandas as pd
642
+ logger.info(f"加载Parquet文件: {file_path}")
643
+ df = pd.read_parquet(file_path)
644
+ logger.info(f"Parquet文件 {file_path} 共有 {len(df)} 行")
645
+
646
+ from tqdm import tqdm
647
+ progress_bar = tqdm(total=len(df), desc=f'处理Parquet文件: {os.path.basename(file_path)}')
648
+ file_processed = 0
649
+ for i in range(len(df)):
650
+ try:
651
+ text = df.iloc[i]['text']
652
+ processed = process_single_text(text)
653
+ total_processed += processed
654
+ file_processed += processed
655
+ progress_bar.update(1)
656
+ progress_bar.set_postfix(total=total_processed)
657
+ except Exception as e:
658
+ logger.error(f"处理Parquet行 {i} 时出错")
659
+ logger.error(f"错误信息: {str(e)}")
660
+ logger.error(traceback.format_exc())
661
+ progress_bar.close()
662
+ logger.info(f"Parquet文件 {file_path} 完成处理,成功处理 {file_processed} 条记录")
663
+ except ImportError:
664
+ logger.error("处理Parquet文件需要pandas库,请安装: pip install pandas")
665
+ except Exception as e:
666
+ logger.error(f"处理Parquet文件 {file_path} 时出现错误")
667
+ logger.error(f"错误信息: {str(e)}")
668
+ logger.error(traceback.format_exc())
669
+ except Exception as e:
670
+ logger.error(f"处理文件 {file_path} 时出现错误")
671
+ logger.error(f"错误信息: {str(e)}")
672
+ logger.error(traceback.format_exc())
673
+
674
+ logger.info(f'总共成功处理 {total_processed} 个样本,结果保存到 {output_file}')
675
+
676
+ except Exception as e:
677
+ logger.error("处理过程中出现全局错误")
678
+ logger.error(f"错误信息: {str(e)}")
679
+ logger.error(traceback.format_exc())
680
+
681
+ finally:
682
+ # 确保资源正确关闭
683
+ logger.info("清理资源...")
684
+ if output_fp is not None:
685
+ try:
686
+ output_fp.close()
687
+ logger.info(f"关闭输出文件")
688
+ except Exception as e:
689
+ logger.error(f"关闭输出文件时出错: {str(e)}")
690
+
691
+ # 释放GPU资源
692
+ if torch.cuda.is_available():
693
+ try:
694
+ torch.cuda.empty_cache()
695
+ logger.info("已清理GPU缓存")
696
+ except Exception as e:
697
+ logger.error(f"清理GPU缓存时出错: {str(e)}")
698
+
699
+ logger.info(f"处理结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
700
+ logger.info(f"='='='='='='='='='='='='='='='='='='='='='='='='='='='='='")
701
+
702
+ if __name__ == '__main__':
703
+ import argparse
704
+ """
705
+ Parse arguments
706
+ task: str, including 'extract_prompt'
707
+ input_tar_files: list of str, input tar files
708
+ input_tar_languages: list of str, input tar languages for each tar file, must be the same length as input_tar_files
709
+ max_duration: float, max duration of audio
710
+ num_samples: int, number of samples to extract
711
+ target_sr: int, target sample rate
712
+ output_dir: str, output directory
713
+ num_processes: int, number of processes to use
714
+ prompt_dir: str, prompt directory which contains prompt jsonl files and audio files
715
+ language: str, language, zh or en
716
+ cosy_model_dir: str, cosy model directory
717
+ device: str, cuda device used to extract speech tokens
718
+ jsonl_files: list of str, jsonl files
719
+ parquet_files: list of str, parquet files
720
+ """
721
+ parser = argparse.ArgumentParser()
722
+ parser.add_argument('--task', type=str, help='task')
723
+ parser.add_argument('--input_tar_files', nargs='+', type=str, help='input tar files')
724
+ parser.add_argument('--input_tar_languages', nargs='+', type=str, help='input tar languages for each tar file')
725
+ parser.add_argument('--output_dir', type=str, help='output directory',required=True)
726
+ parser.add_argument('--max_duration', type=float, default=5, help='max duration of audio')
727
+ parser.add_argument('--num_samples', type=int, default=10, help='number of samples to extract')
728
+ parser.add_argument('--target_sr', type=int, default=16000, help='target sample rate')
729
+ parser.add_argument('--num_processes', type=int, default=1, help='number of processes to use')
730
+ parser.add_argument('--prompts_dir', type=str, help='prompt directory which contains prompt jsonl files and audio files')
731
+ parser.add_argument('--language', type=str, help='language')
732
+ parser.add_argument('--cosy_model_dir', type=str, help='cosy model directory')
733
+ parser.add_argument('--device', type=str, help='cuda device used to extract speech tokens')
734
+ parser.add_argument('--jsonl_files', nargs='+', type=str, help='jsonl files')
735
+ parser.add_argument('--parquet_files', nargs='+', type=str, help='parquet files')
736
+ parser.add_argument('--is_cross_lingual', action='store_true', help='is cross lingual')
737
+ parser.add_argument('--is_instructed', action='store_true', help='is instructed')
738
+ args = parser.parse_args()
739
+ task = args.task
740
+ if task == 'extract_prompt':
741
+ input_tar_files = args.input_tar_files
742
+ input_tar_languages = args.input_tar_languages
743
+ output_dir = args.output_dir
744
+ assert len(input_tar_files) == len(input_tar_languages), 'input_tar_files and input_tar_languages must have the same length'
745
+ extract_prompt(input_tar_files, input_tar_languages, args.max_duration, args.num_samples, args.target_sr, output_dir)
746
+ elif task == 'generate_speech_tokens':
747
+ prompts_dir = args.prompts_dir
748
+ language = args.language
749
+ cosy_model_dir = args.cosy_model_dir
750
+ jsonl_files = args.jsonl_files
751
+ parquet_files = args.parquet_files
752
+ device = args.device
753
+ is_cross_lingual = args.is_cross_lingual
754
+ is_instructed = args.is_instructed
755
+ # 使用单进程单线程版本替代多进程版本
756
+ generate_speech_tokens_single_process(
757
+ cosy_model_dir=cosy_model_dir,
758
+ prompts_dir=prompts_dir,
759
+ output_dir=args.output_dir,
760
+ language=language,
761
+ jsonl_files=jsonl_files,
762
+ parquet_files=parquet_files,
763
+ device=device,
764
+ is_cross_lingual=is_cross_lingual,
765
+ is_instructed=is_instructed,
766
+ )
767
+
eval/eval_seed_generate.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Download the evaluation file from:https://drive.google.com/file/d/1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP/edit
2
+ import os
3
+ voice_engine = None
4
+ def init_process_func(model_path,device):
5
+ global voice_engine
6
+ from cosyvoice.cli.cosyvoice import CosyVoice2
7
+ voice_engine = CosyVoice2(model_path,device=device,fp16=False,load_jit=False)
8
+ print(f'Finish loading cosyvoice model from {model_path} in process {os.getpid()}')
9
+ def do_tts(ID,tts_text,prompt_text,prompt_audio_file,output_dir):
10
+ from cosyvoice.utils.file_utils import load_wav
11
+ import torchaudio
12
+ global voice_engine
13
+ try:
14
+ final_output_file = os.path.join(output_dir,f'{ID}.wav')
15
+ prompt_speech_16k = load_wav(prompt_audio_file, 16000)
16
+ for output in voice_engine.inference_zero_shot(tts_text,prompt_text, prompt_speech_16k, stream=False,speed=1):
17
+ torchaudio.save(final_output_file, output['tts_speech'], voice_engine.sample_rate)
18
+ break # only save the first output
19
+ print(f'TTS {tts_text} and Save to {final_output_file} at process {os.getpid()}')
20
+ except Exception as e:
21
+ print(f'Error: {e}')
22
+ print(f'Error processing {ID} at process {os.getpid()}')
23
+ import traceback
24
+ traceback.print_exc()
25
+ return
26
+ if __name__ == '__main__':
27
+ import argparse
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--eval_dir", type=str, default='eval_data/seedtts_testset')
30
+ parser.add_argument("--language", type=str, default='zh',choices=['zh','en'])
31
+ parser.add_argument("--model_path", type=str, default='/home/yueyulin/models/CosyVoice2-0.5B_RWKV_1.5B/')
32
+ parser.add_argument("--device", type=str, default='cuda:0')
33
+ parser.add_argument("--num_processes", type=int, default=2)
34
+ parser.add_argument("--output_dir", type=str, default='generated')
35
+ parser.add_argument("--list_file", type=str, default='meta.lst')
36
+
37
+
38
+ args = parser.parse_args()
39
+ print(args)
40
+ output_dir = os.path.join(args.eval_dir,args.language,args.output_dir)
41
+ #first delete the output_dir
42
+ if os.path.exists(output_dir):
43
+ import shutil
44
+ shutil.rmtree(output_dir)
45
+ os.makedirs(output_dir)
46
+ list_file = os.path.join(args.eval_dir,args.language,args.list_file)
47
+ with open(list_file) as f:
48
+ lines = f.readlines()
49
+ lines = [line.strip() for line in lines]
50
+ print(f'Processing {len(lines)} lines')
51
+
52
+ from multiprocessing import Pool
53
+ from functools import partial
54
+ import time
55
+ with Pool(args.num_processes,init_process_func,(args.model_path,args.device)) as p:
56
+ for line in lines:
57
+ # 10002287-00000095|在此奉劝大家别乱打美白针。|prompt-wavs/10002287-00000094.wav|简单地说,这相当于惠普把消费领域市场拱手相让了。
58
+ parts = line.split('|')
59
+ ID = parts[0]
60
+ tts_text = parts[3]
61
+ prompt_text = parts[1]
62
+ prompt_audio_file = os.path.join(args.eval_dir,args.language,parts[2])
63
+ p.apply_async(do_tts,(ID,tts_text,prompt_text,prompt_audio_file,output_dir))
64
+ p.close()
65
+ p.join()
66
+ print('All done')
gradio/tts_demo_page.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import torch
4
+ import torchaudio
5
+ import gradio as gr
6
+ from cosyvoice.cli.cosyvoice import CosyVoice2
7
+ from cosyvoice.utils.file_utils import load_wav
8
+
9
+ # 全局变量
10
+ model_path = '/external_data/models/CosyVoice2-0.5B_RWKV_0.19B/'
11
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
12
+
13
+ # 在应用启动时初始化模型(全局共享)
14
+ print("正在初始化 CosyVoice2 模型...")
15
+ cosyvoice = CosyVoice2(model_path, device=device, fp16=True)
16
+ # 预热模型
17
+ cosyvoice.model.llm.dummy_forward()
18
+ print("模型初始化完成!")
19
+
20
+ def synthesize_speech(audio_file, prompt_text, tts_text):
21
+ """合成语音"""
22
+ global cosyvoice
23
+
24
+ if not audio_file or not prompt_text or not tts_text:
25
+ return None, "请提供所有必需的输入(提示音频、提示文本和要合成的文本)"
26
+
27
+ try:
28
+ # 加载提示音频
29
+ prompt_speech_16k = load_wav(audio_file, 16000)
30
+
31
+ # 执行推理
32
+ result = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=False)
33
+
34
+ # 获取合成的语音
35
+ output_speech = result[0]['tts_speech']
36
+
37
+ # 保存临时文件
38
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
39
+ temp_file.close()
40
+ torchaudio.save(temp_file.name, output_speech, cosyvoice.sample_rate)
41
+
42
+ return temp_file.name, f"语音合成成功!"
43
+ except Exception as e:
44
+ return None, f"合成过程中出错:{str(e)}"
45
+
46
+ # 创建 Gradio 界面
47
+ with gr.Blocks(title="RWKV TTS 演示") as demo:
48
+ gr.Markdown("# RWKV 语音合成演示")
49
+ gr.Markdown("### 语音合成系统已准备就绪,可直接使用")
50
+
51
+ with gr.Row():
52
+ with gr.Column():
53
+ audio_input = gr.Audio(type="filepath", label="上传提示音频文件(WAV 格式)")
54
+ prompt_text = gr.Textbox(label="提示文本(与提示音频对应的文字内容)", placeholder="例如:今天天气挺不错的。")
55
+ tts_text = gr.Textbox(label="要合成的文本", placeholder="例如:收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。")
56
+ synthesize_button = gr.Button("生成语音")
57
+
58
+ with gr.Column():
59
+ audio_output = gr.Audio(label="合成的语音")
60
+ output_message = gr.Textbox(label="状态信息")
61
+
62
+ synthesize_button.click(
63
+ fn=synthesize_speech,
64
+ inputs=[audio_input, prompt_text, tts_text],
65
+ outputs=[audio_output, output_message]
66
+ )
67
+
68
+ gr.Markdown("""
69
+ ## 使用说明
70
+
71
+ 1. 上传一个WAV格式的提示音频文件
72
+ 2. 输入与提示音频对应的文本内容
73
+ 3. 输入希望合成的文本
74
+ 4. 点击"生成语音"按钮进行语音合成
75
+
76
+ 注意:模型已在服务启动时预加载,所有用户共享同一个模型实例。
77
+ """)
78
+
79
+ # 启动应用
80
+ if __name__ == "__main__":
81
+ demo.launch()
mine.wav ADDED
Binary file (97 kB). View file
 
new.mp3 ADDED
Binary file (25.7 kB). View file
 
new.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e62a130a15a7560ebf8c1bd73212a9d6410a50e595de9a809bc64775a4a6f07
3
+ size 141964
run_multiple_process.sh ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export PYTHONPATH=/home/yueyulin/github/CosyVoice:/home/yueyulin/github/CosyVoice/third_party/Matcha-TTS/:/home/yueyulin/github/RWKVTTS
2
+
3
+ # 设置默认参数
4
+ LANGUAGE="zh"
5
+ OUTPUT_DIR="/home/yueyulin/data/speech_corpus"
6
+ COSY_MODEL_DIR="/home/yueyulin/models/CosyVoice2-0.5B/"
7
+ PROMPTS_DIR="extract_data/prompts/zh"
8
+ DEVICE="cuda:0"
9
+ PARQUET_FILES=()
10
+ JSONL_FILES=()
11
+ FILE_TYPE="" # 用于标记文件类型
12
+ is_cross_lingual=""
13
+ is_instructed=""
14
+
15
+ # 解析命令行参数
16
+ while [[ $# -gt 0 ]]; do
17
+ case $1 in
18
+ --language)
19
+ LANGUAGE="$2"
20
+ shift 2
21
+ ;;
22
+ --output_dir)
23
+ OUTPUT_DIR="$2"
24
+ shift 2
25
+ ;;
26
+ --cosy_model_dir)
27
+ COSY_MODEL_DIR="$2"
28
+ shift 2
29
+ ;;
30
+ --prompts_dir)
31
+ PROMPTS_DIR="$2"
32
+ shift 2
33
+ ;;
34
+ --parquet_files)
35
+ # 接收多个parquet文件路径
36
+ shift
37
+ while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
38
+ PARQUET_FILES+=("$1")
39
+ shift
40
+ done
41
+ FILE_TYPE="parquet"
42
+ ;;
43
+ --jsonl_files)
44
+ # 接收多个jsonl文件路径
45
+ shift
46
+ while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
47
+ JSONL_FILES+=("$1")
48
+ shift
49
+ done
50
+ FILE_TYPE="jsonl"
51
+ ;;
52
+ --device)
53
+ DEVICE="$2"
54
+ shift 2
55
+ ;;
56
+ --cross_lingual)
57
+ is_cross_lingual="--is_cross_lingual"
58
+ shift
59
+ ;;
60
+ --instructed)
61
+ is_instructed="--is_instructed"
62
+ shift
63
+ ;;
64
+ *)
65
+ echo "未知参数: $1"
66
+ exit 1
67
+ ;;
68
+ esac
69
+ done
70
+
71
+ # 检查是否提供了文件
72
+ if [ "$FILE_TYPE" == "parquet" ]; then
73
+ if [ ${#PARQUET_FILES[@]} -eq 0 ]; then
74
+ echo "错误: 未指定parquet文件,请使用 --parquet_files 参数"
75
+ exit 1
76
+ fi
77
+ FILES=("${PARQUET_FILES[@]}")
78
+ FILE_ARG="--parquet_files"
79
+ echo "将处理 ${#FILES[@]} 个parquet文件"
80
+ elif [ "$FILE_TYPE" == "jsonl" ]; then
81
+ if [ ${#JSONL_FILES[@]} -eq 0 ]; then
82
+ echo "错误: 未指定jsonl文件,请使用 --jsonl_files 参数"
83
+ exit 1
84
+ fi
85
+ FILES=("${JSONL_FILES[@]}")
86
+ FILE_ARG="--jsonl_files"
87
+ echo "将处理 ${#FILES[@]} 个jsonl文件"
88
+ else
89
+ echo "错误: 请使用 --parquet_files 或 --jsonl_files 参数指定输入文件"
90
+ exit 1
91
+ fi
92
+
93
+ echo "运行参数:"
94
+ echo "语言: $LANGUAGE"
95
+ echo "输出目录: $OUTPUT_DIR"
96
+ echo "模型目录: $COSY_MODEL_DIR"
97
+ echo "提示词目录: $PROMPTS_DIR"
98
+ echo "设备: $DEVICE"
99
+ echo "文件类型: $FILE_TYPE"
100
+
101
+ # 确保输出目录存在
102
+ mkdir -p $OUTPUT_DIR
103
+
104
+ # 启动处理进程,每个文件一个进程
105
+ for ((i=0; i<${#FILES[@]}; i++)); do
106
+ FILE="${FILES[$i]}"
107
+ FILENAME=$(basename "$FILE")
108
+
109
+ echo "处理文件 $FILENAME 使用 $DEVICE"
110
+
111
+ # 在后台启动进程
112
+ nohup python data/utils/utilitie.py \
113
+ --task generate_speech_tokens \
114
+ --language $LANGUAGE \
115
+ $is_cross_lingual \
116
+ $FILE_ARG "$FILE" \
117
+ --output_dir $OUTPUT_DIR \
118
+ --cosy_model_dir $COSY_MODEL_DIR \
119
+ --prompts_dir $PROMPTS_DIR \
120
+ $is_instructed \
121
+ --device "$DEVICE" > "$OUTPUT_DIR/log_${FILENAME%.*}.log" 2>&1 &
122
+
123
+ # 记录进程ID
124
+ PID=$!
125
+ echo "启动进程 PID: $PID 处理文件: $FILENAME 使用 $DEVICE"
126
+
127
+ # 等待一点时间确保进程启动
128
+ sleep 5
129
+ done
130
+
131
+ echo "所有处理进程已启动,日志文件保存在 $OUTPUT_DIR 目录"
132
+ echo "使用 'ps aux | grep utilitie.py' 命令查看运行状态"
133
+ echo "使用 'nvidia-smi' 命令监控GPU使用情况"
134
+
135
+ # 等待所有后台进程完成
136
+ wait
137
+ echo "所有处理已完成"
rwkvtts_requirements.txt ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiofiles==23.2.1
3
+ aiohappyeyeballs==2.4.8
4
+ aiohttp==3.11.13
5
+ aiosignal==1.3.2
6
+ alembic==1.15.1
7
+ altair==5.5.0
8
+ annotated-types==0.7.0
9
+ antlr4-python3-runtime==4.9.3
10
+ anyio==4.8.0
11
+ argon2-cffi==23.1.0
12
+ argon2-cffi-bindings==21.2.0
13
+ arrow==1.3.0
14
+ asttokens==3.0.0
15
+ async-lru==2.0.4
16
+ attrs==25.1.0
17
+ audioread==3.0.1
18
+ autopage==0.5.2
19
+ babel==2.17.0
20
+ beautifulsoup4==4.13.3
21
+ bleach==6.2.0
22
+ certifi==2025.1.31
23
+ cffi==1.17.1
24
+ cfgv==3.4.0
25
+ charset-normalizer==3.4.1
26
+ click==8.1.8
27
+ cliff==4.9.1
28
+ cmaes==0.11.1
29
+ cmd2==2.5.11
30
+ colorama==0.4.6
31
+ coloredlogs==15.0.1
32
+ colorlog==6.9.0
33
+ comm==0.2.2
34
+ conformer==0.3.2
35
+ contourpy==1.3.1
36
+ csvw==3.5.1
37
+ cycler==0.12.1
38
+ Cython==3.0.12
39
+ datasets==3.3.2
40
+ debugpy==1.8.13
41
+ decorator==5.2.1
42
+ deepspeed==0.16.4
43
+ defusedxml==0.7.1
44
+ diffusers==0.32.2
45
+ dill==0.3.8
46
+ distlib==0.3.9
47
+ dlinfo==2.0.0
48
+ einops==0.8.1
49
+ executing==2.2.0
50
+ fastapi==0.115.11
51
+ fastjsonschema==2.21.1
52
+ ffmpy==0.5.0
53
+ filelock==3.17.0
54
+ flatbuffers==25.2.10
55
+ fonttools==4.56.0
56
+ fqdn==1.5.1
57
+ frozenlist==1.5.0
58
+ fsspec==2024.12.0
59
+ gdown==5.2.0
60
+ gradio==3.43.2
61
+ gradio_client==0.5.0
62
+ greenlet==3.1.1
63
+ grpcio==1.70.0
64
+ h11==0.14.0
65
+ hjson==3.1.0
66
+ httpcore==1.0.7
67
+ httpx==0.28.1
68
+ huggingface-hub==0.29.1
69
+ humanfriendly==10.0
70
+ hydra-colorlog==1.2.0
71
+ hydra-core==1.3.2
72
+ hydra-optuna-sweeper==1.2.0
73
+ HyperPyYAML==1.2.2
74
+ identify==2.6.8
75
+ idna==3.10
76
+ importlib_metadata==8.6.1
77
+ importlib_resources==6.5.2
78
+ inflect==7.5.0
79
+ iniconfig==2.0.0
80
+ ipykernel==6.29.5
81
+ ipython==9.0.1
82
+ ipython_pygments_lexers==1.1.1
83
+ ipywidgets==8.1.5
84
+ isodate==0.7.2
85
+ isoduration==20.11.0
86
+ jedi==0.19.2
87
+ Jinja2==3.1.5
88
+ joblib==1.4.2
89
+ json5==0.10.0
90
+ jsonpointer==3.0.0
91
+ jsonschema==4.23.0
92
+ jsonschema-specifications==2024.10.1
93
+ jupyter-events==0.12.0
94
+ jupyter-lsp==2.2.5
95
+ jupyter_client==8.6.3
96
+ jupyter_core==5.7.2
97
+ jupyter_server==2.15.0
98
+ jupyter_server_terminals==0.5.3
99
+ jupyterlab==4.3.5
100
+ jupyterlab_pygments==0.3.0
101
+ jupyterlab_server==2.27.3
102
+ jupyterlab_widgets==3.0.13
103
+ kiwisolver==1.4.8
104
+ language-tags==1.2.0
105
+ lazy_loader==0.4
106
+ librosa==0.10.2.post1
107
+ lightning==2.5.0.post0
108
+ lightning-utilities==0.13.1
109
+ llvmlite==0.44.0
110
+ Mako==1.3.9
111
+ Markdown==3.7
112
+ markdown-it-py==3.0.0
113
+ MarkupSafe==2.1.5
114
+ matcha-tts==0.0.7.2
115
+ matplotlib==3.10.1
116
+ matplotlib-inline==0.1.7
117
+ mdurl==0.1.2
118
+ mistune==3.1.2
119
+ modelscope==1.23.2
120
+ more-itertools==10.6.0
121
+ mpmath==1.3.0
122
+ msgpack==1.1.0
123
+ multidict==6.1.0
124
+ multiprocess==0.70.16
125
+ narwhals==1.29.0
126
+ nbclient==0.10.2
127
+ nbconvert==7.16.6
128
+ nbformat==5.10.4
129
+ nest-asyncio==1.6.0
130
+ networkx==3.4.2
131
+ ninja==1.11.1.3
132
+ nodeenv==1.9.1
133
+ notebook==7.3.2
134
+ notebook_shim==0.2.4
135
+ numba==0.61.0
136
+ numpy==1.26.4
137
+ nvidia-cublas-cu12==12.4.5.8
138
+ nvidia-cuda-cupti-cu12==12.4.127
139
+ nvidia-cuda-nvrtc-cu12==12.4.127
140
+ nvidia-cuda-runtime-cu12==12.4.127
141
+ nvidia-cudnn-cu12==9.1.0.70
142
+ nvidia-cufft-cu12==11.2.1.3
143
+ nvidia-curand-cu12==10.3.5.147
144
+ nvidia-cusolver-cu12==11.6.1.9
145
+ nvidia-cusparse-cu12==12.3.1.170
146
+ nvidia-cusparselt-cu12==0.6.2
147
+ nvidia-nccl-cu12==2.21.5
148
+ nvidia-nvjitlink-cu12==12.4.127
149
+ nvidia-nvtx-cu12==12.4.127
150
+ omegaconf==2.3.0
151
+ onnx==1.17.0
152
+ onnxruntime-gpu==1.20.1
153
+ openai-whisper==20240930
154
+ optuna==2.10.1
155
+ orjson==3.10.15
156
+ overrides==7.7.0
157
+ packaging==24.2
158
+ pandas==2.2.3
159
+ pandocfilters==1.5.1
160
+ parso==0.8.4
161
+ pbr==6.1.1
162
+ pexpect==4.9.0
163
+ phonemizer==3.3.0
164
+ pillow==10.4.0
165
+ platformdirs==4.3.6
166
+ pluggy==1.5.0
167
+ pooch==1.8.2
168
+ pre_commit==4.1.0
169
+ prettytable==3.15.1
170
+ prometheus_client==0.21.1
171
+ prompt_toolkit==3.0.50
172
+ propcache==0.3.0
173
+ protobuf==6.30.0
174
+ psutil==7.0.0
175
+ ptyprocess==0.7.0
176
+ pure_eval==0.2.3
177
+ py-cpuinfo==9.0.0
178
+ pyarrow==19.0.1
179
+ pycparser==2.22
180
+ pydantic==2.10.6
181
+ pydantic_core==2.27.2
182
+ pydub==0.25.1
183
+ Pygments==2.19.1
184
+ pyparsing==3.2.1
185
+ pyperclip==1.9.0
186
+ PySocks==1.7.1
187
+ pytest==8.3.5
188
+ python-dateutil==2.9.0.post0
189
+ python-dotenv==1.0.1
190
+ python-json-logger==3.2.1
191
+ python-multipart==0.0.20
192
+ pytorch-lightning==2.5.0.post0
193
+ pytz==2025.1
194
+ pyworld==0.3.5
195
+ PyYAML==6.0.2
196
+ pyzmq==26.2.1
197
+ rdflib==7.1.3
198
+ referencing==0.36.2
199
+ regex==2024.11.6
200
+ requests==2.32.3
201
+ rfc3339-validator==0.1.4
202
+ rfc3986==1.5.0
203
+ rfc3986-validator==0.1.1
204
+ rich==13.9.4
205
+ rootutils==1.0.7
206
+ rpds-py==0.23.1
207
+ ruamel.yaml==0.18.10
208
+ ruamel.yaml.clib==0.2.12
209
+ rwkv-fla==0.7.202503020902
210
+ safetensors==0.5.3
211
+ scikit-learn==1.6.1
212
+ scipy==1.15.2
213
+ seaborn==0.13.2
214
+ segments==2.3.0
215
+ semantic-version==2.10.0
216
+ Send2Trash==1.8.3
217
+ six==1.17.0
218
+ sniffio==1.3.1
219
+ soundfile==0.13.1
220
+ soupsieve==2.6
221
+ soxr==0.5.0.post1
222
+ SQLAlchemy==2.0.38
223
+ stack-data==0.6.3
224
+ starlette==0.46.0
225
+ stevedore==5.4.1
226
+ sympy==1.13.1
227
+ tensorboard==2.19.0
228
+ tensorboard-data-server==0.7.2
229
+ terminado==0.18.1
230
+ threadpoolctl==3.5.0
231
+ tiktoken==0.9.0
232
+ tinycss2==1.4.0
233
+ tokenizers==0.21.0
234
+ torch==2.6.0
235
+ torchaudio==2.6.0
236
+ torchmetrics==1.6.2
237
+ torchvision==0.21.0
238
+ tornado==6.4.2
239
+ tqdm==4.67.1
240
+ traitlets==5.14.3
241
+ transformers==4.49.0
242
+ triton==3.2.0
243
+ typeguard==4.4.2
244
+ types-python-dateutil==2.9.0.20241206
245
+ typing_extensions==4.12.2
246
+ tzdata==2025.1
247
+ Unidecode==1.3.8
248
+ uri-template==1.3.0
249
+ uritemplate==4.1.1
250
+ urllib3==2.3.0
251
+ uvicorn==0.34.0
252
+ virtualenv==20.29.2
253
+ wcwidth==0.2.13
254
+ webcolors==24.11.1
255
+ webencodings==0.5.1
256
+ websocket-client==1.8.0
257
+ websockets==11.0.3
258
+ Werkzeug==3.1.3
259
+ WeTextProcessing==1.0.4.1
260
+ wget==3.2
261
+ widgetsnbextension==4.0.13
262
+ xxhash==3.5.0
263
+ yarl==1.18.3
264
+ zipp==3.21.0
third_party/cosyvoice/dataset/processor.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+ import pyworld as pw
24
+
25
+
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
+ df = df.to_pandas()
45
+ for i in range(len(df)):
46
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
47
+ continue
48
+ sample.update(dict(df.loc[i]))
49
+ if mode == 'train':
50
+ # NOTE do not return sample directly, must initialize a new dict
51
+ yield {**sample}
52
+ else:
53
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
54
+ yield {**sample, 'tts_index': index, 'tts_text': text}
55
+ except Exception as ex:
56
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
57
+
58
+
59
+ def filter(data,
60
+ max_length=10240,
61
+ min_length=10,
62
+ token_max_length=200,
63
+ token_min_length=1,
64
+ min_output_input_ratio=0.0005,
65
+ max_output_input_ratio=1,
66
+ mode='train'):
67
+ """ Filter sample according to feature and label length
68
+ Inplace operation.
69
+
70
+ Args::
71
+ data: Iterable[{key, wav, label, sample_rate}]
72
+ max_length: drop utterance which is greater than max_length(10ms)
73
+ min_length: drop utterance which is less than min_length(10ms)
74
+ token_max_length: drop utterance which is greater than
75
+ token_max_length, especially when use char unit for
76
+ english modeling
77
+ token_min_length: drop utterance which is
78
+ less than token_max_length
79
+ min_output_input_ratio: minimal ration of
80
+ token_length / feats_length(10ms)
81
+ max_output_input_ratio: maximum ration of
82
+ token_length / feats_length(10ms)
83
+
84
+ Returns:
85
+ Iterable[{key, wav, label, sample_rate}]
86
+ """
87
+ for sample in data:
88
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
89
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
90
+ del sample['audio_data']
91
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
92
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
93
+ if num_frames < min_length:
94
+ continue
95
+ if num_frames > max_length:
96
+ continue
97
+ if len(sample['text_token']) < token_min_length:
98
+ continue
99
+ if len(sample['text_token']) > token_max_length:
100
+ continue
101
+ if len(sample['speech_token']) == 0:
102
+ continue
103
+ if num_frames != 0:
104
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
105
+ continue
106
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
107
+ continue
108
+ yield sample
109
+
110
+
111
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
112
+ """ Resample data.
113
+ Inplace operation.
114
+
115
+ Args:
116
+ data: Iterable[{key, wav, label, sample_rate}]
117
+ resample_rate: target resample rate
118
+
119
+ Returns:
120
+ Iterable[{key, wav, label, sample_rate}]
121
+ """
122
+ for sample in data:
123
+ assert 'sample_rate' in sample
124
+ assert 'speech' in sample
125
+ sample_rate = sample['sample_rate']
126
+ waveform = sample['speech']
127
+ if sample_rate != resample_rate:
128
+ if sample_rate < min_sample_rate:
129
+ continue
130
+ sample['sample_rate'] = resample_rate
131
+ sample['speech'] = torchaudio.transforms.Resample(
132
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
133
+ max_val = sample['speech'].abs().max()
134
+ if max_val > 1:
135
+ sample['speech'] /= max_val
136
+ yield sample
137
+
138
+
139
+ def truncate(data, truncate_length=24576, mode='train'):
140
+ """ Truncate data.
141
+
142
+ Args:
143
+ data: Iterable[{key, wav, label, sample_rate}]
144
+ truncate_length: truncate length
145
+
146
+ Returns:
147
+ Iterable[{key, wav, label, sample_rate}]
148
+ """
149
+ for sample in data:
150
+ waveform = sample['speech']
151
+ if waveform.shape[1] > truncate_length:
152
+ start = random.randint(0, waveform.shape[1] - truncate_length)
153
+ waveform = waveform[:, start: start + truncate_length]
154
+ else:
155
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
156
+ sample['speech'] = waveform
157
+ yield sample
158
+
159
+
160
+ def compute_fbank(data,
161
+ feat_extractor,
162
+ mode='train'):
163
+ """ Extract fbank
164
+
165
+ Args:
166
+ data: Iterable[{key, wav, label, sample_rate}]
167
+
168
+ Returns:
169
+ Iterable[{key, feat, label}]
170
+ """
171
+ for sample in data:
172
+ assert 'sample_rate' in sample
173
+ assert 'speech' in sample
174
+ assert 'utt' in sample
175
+ assert 'text_token' in sample
176
+ waveform = sample['speech']
177
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
178
+ sample['speech_feat'] = mat
179
+ yield sample
180
+
181
+
182
+ def compute_f0(data, sample_rate, hop_size, mode='train'):
183
+ """ Extract f0
184
+
185
+ Args:
186
+ data: Iterable[{key, wav, label, sample_rate}]
187
+
188
+ Returns:
189
+ Iterable[{key, feat, label}]
190
+ """
191
+ frame_period = hop_size * 1000 / sample_rate
192
+ for sample in data:
193
+ assert 'sample_rate' in sample
194
+ assert 'speech' in sample
195
+ assert 'utt' in sample
196
+ assert 'text_token' in sample
197
+ waveform = sample['speech']
198
+ _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
199
+ if sum(_f0 != 0) < 5: # this happens when the algorithm fails
200
+ _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
201
+ f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
202
+ f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
203
+ sample['pitch_feat'] = f0
204
+ yield sample
205
+
206
+
207
+ def parse_embedding(data, normalize, mode='train'):
208
+ """ Parse utt_embedding/spk_embedding
209
+
210
+ Args:
211
+ data: Iterable[{key, wav, label, sample_rate}]
212
+
213
+ Returns:
214
+ Iterable[{key, feat, label}]
215
+ """
216
+ for sample in data:
217
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
218
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
219
+ if normalize:
220
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
221
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
222
+ yield sample
223
+
224
+
225
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
226
+ """ Decode text to chars or BPE
227
+ Inplace operation
228
+
229
+ Args:
230
+ data: Iterable[{key, wav, txt, sample_rate}]
231
+
232
+ Returns:
233
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
234
+ """
235
+ tokenizer = get_tokenizer()
236
+ for sample in data:
237
+ assert 'text' in sample
238
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
239
+ if mode == 'inference':
240
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
241
+ yield sample
242
+
243
+
244
+ def shuffle(data, shuffle_size=10000, mode='train'):
245
+ """ Local shuffle the data
246
+
247
+ Args:
248
+ data: Iterable[{key, feat, label}]
249
+ shuffle_size: buffer size for shuffle
250
+
251
+ Returns:
252
+ Iterable[{key, feat, label}]
253
+ """
254
+ buf = []
255
+ for sample in data:
256
+ buf.append(sample)
257
+ if len(buf) >= shuffle_size:
258
+ random.shuffle(buf)
259
+ for x in buf:
260
+ yield x
261
+ buf = []
262
+ # The sample left over
263
+ random.shuffle(buf)
264
+ for x in buf:
265
+ yield x
266
+
267
+
268
+ def sort(data, sort_size=500, mode='train'):
269
+ """ Sort the data by feature length.
270
+ Sort is used after shuffle and before batch, so we can group
271
+ utts with similar lengths into a batch, and `sort_size` should
272
+ be less than `shuffle_size`
273
+
274
+ Args:
275
+ data: Iterable[{key, feat, label}]
276
+ sort_size: buffer size for sort
277
+
278
+ Returns:
279
+ Iterable[{key, feat, label}]
280
+ """
281
+
282
+ buf = []
283
+ for sample in data:
284
+ buf.append(sample)
285
+ if len(buf) >= sort_size:
286
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
287
+ for x in buf:
288
+ yield x
289
+ buf = []
290
+ # The sample left over
291
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
292
+ for x in buf:
293
+ yield x
294
+
295
+
296
+ def static_batch(data, batch_size=16):
297
+ """ Static batch the data by `batch_size`
298
+
299
+ Args:
300
+ data: Iterable[{key, feat, label}]
301
+ batch_size: batch size
302
+
303
+ Returns:
304
+ Iterable[List[{key, feat, label}]]
305
+ """
306
+ buf = []
307
+ for sample in data:
308
+ buf.append(sample)
309
+ if len(buf) >= batch_size:
310
+ yield buf
311
+ buf = []
312
+ if len(buf) > 0:
313
+ yield buf
314
+
315
+
316
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
317
+ """ Dynamic batch the data until the total frames in batch
318
+ reach `max_frames_in_batch`
319
+
320
+ Args:
321
+ data: Iterable[{key, feat, label}]
322
+ max_frames_in_batch: max_frames in one batch
323
+
324
+ Returns:
325
+ Iterable[List[{key, feat, label}]]
326
+ """
327
+ buf = []
328
+ longest_frames = 0
329
+ for sample in data:
330
+ assert 'speech_feat' in sample
331
+ assert isinstance(sample['speech_feat'], torch.Tensor)
332
+ new_sample_frames = sample['speech_feat'].size(0)
333
+ longest_frames = max(longest_frames, new_sample_frames)
334
+ frames_after_padding = longest_frames * (len(buf) + 1)
335
+ if frames_after_padding > max_frames_in_batch:
336
+ yield buf
337
+ buf = [sample]
338
+ longest_frames = new_sample_frames
339
+ else:
340
+ buf.append(sample)
341
+ if len(buf) > 0:
342
+ yield buf
343
+
344
+
345
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
346
+ """ Wrapper for static/dynamic batch
347
+ """
348
+ if mode == 'inference':
349
+ return static_batch(data, 1)
350
+ else:
351
+ if batch_type == 'static':
352
+ return static_batch(data, batch_size)
353
+ elif batch_type == 'dynamic':
354
+ return dynamic_batch(data, max_frames_in_batch)
355
+ else:
356
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
357
+
358
+
359
+ def padding(data, use_spk_embedding, mode='train', gan=False):
360
+ """ Padding the data into training data
361
+
362
+ Args:
363
+ data: Iterable[List[{key, feat, label}]]
364
+
365
+ Returns:
366
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
367
+ """
368
+ for sample in data:
369
+ assert isinstance(sample, list)
370
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
371
+ dtype=torch.int32)
372
+ order = torch.argsort(speech_feat_len, descending=True)
373
+
374
+ utts = [sample[i]['utt'] for i in order]
375
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
376
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
377
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
378
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
379
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
380
+ speech_token = pad_sequence(speech_token,
381
+ batch_first=True,
382
+ padding_value=0)
383
+ speech_feat = [sample[i]['speech_feat'] for i in order]
384
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
385
+ speech_feat = pad_sequence(speech_feat,
386
+ batch_first=True,
387
+ padding_value=0)
388
+ text = [sample[i]['text'] for i in order]
389
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
390
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
391
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
392
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
393
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
394
+ batch = {
395
+ "utts": utts,
396
+ "speech": speech,
397
+ "speech_len": speech_len,
398
+ "speech_token": speech_token,
399
+ "speech_token_len": speech_token_len,
400
+ "speech_feat": speech_feat,
401
+ "speech_feat_len": speech_feat_len,
402
+ "text": text,
403
+ "text_token": text_token,
404
+ "text_token_len": text_token_len,
405
+ "utt_embedding": utt_embedding,
406
+ "spk_embedding": spk_embedding,
407
+ }
408
+ if gan is True:
409
+ # in gan train, we need pitch_feat
410
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
411
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
412
+ pitch_feat = pad_sequence(pitch_feat,
413
+ batch_first=True,
414
+ padding_value=0)
415
+ batch["pitch_feat"] = pitch_feat
416
+ batch["pitch_feat_len"] = pitch_feat_len
417
+ else:
418
+ # only gan train needs speech, delete it to save memory
419
+ del batch["speech"]
420
+ del batch["speech_len"]
421
+ if mode == 'inference':
422
+ tts_text = [sample[i]['tts_text'] for i in order]
423
+ tts_index = [sample[i]['tts_index'] for i in order]
424
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
425
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
426
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
427
+ batch.update({'tts_text': tts_text,
428
+ 'tts_index': tts_index,
429
+ 'tts_text_token': tts_text_token,
430
+ 'tts_text_token_len': tts_text_token_len})
431
+ if use_spk_embedding is True:
432
+ batch["embedding"] = batch["spk_embedding"]
433
+ else:
434
+ batch["embedding"] = batch["utt_embedding"]
435
+ yield batch
third_party/cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import pack, rearrange, repeat
18
+ from cosyvoice.utils.common import mask_to_bias
19
+ from cosyvoice.utils.mask import add_optional_chunk_mask
20
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
21
+ from matcha.models.components.transformer import BasicTransformerBlock
22
+
23
+
24
+ class Transpose(torch.nn.Module):
25
+ def __init__(self, dim0: int, dim1: int):
26
+ super().__init__()
27
+ self.dim0 = dim0
28
+ self.dim1 = dim1
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = torch.transpose(x, self.dim0, self.dim1)
32
+ return x
33
+
34
+
35
+ class CausalBlock1D(Block1D):
36
+ def __init__(self, dim: int, dim_out: int):
37
+ super(CausalBlock1D, self).__init__(dim, dim_out)
38
+ self.block = torch.nn.Sequential(
39
+ CausalConv1d(dim, dim_out, 3),
40
+ Transpose(1, 2),
41
+ nn.LayerNorm(dim_out),
42
+ Transpose(1, 2),
43
+ nn.Mish(),
44
+ )
45
+
46
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
47
+ output = self.block(x * mask)
48
+ return output * mask
49
+
50
+
51
+ class CausalResnetBlock1D(ResnetBlock1D):
52
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
53
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
54
+ self.block1 = CausalBlock1D(dim, dim_out)
55
+ self.block2 = CausalBlock1D(dim_out, dim_out)
56
+
57
+
58
+ class CausalConv1d(torch.nn.Conv1d):
59
+ def __init__(
60
+ self,
61
+ in_channels: int,
62
+ out_channels: int,
63
+ kernel_size: int,
64
+ stride: int = 1,
65
+ dilation: int = 1,
66
+ groups: int = 1,
67
+ bias: bool = True,
68
+ padding_mode: str = 'zeros',
69
+ device=None,
70
+ dtype=None
71
+ ) -> None:
72
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
73
+ kernel_size, stride,
74
+ padding=0, dilation=dilation,
75
+ groups=groups, bias=bias,
76
+ padding_mode=padding_mode,
77
+ device=device, dtype=dtype)
78
+ assert stride == 1
79
+ self.causal_padding = (kernel_size - 1, 0)
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ x = F.pad(x, self.causal_padding)
83
+ x = super(CausalConv1d, self).forward(x)
84
+ return x
85
+
86
+
87
+ class ConditionalDecoder(nn.Module):
88
+ def __init__(
89
+ self,
90
+ in_channels,
91
+ out_channels,
92
+ causal=False,
93
+ channels=(256, 256),
94
+ dropout=0.05,
95
+ attention_head_dim=64,
96
+ n_blocks=1,
97
+ num_mid_blocks=2,
98
+ num_heads=4,
99
+ act_fn="snake",
100
+ ):
101
+ """
102
+ This decoder requires an input with the same shape of the target. So, if your text content
103
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
104
+ """
105
+ super().__init__()
106
+ channels = tuple(channels)
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+ self.causal = causal
110
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
111
+ time_embed_dim = channels[0] * 4
112
+ self.time_mlp = TimestepEmbedding(
113
+ in_channels=in_channels,
114
+ time_embed_dim=time_embed_dim,
115
+ act_fn="silu",
116
+ )
117
+ self.down_blocks = nn.ModuleList([])
118
+ self.mid_blocks = nn.ModuleList([])
119
+ self.up_blocks = nn.ModuleList([])
120
+
121
+ output_channel = in_channels
122
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
123
+ input_channel = output_channel
124
+ output_channel = channels[i]
125
+ is_last = i == len(channels) - 1
126
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
127
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
128
+ transformer_blocks = nn.ModuleList(
129
+ [
130
+ BasicTransformerBlock(
131
+ dim=output_channel,
132
+ num_attention_heads=num_heads,
133
+ attention_head_dim=attention_head_dim,
134
+ dropout=dropout,
135
+ activation_fn=act_fn,
136
+ )
137
+ for _ in range(n_blocks)
138
+ ]
139
+ )
140
+ downsample = (
141
+ Downsample1D(output_channel) if not is_last else
142
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
143
+ )
144
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
145
+
146
+ for _ in range(num_mid_blocks):
147
+ input_channel = channels[-1]
148
+ out_channels = channels[-1]
149
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
150
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
151
+
152
+ transformer_blocks = nn.ModuleList(
153
+ [
154
+ BasicTransformerBlock(
155
+ dim=output_channel,
156
+ num_attention_heads=num_heads,
157
+ attention_head_dim=attention_head_dim,
158
+ dropout=dropout,
159
+ activation_fn=act_fn,
160
+ )
161
+ for _ in range(n_blocks)
162
+ ]
163
+ )
164
+
165
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
166
+
167
+ channels = channels[::-1] + (channels[0],)
168
+ for i in range(len(channels) - 1):
169
+ input_channel = channels[i] * 2
170
+ output_channel = channels[i + 1]
171
+ is_last = i == len(channels) - 2
172
+ resnet = CausalResnetBlock1D(
173
+ dim=input_channel,
174
+ dim_out=output_channel,
175
+ time_emb_dim=time_embed_dim,
176
+ ) if self.causal else ResnetBlock1D(
177
+ dim=input_channel,
178
+ dim_out=output_channel,
179
+ time_emb_dim=time_embed_dim,
180
+ )
181
+ transformer_blocks = nn.ModuleList(
182
+ [
183
+ BasicTransformerBlock(
184
+ dim=output_channel,
185
+ num_attention_heads=num_heads,
186
+ attention_head_dim=attention_head_dim,
187
+ dropout=dropout,
188
+ activation_fn=act_fn,
189
+ )
190
+ for _ in range(n_blocks)
191
+ ]
192
+ )
193
+ upsample = (
194
+ Upsample1D(output_channel, use_conv_transpose=True)
195
+ if not is_last
196
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
197
+ )
198
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
199
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
200
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
201
+ self.initialize_weights()
202
+
203
+ def initialize_weights(self):
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv1d):
206
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
207
+ if m.bias is not None:
208
+ nn.init.constant_(m.bias, 0)
209
+ elif isinstance(m, nn.GroupNorm):
210
+ nn.init.constant_(m.weight, 1)
211
+ nn.init.constant_(m.bias, 0)
212
+ elif isinstance(m, nn.Linear):
213
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
214
+ if m.bias is not None:
215
+ nn.init.constant_(m.bias, 0)
216
+
217
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
218
+ """Forward pass of the UNet1DConditional model.
219
+
220
+ Args:
221
+ x (torch.Tensor): shape (batch_size, in_channels, time)
222
+ mask (_type_): shape (batch_size, 1, time)
223
+ t (_type_): shape (batch_size)
224
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
225
+ cond (_type_, optional): placeholder for future use. Defaults to None.
226
+
227
+ Raises:
228
+ ValueError: _description_
229
+ ValueError: _description_
230
+
231
+ Returns:
232
+ _type_: _description_
233
+ """
234
+
235
+ t = self.time_embeddings(t).to(t.dtype)
236
+ t = self.time_mlp(t)
237
+
238
+ x = pack([x, mu], "b * t")[0]
239
+
240
+ if spks is not None:
241
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
242
+ x = pack([x, spks], "b * t")[0]
243
+ if cond is not None:
244
+ x = pack([x, cond], "b * t")[0]
245
+
246
+ hiddens = []
247
+ masks = [mask]
248
+ for resnet, transformer_blocks, downsample in self.down_blocks:
249
+ mask_down = masks[-1]
250
+ x = resnet(x, mask_down, t)
251
+ x = rearrange(x, "b c t -> b t c").contiguous()
252
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
253
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
254
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
255
+ for transformer_block in transformer_blocks:
256
+ x = transformer_block(
257
+ hidden_states=x,
258
+ attention_mask=attn_mask,
259
+ timestep=t,
260
+ )
261
+ x = rearrange(x, "b t c -> b c t").contiguous()
262
+ hiddens.append(x) # Save hidden states for skip connections
263
+ x = downsample(x * mask_down)
264
+ masks.append(mask_down[:, :, ::2])
265
+ masks = masks[:-1]
266
+ mask_mid = masks[-1]
267
+
268
+ for resnet, transformer_blocks in self.mid_blocks:
269
+ x = resnet(x, mask_mid, t)
270
+ x = rearrange(x, "b c t -> b t c").contiguous()
271
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
272
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
273
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
274
+ for transformer_block in transformer_blocks:
275
+ x = transformer_block(
276
+ hidden_states=x,
277
+ attention_mask=attn_mask,
278
+ timestep=t,
279
+ )
280
+ x = rearrange(x, "b t c -> b c t").contiguous()
281
+
282
+ for resnet, transformer_blocks, upsample in self.up_blocks:
283
+ mask_up = masks.pop()
284
+ skip = hiddens.pop()
285
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
286
+ x = resnet(x, mask_up, t)
287
+ x = rearrange(x, "b c t -> b t c").contiguous()
288
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
289
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
290
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
291
+ for transformer_block in transformer_blocks:
292
+ x = transformer_block(
293
+ hidden_states=x,
294
+ attention_mask=attn_mask,
295
+ timestep=t,
296
+ )
297
+ x = rearrange(x, "b t c -> b c t").contiguous()
298
+ x = upsample(x * mask_up)
299
+ x = self.final_block(x, mask_up)
300
+ output = self.final_proj(x * mask_up)
301
+ return output * mask
third_party/cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+ self.output_size = output_size
46
+ self.decoder_conf = decoder_conf
47
+ self.mel_feat_conf = mel_feat_conf
48
+ self.vocab_size = vocab_size
49
+ self.output_type = output_type
50
+ self.input_frame_rate = input_frame_rate
51
+ logging.info(f"input frame rate={self.input_frame_rate}")
52
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
53
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
54
+ self.encoder = encoder
55
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
56
+ self.decoder = decoder
57
+ self.length_regulator = length_regulator
58
+ self.only_mask_loss = only_mask_loss
59
+
60
+ def forward(
61
+ self,
62
+ batch: dict,
63
+ device: torch.device,
64
+ ) -> Dict[str, Optional[torch.Tensor]]:
65
+ token = batch['speech_token'].to(device)
66
+ token_len = batch['speech_token_len'].to(device)
67
+ feat = batch['speech_feat'].to(device)
68
+ feat_len = batch['speech_feat_len'].to(device)
69
+ embedding = batch['embedding'].to(device)
70
+
71
+ # xvec projection
72
+ embedding = F.normalize(embedding, dim=1)
73
+ embedding = self.spk_embed_affine_layer(embedding)
74
+
75
+ # concat text and prompt_text
76
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
78
+
79
+ # text encode
80
+ h, h_lengths = self.encoder(token, token_len)
81
+ h = self.encoder_proj(h)
82
+ h, h_lengths = self.length_regulator(h, feat_len)
83
+
84
+ # get conditions
85
+ conds = torch.zeros(feat.shape, device=token.device)
86
+ for i, j in enumerate(feat_len):
87
+ if random.random() < 0.5:
88
+ continue
89
+ index = random.randint(0, int(0.3 * j))
90
+ conds[i, :index] = feat[i, :index]
91
+ conds = conds.transpose(1, 2)
92
+
93
+ mask = (~make_pad_mask(feat_len)).to(h)
94
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
95
+ loss, _ = self.decoder.compute_loss(
96
+ feat.transpose(1, 2).contiguous(),
97
+ mask.unsqueeze(1),
98
+ h.transpose(1, 2).contiguous(),
99
+ embedding,
100
+ cond=conds
101
+ )
102
+ return {'loss': loss}
103
+
104
+ @torch.inference_mode()
105
+ def inference(self,
106
+ token,
107
+ token_len,
108
+ prompt_token,
109
+ prompt_token_len,
110
+ prompt_feat,
111
+ prompt_feat_len,
112
+ embedding,
113
+ flow_cache):
114
+ if self.fp16 is True:
115
+ prompt_feat = prompt_feat.half()
116
+ embedding = embedding.half()
117
+
118
+ assert token.shape[0] == 1
119
+ # xvec projection
120
+ embedding = F.normalize(embedding, dim=1)
121
+ embedding = self.spk_embed_affine_layer(embedding)
122
+
123
+ # concat text and prompt_text
124
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
125
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
126
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
127
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
128
+
129
+ # text encode
130
+ h, h_lengths = self.encoder(token, token_len)
131
+ h = self.encoder_proj(h)
132
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
133
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
134
+
135
+ # get conditions
136
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
137
+ conds[:, :mel_len1] = prompt_feat
138
+ conds = conds.transpose(1, 2)
139
+
140
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
141
+ feat, flow_cache = self.decoder(
142
+ mu=h.transpose(1, 2).contiguous(),
143
+ mask=mask.unsqueeze(1),
144
+ spks=embedding,
145
+ cond=conds,
146
+ n_timesteps=10,
147
+ prompt_len=mel_len1,
148
+ flow_cache=flow_cache
149
+ )
150
+ feat = feat[:, :, mel_len1:]
151
+ assert feat.shape[2] == mel_len2
152
+ return feat.float(), flow_cache
153
+
154
+
155
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
156
+ def __init__(self,
157
+ input_size: int = 512,
158
+ output_size: int = 80,
159
+ spk_embed_dim: int = 192,
160
+ output_type: str = "mel",
161
+ vocab_size: int = 4096,
162
+ input_frame_rate: int = 50,
163
+ only_mask_loss: bool = True,
164
+ token_mel_ratio: int = 2,
165
+ pre_lookahead_len: int = 3,
166
+ encoder: torch.nn.Module = None,
167
+ decoder: torch.nn.Module = None,
168
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
169
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
170
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
171
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
172
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
173
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
174
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
175
+ super().__init__()
176
+ self.input_size = input_size
177
+ self.output_size = output_size
178
+ self.decoder_conf = decoder_conf
179
+ self.mel_feat_conf = mel_feat_conf
180
+ self.vocab_size = vocab_size
181
+ self.output_type = output_type
182
+ self.input_frame_rate = input_frame_rate
183
+ logging.info(f"input frame rate={self.input_frame_rate}")
184
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
185
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
186
+ self.encoder = encoder
187
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
188
+ self.decoder = decoder
189
+ self.only_mask_loss = only_mask_loss
190
+ self.token_mel_ratio = token_mel_ratio
191
+ self.pre_lookahead_len = pre_lookahead_len
192
+
193
+ @torch.inference_mode()
194
+ def inference(self,
195
+ token,
196
+ token_len,
197
+ prompt_token,
198
+ prompt_token_len,
199
+ prompt_feat,
200
+ prompt_feat_len,
201
+ embedding,
202
+ finalize):
203
+ if self.fp16 is True:
204
+ prompt_feat = prompt_feat.half()
205
+ embedding = embedding.half()
206
+
207
+ assert token.shape[0] == 1
208
+ # xvec projection
209
+ embedding = F.normalize(embedding, dim=1)
210
+ embedding = self.spk_embed_affine_layer(embedding)
211
+
212
+ # concat text and prompt_text
213
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
214
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
215
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
216
+
217
+ # text encode
218
+ h, h_lengths = self.encoder(token, token_len)
219
+ if finalize is False:
220
+ h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
221
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
222
+ h = self.encoder_proj(h)
223
+
224
+ # get conditions
225
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
226
+ conds[:, :mel_len1] = prompt_feat
227
+ conds = conds.transpose(1, 2)
228
+
229
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
230
+ feat, _ = self.decoder(
231
+ mu=h.transpose(1, 2).contiguous(),
232
+ mask=mask.unsqueeze(1),
233
+ spks=embedding,
234
+ cond=conds,
235
+ n_timesteps=10
236
+ )
237
+ feat = feat[:, :, mel_len1:]
238
+ assert feat.shape[2] == mel_len2
239
+ return feat.float(), None
third_party/cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import threading
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from matcha.models.components.flow_matching import BASECFM
18
+
19
+
20
+ class ConditionalCFM(BASECFM):
21
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
22
+ super().__init__(
23
+ n_feats=in_channels,
24
+ cfm_params=cfm_params,
25
+ n_spks=n_spks,
26
+ spk_emb_dim=spk_emb_dim,
27
+ )
28
+ self.t_scheduler = cfm_params.t_scheduler
29
+ self.training_cfg_rate = cfm_params.training_cfg_rate
30
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
31
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
32
+ # Just change the architecture of the estimator here
33
+ self.estimator = estimator
34
+ self.lock = threading.Lock()
35
+
36
+ @torch.inference_mode()
37
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
38
+ """Forward diffusion
39
+
40
+ Args:
41
+ mu (torch.Tensor): output of encoder
42
+ shape: (batch_size, n_feats, mel_timesteps)
43
+ mask (torch.Tensor): output_mask
44
+ shape: (batch_size, 1, mel_timesteps)
45
+ n_timesteps (int): number of diffusion steps
46
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
47
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
48
+ shape: (batch_size, spk_emb_dim)
49
+ cond: Not used but kept for future purposes
50
+
51
+ Returns:
52
+ sample: generated mel-spectrogram
53
+ shape: (batch_size, n_feats, mel_timesteps)
54
+ """
55
+
56
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
57
+ cache_size = flow_cache.shape[2]
58
+ # fix prompt and overlap part mu and z
59
+ if cache_size != 0:
60
+ z[:, :, :cache_size] = flow_cache[:, :, :, 0]
61
+ mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
62
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
63
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
64
+ flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
65
+
66
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
67
+ if self.t_scheduler == 'cosine':
68
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
69
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
70
+
71
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
72
+ """
73
+ Fixed euler solver for ODEs.
74
+ Args:
75
+ x (torch.Tensor): random noise
76
+ t_span (torch.Tensor): n_timesteps interpolated
77
+ shape: (n_timesteps + 1,)
78
+ mu (torch.Tensor): output of encoder
79
+ shape: (batch_size, n_feats, mel_timesteps)
80
+ mask (torch.Tensor): output_mask
81
+ shape: (batch_size, 1, mel_timesteps)
82
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
83
+ shape: (batch_size, spk_emb_dim)
84
+ cond: Not used but kept for future purposes
85
+ """
86
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
87
+ t = t.unsqueeze(dim=0)
88
+
89
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
90
+ # Or in future might add like a return_all_steps flag
91
+ sol = []
92
+
93
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
95
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
97
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
98
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
99
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
100
+ for step in range(1, len(t_span)):
101
+ # Classifier-Free Guidance inference introduced in VoiceBox
102
+ x_in[:] = x
103
+ mask_in[:] = mask
104
+ mu_in[0] = mu
105
+ t_in[:] = t.unsqueeze(0)
106
+ spks_in[0] = spks
107
+ cond_in[0] = cond
108
+ dphi_dt = self.forward_estimator(
109
+ x_in, mask_in,
110
+ mu_in, t_in,
111
+ spks_in,
112
+ cond_in
113
+ )
114
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
115
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
116
+ x = x + dt * dphi_dt
117
+ t = t + dt
118
+ sol.append(x)
119
+ if step < len(t_span) - 1:
120
+ dt = t_span[step + 1] - t
121
+
122
+ return sol[-1].float()
123
+
124
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
125
+ if isinstance(self.estimator, torch.nn.Module):
126
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
127
+ else:
128
+ with self.lock:
129
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
130
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
131
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
132
+ self.estimator.set_input_shape('t', (2,))
133
+ self.estimator.set_input_shape('spks', (2, 80))
134
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
135
+ # run trt engine
136
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
137
+ mask.contiguous().data_ptr(),
138
+ mu.contiguous().data_ptr(),
139
+ t.contiguous().data_ptr(),
140
+ spks.contiguous().data_ptr(),
141
+ cond.contiguous().data_ptr(),
142
+ x.data_ptr()])
143
+ return x
144
+
145
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
146
+ """Computes diffusion loss
147
+
148
+ Args:
149
+ x1 (torch.Tensor): Target
150
+ shape: (batch_size, n_feats, mel_timesteps)
151
+ mask (torch.Tensor): target mask
152
+ shape: (batch_size, 1, mel_timesteps)
153
+ mu (torch.Tensor): output of encoder
154
+ shape: (batch_size, n_feats, mel_timesteps)
155
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
156
+ shape: (batch_size, spk_emb_dim)
157
+
158
+ Returns:
159
+ loss: conditional flow matching loss
160
+ y: conditional flow
161
+ shape: (batch_size, n_feats, mel_timesteps)
162
+ """
163
+ b, _, t = mu.shape
164
+
165
+ # random timestep
166
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
167
+ if self.t_scheduler == 'cosine':
168
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
169
+ # sample noise p(x_0)
170
+ z = torch.randn_like(x1)
171
+
172
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
173
+ u = x1 - (1 - self.sigma_min) * z
174
+
175
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
176
+ if self.training_cfg_rate > 0:
177
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
178
+ mu = mu * cfg_mask.view(-1, 1, 1)
179
+ spks = spks * cfg_mask.view(-1, 1)
180
+ cond = cond * cfg_mask.view(-1, 1, 1)
181
+
182
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
183
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
184
+ return loss, y
185
+
186
+
187
+ class CausalConditionalCFM(ConditionalCFM):
188
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
189
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
190
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
191
+
192
+ @torch.inference_mode()
193
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
194
+ """Forward diffusion
195
+
196
+ Args:
197
+ mu (torch.Tensor): output of encoder
198
+ shape: (batch_size, n_feats, mel_timesteps)
199
+ mask (torch.Tensor): output_mask
200
+ shape: (batch_size, 1, mel_timesteps)
201
+ n_timesteps (int): number of diffusion steps
202
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
203
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
204
+ shape: (batch_size, spk_emb_dim)
205
+ cond: Not used but kept for future purposes
206
+
207
+ Returns:
208
+ sample: generated mel-spectrogram
209
+ shape: (batch_size, n_feats, mel_timesteps)
210
+ """
211
+
212
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
213
+ # fix prompt and overlap part mu and z
214
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
215
+ if self.t_scheduler == 'cosine':
216
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
217
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
third_party/cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from cosyvoice.utils.mask import make_pad_mask
19
+
20
+
21
+ class InterpolateRegulator(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ sampling_ratios: Tuple,
26
+ out_channels: int = None,
27
+ groups: int = 1,
28
+ ):
29
+ super().__init__()
30
+ self.sampling_ratios = sampling_ratios
31
+ out_channels = out_channels or channels
32
+ model = nn.ModuleList([])
33
+ if len(sampling_ratios) > 0:
34
+ for _ in sampling_ratios:
35
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
36
+ norm = nn.GroupNorm(groups, channels)
37
+ act = nn.Mish()
38
+ model.extend([module, norm, act])
39
+ model.append(
40
+ nn.Conv1d(channels, out_channels, 1, 1)
41
+ )
42
+ self.model = nn.Sequential(*model)
43
+
44
+ def forward(self, x, ylens=None):
45
+ # x in (B, T, D)
46
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48
+ out = self.model(x).transpose(1, 2).contiguous()
49
+ olens = ylens
50
+ return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # x in (B, T, D)
55
+ if x2.shape[1] > 40:
56
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
57
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
58
+ mode='linear')
59
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
60
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
61
+ else:
62
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
63
+ if x1.shape[1] != 0:
64
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
65
+ x = torch.concat([x1, x2], dim=2)
66
+ else:
67
+ x = x2
68
+ out = self.model(x).transpose(1, 2).contiguous()
69
+ return out, mel_len1 + mel_len2