Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +11 -0
- 00000309-00000300.wav +3 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.gitattributes +38 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.msc +0 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.mv +1 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/README.md +119 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/__init__.py +0 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/added_tokens.json +3 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/config.json +39 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/generation_config.json +11 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/hf_rwkv_tokenizer.py +279 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/modeling_rwkv7.py +4 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/rwkv_vocab_v20230424.txt +0 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/special_tokens_map.json +6 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/tokenizer_config.json +28 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/README.md +14 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/asset/dingding.png +0 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/campplus.onnx +3 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/configuration.json +1 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/cosyvoice.yaml +116 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/flow.encoder.fp16.zip +3 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/hift.pt +3 -0
- CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/spk2info.pt +3 -0
- Inference.md +98 -0
- LICENSE +201 -0
- README.md +181 -3
- Trump.wav +3 -0
- _config.yml +3 -0
- another.wav +3 -0
- badXT_71.wav +3 -0
- data/cosy/data/data_processor.py +128 -0
- data/cosy/test/test_vq.py +171 -0
- data/utils/convert_embeddings_2_pt.py +34 -0
- data/utils/create_embeddings_from_raw.py +263 -0
- data/utils/create_lm_corpus_from_raw.py +156 -0
- data/utils/llm_dataset.py +206 -0
- data/utils/test_utilities.py +31 -0
- data/utils/utilitie.py +767 -0
- eval/eval_seed_generate.py +66 -0
- gradio/tts_demo_page.py +81 -0
- mine.wav +0 -0
- new.mp3 +0 -0
- new.wav +3 -0
- run_multiple_process.sh +137 -0
- rwkvtts_requirements.txt +264 -0
- third_party/cosyvoice/dataset/processor.py +435 -0
- third_party/cosyvoice/flow/decoder.py +301 -0
- third_party/cosyvoice/flow/flow.py +239 -0
- third_party/cosyvoice/flow/flow_matching.py +217 -0
- third_party/cosyvoice/flow/length_regulator.py +69 -0
.gitattributes
CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
Trump.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
new.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
badXT_71.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
zero_shot_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
+
00000309-00000300.wav filter=lfs diff=lfs merge=lfs -text
|
41 |
+
another.wav filter=lfs diff=lfs merge=lfs -text
|
42 |
+
zero_2_0.wav filter=lfs diff=lfs merge=lfs -text
|
43 |
+
zero_shot_0.wav filter=lfs diff=lfs merge=lfs -text
|
44 |
+
zero_1_0.wav filter=lfs diff=lfs merge=lfs -text
|
45 |
+
zero_3_0.wav filter=lfs diff=lfs merge=lfs -text
|
46 |
+
zero_0_0.wav filter=lfs diff=lfs merge=lfs -text
|
00000309-00000300.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:631608f5c8b931ece1d45adc7f40a3b3b0ae2ec056a8a08a3565b04cc5750a4b
|
3 |
+
size 243244
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.gitattributes
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
flow.decoder.estimator.fp16.a10.plan filter=lfs diff=lfs merge=lfs -text
|
37 |
+
flow.decoder.estimator.fp16.l20.plan filter=lfs diff=lfs merge=lfs -text
|
38 |
+
flow.decoder.estimator.fp16.v100.plan filter=lfs diff=lfs merge=lfs -text
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.msc
ADDED
Binary file (1.66 kB). View file
|
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/.mv
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Revision:master,CreatedAt:1736490687
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/README.md
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
- zh
|
6 |
+
- ja
|
7 |
+
- ko
|
8 |
+
- fr
|
9 |
+
- ar
|
10 |
+
- es
|
11 |
+
- pt
|
12 |
+
metrics:
|
13 |
+
- accuracy
|
14 |
+
base_model:
|
15 |
+
- BlinkDL/rwkv-7-world
|
16 |
+
pipeline_tag: text-generation
|
17 |
+
---
|
18 |
+
|
19 |
+
# rwkv7-1.5B-world
|
20 |
+
|
21 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
22 |
+
|
23 |
+
This is RWKV-7 model under flash-linear attention format.
|
24 |
+
|
25 |
+
## Model Details
|
26 |
+
|
27 |
+
|
28 |
+
### Model Description
|
29 |
+
|
30 |
+
<!-- Provide a longer summary of what this model is. -->
|
31 |
+
|
32 |
+
- **Developed by:** Bo Peng, Yu Zhang, Songlin Yang, Ruichong Zhang
|
33 |
+
- **Funded by:** RWKV Project (Under LF AI & Data Foundation)
|
34 |
+
- **Model type:** RWKV7
|
35 |
+
- **Language(s) (NLP):** English
|
36 |
+
- **License:** Apache-2.0
|
37 |
+
- **Parameter count:** 1.52B
|
38 |
+
- **Tokenizer:** RWKV World tokenizer
|
39 |
+
- **Vocabulary size:** 65,536
|
40 |
+
|
41 |
+
### Model Sources
|
42 |
+
|
43 |
+
<!-- Provide the basic links for the model. -->
|
44 |
+
|
45 |
+
- **Repository:** https://github.com/fla-org/flash-linear-attention ; https://github.com/BlinkDL/RWKV-LM
|
46 |
+
- **Paper:** With in Progress
|
47 |
+
|
48 |
+
## Uses
|
49 |
+
|
50 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
51 |
+
Install `flash-linear-attention` and the latest version of `transformers` before using this model:
|
52 |
+
|
53 |
+
```bash
|
54 |
+
pip install git+https://github.com/fla-org/flash-linear-attention
|
55 |
+
pip install 'transformers>=4.48.0'
|
56 |
+
```
|
57 |
+
|
58 |
+
### Direct Use
|
59 |
+
|
60 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
61 |
+
You can use this model just as any other HuggingFace models:
|
62 |
+
```python
|
63 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
64 |
+
model = AutoModelForCausalLM.from_pretrained('fla-hub/rwkv7-1.5B-world', trust_remote_code=True)
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained('fla-hub/rwkv7-1.5B-world', trust_remote_code=True)
|
66 |
+
|
67 |
+
model = model.cuda()
|
68 |
+
prompt = "What is a large language model?"
|
69 |
+
messages = [
|
70 |
+
{"role": "user", "content": "Who are you?"},
|
71 |
+
{"role": "assistant", "content": "I am a GPT-3 based model."},
|
72 |
+
{"role": "user", "content": prompt}
|
73 |
+
]
|
74 |
+
text = tokenizer.apply_chat_template(
|
75 |
+
messages,
|
76 |
+
tokenize=False,
|
77 |
+
add_generation_prompt=True
|
78 |
+
)
|
79 |
+
|
80 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
81 |
+
|
82 |
+
generated_ids = model.generate(
|
83 |
+
**model_inputs,
|
84 |
+
max_new_tokens=1024,
|
85 |
+
)
|
86 |
+
generated_ids = [
|
87 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
88 |
+
]
|
89 |
+
|
90 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
91 |
+
print(response)
|
92 |
+
```
|
93 |
+
|
94 |
+
## Training Details
|
95 |
+
|
96 |
+
### Training Data
|
97 |
+
|
98 |
+
This model is trained on the World v3 with a total of 3.119 trillion tokens.
|
99 |
+
|
100 |
+
#### Training Hyperparameters
|
101 |
+
|
102 |
+
- **Training regime:** bfloat16, lr 4e-4 to 1e-5 "delayed" cosine decay, wd 0.1 (with increasing batch sizes during the middle)
|
103 |
+
- **Final Loss:** 1.9965
|
104 |
+
- **Token Count:** 3.119 trillion
|
105 |
+
|
106 |
+
## Evaluation
|
107 |
+
|
108 |
+
#### Metrics
|
109 |
+
|
110 |
+
`lambada_openai`:
|
111 |
+
|
112 |
+
before conversion: ppl 4.13 acc 69.4%
|
113 |
+
|
114 |
+
after conversion: ppl 4.26 acc 68.8% (without apply temple)
|
115 |
+
|
116 |
+
## FAQ
|
117 |
+
Q: safetensors metadata is none.
|
118 |
+
|
119 |
+
A: upgrade transformers to >=4.48.0: `pip install 'transformers>=4.48.0'`
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/__init__.py
ADDED
File without changes
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/added_tokens.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<|rwkv_tokenizer_end_of_text|>": 0
|
3 |
+
}
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/config.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"a_low_rank_dim": 96,
|
4 |
+
"architectures": [
|
5 |
+
"RWKV7ForCausalLM"
|
6 |
+
],
|
7 |
+
"attn": null,
|
8 |
+
"attn_mode": "fused_recurrent",
|
9 |
+
"auto_map": {
|
10 |
+
"AutoConfig": "modeling_rwkv7.RWKV7Config",
|
11 |
+
"AutoModel": "modeling_rwkv7.RWKV7Model",
|
12 |
+
"AutoModelForCausalLM": "modeling_rwkv7.RWKV7ForCausalLM"
|
13 |
+
},
|
14 |
+
"bos_token_id": 1,
|
15 |
+
"decay_low_rank_dim": 96,
|
16 |
+
"eos_token_id": 2,
|
17 |
+
"fuse_cross_entropy": true,
|
18 |
+
"fuse_norm": false,
|
19 |
+
"gate_low_rank_dim": 256,
|
20 |
+
"head_dim": 64,
|
21 |
+
"hidden_act": "sqrelu",
|
22 |
+
"hidden_ratio": 4.0,
|
23 |
+
"hidden_size": 2048,
|
24 |
+
"initializer_range": 0.02,
|
25 |
+
"intermediate_size": 8192,
|
26 |
+
"max_position_embeddings": 2048,
|
27 |
+
"model_type": "rwkv7",
|
28 |
+
"norm_bias": true,
|
29 |
+
"norm_eps": 1e-05,
|
30 |
+
"norm_first": true,
|
31 |
+
"num_heads": null,
|
32 |
+
"num_hidden_layers": 24,
|
33 |
+
"tie_word_embeddings": false,
|
34 |
+
"torch_dtype": "float32",
|
35 |
+
"transformers_version": "4.48.1",
|
36 |
+
"use_cache": true,
|
37 |
+
"v_low_rank_dim": 64,
|
38 |
+
"vocab_size": 65536
|
39 |
+
}
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/generation_config.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 0,
|
3 |
+
"eos_token_id": 0,
|
4 |
+
"pad_token_id": 0,
|
5 |
+
"max_window_size": 2147483647,
|
6 |
+
"do_sample": true,
|
7 |
+
"top_k": 65536,
|
8 |
+
"top_p": 1.0,
|
9 |
+
"temperature": 1.0,
|
10 |
+
"transformers_version": "4.48.0"
|
11 |
+
}
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/hf_rwkv_tokenizer.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes for RWKV."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
20 |
+
|
21 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
22 |
+
from transformers.utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
pass
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
VOCAB_FILES_NAMES = {
|
32 |
+
"vocab_file": "rwkv_vocab_v20230424.txt",
|
33 |
+
}
|
34 |
+
|
35 |
+
class TRIE:
|
36 |
+
__slots__ = tuple("ch,to,values,front".split(","))
|
37 |
+
to: list
|
38 |
+
values: set
|
39 |
+
|
40 |
+
def __init__(self, front=None, ch=None):
|
41 |
+
self.ch = ch
|
42 |
+
self.to = [None for ch in range(256)]
|
43 |
+
self.values = set()
|
44 |
+
self.front = front
|
45 |
+
|
46 |
+
def __repr__(self):
|
47 |
+
fr = self
|
48 |
+
ret = []
|
49 |
+
while fr != None:
|
50 |
+
if fr.ch != None:
|
51 |
+
ret.append(fr.ch)
|
52 |
+
fr = fr.front
|
53 |
+
return "<TRIE %s %s>" % (ret[::-1], self.values)
|
54 |
+
|
55 |
+
def add(self, key: bytes, idx: int = 0, val=None):
|
56 |
+
if idx == len(key):
|
57 |
+
if val is None:
|
58 |
+
val = key
|
59 |
+
self.values.add(val)
|
60 |
+
return self
|
61 |
+
ch = key[idx]
|
62 |
+
if self.to[ch] is None:
|
63 |
+
self.to[ch] = TRIE(front=self, ch=ch)
|
64 |
+
return self.to[ch].add(key, idx=idx + 1, val=val)
|
65 |
+
|
66 |
+
def find_longest(self, key: bytes, idx: int = 0):
|
67 |
+
u: TRIE = self
|
68 |
+
ch: int = key[idx]
|
69 |
+
|
70 |
+
while u.to[ch] is not None:
|
71 |
+
u = u.to[ch]
|
72 |
+
idx += 1
|
73 |
+
if u.values:
|
74 |
+
ret = idx, u, u.values
|
75 |
+
if idx == len(key):
|
76 |
+
break
|
77 |
+
ch = key[idx]
|
78 |
+
return ret
|
79 |
+
|
80 |
+
|
81 |
+
class RWKV_TOKENIZER:
|
82 |
+
def __init__(self, file_name):
|
83 |
+
self.idx2token = {}
|
84 |
+
sorted = [] # must be already sorted
|
85 |
+
with open(file_name, "r", encoding="utf-8") as f:
|
86 |
+
lines = f.readlines()
|
87 |
+
for l in lines:
|
88 |
+
idx = int(l[: l.index(" ")])
|
89 |
+
x = eval(l[l.index(" ") : l.rindex(" ")])
|
90 |
+
x = x.encode("utf-8") if isinstance(x, str) else x
|
91 |
+
assert isinstance(x, bytes)
|
92 |
+
|
93 |
+
assert len(x) == int(l[l.rindex(" ") :])
|
94 |
+
sorted += [x]
|
95 |
+
self.idx2token[idx] = x
|
96 |
+
|
97 |
+
self.token2idx = {}
|
98 |
+
for k, v in self.idx2token.items():
|
99 |
+
self.token2idx[v] = int(k)
|
100 |
+
|
101 |
+
self.root = TRIE()
|
102 |
+
for t, i in self.token2idx.items():
|
103 |
+
_ = self.root.add(t, val=(t, i))
|
104 |
+
|
105 |
+
def encodeBytes(self, src: bytes):
|
106 |
+
idx: int = 0
|
107 |
+
tokens = []
|
108 |
+
while idx < len(src):
|
109 |
+
_idx: int = idx
|
110 |
+
idx, _, values = self.root.find_longest(src, idx)
|
111 |
+
assert idx != _idx
|
112 |
+
_, token = next(iter(values))
|
113 |
+
tokens.append(token)
|
114 |
+
return tokens
|
115 |
+
|
116 |
+
def decodeBytes(self, tokens):
|
117 |
+
return b"".join(map(lambda i: self.idx2token[i], tokens))
|
118 |
+
|
119 |
+
def encode(self, src):
|
120 |
+
if isinstance(src, str):
|
121 |
+
return [self.encodeBytes(src.encode("utf-8"))]
|
122 |
+
elif isinstance(src, list):
|
123 |
+
return [self.encodeBytes(s.encode("utf-8")) for s in src]
|
124 |
+
|
125 |
+
def decode(self, tokens):
|
126 |
+
return [self.decodeBytes(batch).decode("utf-8") for batch in tokens]
|
127 |
+
# try:
|
128 |
+
# return self.decodeBytes(tokens).decode('utf-8')
|
129 |
+
# except:
|
130 |
+
# return '\ufffd' # bad utf-8
|
131 |
+
|
132 |
+
def printTokens(self, tokens):
|
133 |
+
for i in tokens:
|
134 |
+
s = self.idx2token[i]
|
135 |
+
try:
|
136 |
+
s = s.decode("utf-8")
|
137 |
+
except:
|
138 |
+
pass
|
139 |
+
print(f"{repr(s)}{i}", end=" ")
|
140 |
+
print()
|
141 |
+
|
142 |
+
|
143 |
+
class RwkvTokenizer(PreTrainedTokenizer):
|
144 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
145 |
+
model_input_names = ["input_ids", "attention_mask"]
|
146 |
+
|
147 |
+
def __init__(
|
148 |
+
self, vocab_file, bos_token="<|rwkv_tokenizer_end_of_text|>", eos_token="<|rwkv_tokenizer_end_of_text|>", unk_token="<|rwkv_tokenizer_end_of_text|>", **kwargs
|
149 |
+
):
|
150 |
+
if not os.path.isfile(vocab_file):
|
151 |
+
raise ValueError(
|
152 |
+
f"Can't find a vocabulary file at path '{vocab_file}'."
|
153 |
+
)
|
154 |
+
|
155 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
156 |
+
tokens = reader.readlines()
|
157 |
+
|
158 |
+
if "add_bos_token" in kwargs:
|
159 |
+
self.add_bos_token = kwargs["add_bos_token"]
|
160 |
+
else:
|
161 |
+
self.add_bos_token = False
|
162 |
+
self.trie_tokenizer = RWKV_TOKENIZER(vocab_file)
|
163 |
+
vocab = self.trie_tokenizer.token2idx
|
164 |
+
self.encoder = vocab
|
165 |
+
self.decoder = {v: k for k, v in vocab.items()}
|
166 |
+
self._added_tokens_decoder = {0: AddedToken(str(bos_token))}
|
167 |
+
super().__init__(
|
168 |
+
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
|
169 |
+
)
|
170 |
+
|
171 |
+
@property
|
172 |
+
def vocab_size(self):
|
173 |
+
return len(self.encoder)
|
174 |
+
|
175 |
+
def get_vocab(self):
|
176 |
+
vocab = self.encoder
|
177 |
+
vocab.update(self.added_tokens_encoder)
|
178 |
+
vocab = dict(sorted(vocab.items(), key=lambda item: item[1]))
|
179 |
+
return vocab
|
180 |
+
|
181 |
+
def _tokenize(self, text, split_special_tokens=False):
|
182 |
+
# return self.wordpiece_tokenizer.tokenize(text.encode("utf-8"))
|
183 |
+
return self.trie_tokenizer.encode(text)[0]
|
184 |
+
|
185 |
+
def _convert_token_to_id(self, token):
|
186 |
+
return token
|
187 |
+
|
188 |
+
def _convert_id_to_token(self, index):
|
189 |
+
"""Converts an index (integer) in a token (byte) using the vocab."""
|
190 |
+
token = self.decoder.get(index, self.unk_token)
|
191 |
+
if isinstance(token, (bytes)):
|
192 |
+
token = token.decode("utf-8", errors="replace")
|
193 |
+
return token
|
194 |
+
|
195 |
+
def convert_tokens_to_string(self, tokens):
|
196 |
+
"""Converts a sequence of tokens (bytes) in a single string. Additional tokens are encoded to bytes"""
|
197 |
+
out_string = b"".join(
|
198 |
+
[k.encode(errors="replace") if isinstance(k, str) else k for k in tokens]
|
199 |
+
).decode("utf-8")
|
200 |
+
return out_string
|
201 |
+
|
202 |
+
def save_vocabulary(
|
203 |
+
self, save_directory: str, filename_prefix: Optional[str] = None
|
204 |
+
) -> Tuple[str]:
|
205 |
+
index = 0
|
206 |
+
if os.path.isdir(save_directory):
|
207 |
+
vocab_file = os.path.join(
|
208 |
+
save_directory,
|
209 |
+
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
vocab_file = (
|
213 |
+
filename_prefix + "-" if filename_prefix else ""
|
214 |
+
) + save_directory
|
215 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
216 |
+
for token, token_index in sorted(
|
217 |
+
self.encoder.items(), key=lambda kv: kv[1]
|
218 |
+
):
|
219 |
+
if index != token_index:
|
220 |
+
logger.warning(
|
221 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
222 |
+
" Please check that the vocabulary is not corrupted!"
|
223 |
+
)
|
224 |
+
index = token_index
|
225 |
+
writer.write(str(token) + "\n")
|
226 |
+
index += 1
|
227 |
+
return (vocab_file,)
|
228 |
+
|
229 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
230 |
+
if self.add_bos_token:
|
231 |
+
bos_token_ids = [self.bos_token_id]
|
232 |
+
else:
|
233 |
+
bos_token_ids = []
|
234 |
+
|
235 |
+
output = bos_token_ids + token_ids_0
|
236 |
+
|
237 |
+
if token_ids_1 is None:
|
238 |
+
return output
|
239 |
+
|
240 |
+
return output + bos_token_ids + token_ids_1
|
241 |
+
|
242 |
+
def get_special_tokens_mask(
|
243 |
+
self,
|
244 |
+
token_ids_0: List[int],
|
245 |
+
token_ids_1: Optional[List[int]] = None,
|
246 |
+
already_has_special_tokens: bool = False,
|
247 |
+
) -> List[int]:
|
248 |
+
"""
|
249 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
250 |
+
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
token_ids_0 (`List[int]`):
|
254 |
+
List of IDs.
|
255 |
+
token_ids_1 (`List[int]`, *optional*):
|
256 |
+
Optional second list of IDs for sequence pairs.
|
257 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
258 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
262 |
+
"""
|
263 |
+
if already_has_special_tokens:
|
264 |
+
return super().get_special_tokens_mask(
|
265 |
+
token_ids_0=token_ids_0,
|
266 |
+
token_ids_1=token_ids_1,
|
267 |
+
already_has_special_tokens=True,
|
268 |
+
)
|
269 |
+
|
270 |
+
if not self.add_bos_token:
|
271 |
+
return super().get_special_tokens_mask(
|
272 |
+
token_ids_0=token_ids_0,
|
273 |
+
token_ids_1=token_ids_1,
|
274 |
+
already_has_special_tokens=False,
|
275 |
+
)
|
276 |
+
|
277 |
+
if token_ids_1 is None:
|
278 |
+
return [1] + ([0] * len(token_ids_0))
|
279 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/modeling_rwkv7.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rwkvfla.models.rwkv7 import RWKV7ForCausalLM, RWKV7Model, RWKV7Config
|
2 |
+
RWKV7ForCausalLM = RWKV7ForCausalLM
|
3 |
+
RWKV7Model = RWKV7Model
|
4 |
+
RWKV7Config = RWKV7Config
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/rwkv_vocab_v20230424.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/special_tokens_map.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<|rwkv_tokenizer_end_of_text|>",
|
3 |
+
"eos_token": "\n\n",
|
4 |
+
"unk_token": "<|rwkv_tokenizer_end_of_text|>",
|
5 |
+
"pad_token": "<|rwkv_tokenizer_end_of_text|>"
|
6 |
+
}
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/CosyVoice-BlankEN/tokenizer_config.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"0": {
|
5 |
+
"content": "<|rwkv_tokenizer_end_of_text|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": false,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
}
|
12 |
+
},
|
13 |
+
"auto_map": {
|
14 |
+
"AutoTokenizer": [
|
15 |
+
"hf_rwkv_tokenizer.RwkvTokenizer",
|
16 |
+
null
|
17 |
+
]
|
18 |
+
},
|
19 |
+
"bos_token": "<|rwkv_tokenizer_end_of_text|>",
|
20 |
+
"pad_token": "<|rwkv_tokenizer_end_of_text|>",
|
21 |
+
"clean_up_tokenization_spaces": false,
|
22 |
+
"eos_token": "\n\n",
|
23 |
+
"model_max_length": 1000000000000000019884624838656,
|
24 |
+
"tokenizer_class": "RwkvTokenizer",
|
25 |
+
"unk_token": "<|rwkv_tokenizer_end_of_text|>",
|
26 |
+
"use_fast": false,
|
27 |
+
"chat_template": "{{ '<|rwkv_tokenizer_end_of_text|>' }}{% for message in messages %}{% if message['role'] == 'user' %}{{'User: ' + message['content'] + '\n\n'}}{% elif message['role'] == 'system' %}{{'System: ' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'Assistant: ' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
28 |
+
}
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- zh
|
4 |
+
- en
|
5 |
+
- ko
|
6 |
+
- ja
|
7 |
+
base_model:
|
8 |
+
- fla-hub/rwkv7-1.5B-world
|
9 |
+
pipeline_tag: text-to-speech
|
10 |
+
---
|
11 |
+
This is TTS model combined with Cosy's FSQ and RWKV Language model.
|
12 |
+
Please refer :
|
13 |
+
https://github.com/yynil/RWKVTTS/blob/main/Inference.md
|
14 |
+
to use this checkpoint.
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/asset/dingding.png
ADDED
![]() |
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/campplus.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
|
3 |
+
size 28303423
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/configuration.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"framework":"Pytorch","task":"text-to-speech"}
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/cosyvoice.yaml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# set random seed, so that you may reproduce your result.
|
2 |
+
__set_seed1: !apply:random.seed [1986]
|
3 |
+
__set_seed2: !apply:numpy.random.seed [1986]
|
4 |
+
__set_seed3: !apply:torch.manual_seed [1986]
|
5 |
+
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
6 |
+
|
7 |
+
# fixed params
|
8 |
+
sample_rate: 24000
|
9 |
+
llm_input_size: 2048
|
10 |
+
llm_output_size: 2048
|
11 |
+
spk_embed_dim: 192
|
12 |
+
qwen_pretrain_path: ''
|
13 |
+
|
14 |
+
# model params
|
15 |
+
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
16 |
+
# for system/third_party class/function, we do not require this.
|
17 |
+
llm: !new:model.llm.llm.RWKV7LM
|
18 |
+
llm_input_size: !ref <llm_input_size>
|
19 |
+
llm_output_size: !ref <llm_output_size>
|
20 |
+
speech_token_size: 6561
|
21 |
+
length_normalized_loss: True
|
22 |
+
lsm_weight: 0
|
23 |
+
vocab_size: 65548
|
24 |
+
llm: !ref <qwen_pretrain_path>
|
25 |
+
sampling: !name:cosyvoice.utils.common.ras_sampling
|
26 |
+
top_p: 0.8
|
27 |
+
top_k: 25
|
28 |
+
win_size: 10
|
29 |
+
tau_r: 0.1
|
30 |
+
|
31 |
+
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
|
32 |
+
input_size: 512
|
33 |
+
output_size: 80
|
34 |
+
spk_embed_dim: !ref <spk_embed_dim>
|
35 |
+
output_type: 'mel'
|
36 |
+
vocab_size: 6561
|
37 |
+
input_frame_rate: 25
|
38 |
+
only_mask_loss: True
|
39 |
+
token_mel_ratio: 2
|
40 |
+
pre_lookahead_len: 3
|
41 |
+
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
|
42 |
+
output_size: 512
|
43 |
+
attention_heads: 8
|
44 |
+
linear_units: 2048
|
45 |
+
num_blocks: 6
|
46 |
+
dropout_rate: 0.1
|
47 |
+
positional_dropout_rate: 0.1
|
48 |
+
attention_dropout_rate: 0.1
|
49 |
+
normalize_before: True
|
50 |
+
input_layer: 'linear'
|
51 |
+
pos_enc_layer_type: 'rel_pos_espnet'
|
52 |
+
selfattention_layer_type: 'rel_selfattn'
|
53 |
+
input_size: 512
|
54 |
+
use_cnn_module: False
|
55 |
+
macaron_style: False
|
56 |
+
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
|
57 |
+
in_channels: 240
|
58 |
+
n_spks: 1
|
59 |
+
spk_emb_dim: 80
|
60 |
+
cfm_params: !new:omegaconf.DictConfig
|
61 |
+
content:
|
62 |
+
sigma_min: 1e-06
|
63 |
+
solver: 'euler'
|
64 |
+
t_scheduler: 'cosine'
|
65 |
+
training_cfg_rate: 0.2
|
66 |
+
inference_cfg_rate: 0.7
|
67 |
+
reg_loss_type: 'l1'
|
68 |
+
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
|
69 |
+
in_channels: 320
|
70 |
+
out_channels: 80
|
71 |
+
causal: True
|
72 |
+
channels: [256]
|
73 |
+
dropout: 0.0
|
74 |
+
attention_head_dim: 64
|
75 |
+
n_blocks: 4
|
76 |
+
num_mid_blocks: 12
|
77 |
+
num_heads: 8
|
78 |
+
act_fn: 'gelu'
|
79 |
+
|
80 |
+
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
81 |
+
in_channels: 80
|
82 |
+
base_channels: 512
|
83 |
+
nb_harmonics: 8
|
84 |
+
sampling_rate: !ref <sample_rate>
|
85 |
+
nsf_alpha: 0.1
|
86 |
+
nsf_sigma: 0.003
|
87 |
+
nsf_voiced_threshold: 10
|
88 |
+
upsample_rates: [8, 5, 3]
|
89 |
+
upsample_kernel_sizes: [16, 11, 7]
|
90 |
+
istft_params:
|
91 |
+
n_fft: 16
|
92 |
+
hop_len: 4
|
93 |
+
resblock_kernel_sizes: [3, 7, 11]
|
94 |
+
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
95 |
+
source_resblock_kernel_sizes: [7, 7, 11]
|
96 |
+
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
97 |
+
lrelu_slope: 0.1
|
98 |
+
audio_limit: 0.99
|
99 |
+
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
100 |
+
num_class: 1
|
101 |
+
in_channels: 80
|
102 |
+
cond_channels: 512
|
103 |
+
|
104 |
+
# processor functions
|
105 |
+
get_tokenizer: !name:utils.utilities.get_tokenizer
|
106 |
+
model_dir: !ref <qwen_pretrain_path>
|
107 |
+
allowed_special: 'all'
|
108 |
+
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
109 |
+
n_fft: 1920
|
110 |
+
num_mels: 80
|
111 |
+
sampling_rate: !ref <sample_rate>
|
112 |
+
hop_size: 480
|
113 |
+
win_size: 1920
|
114 |
+
fmin: 0
|
115 |
+
fmax: 8000
|
116 |
+
center: False
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/flow.encoder.fp16.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:46d2539ad8bdb90026cd50cb42e45bd389f10108111d742b912feddca105aeb6
|
3 |
+
size 116703414
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/hift.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d4af0d661a416c69544eec83ff9c070dc80c37ee53ef44af3a37d910c95bc21
|
3 |
+
size 83364158
|
CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO/spk2info.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbc8f9064db35ee8163b538c0f6ed9fe0c3e2fe0f560cca910e578138d961285
|
3 |
+
size 3281245
|
Inference.md
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install the code base and the dependencies
|
2 |
+
```bash
|
3 |
+
git clone https://github.com/yynil/RWKVTTS
|
4 |
+
```
|
5 |
+
Add these two directories to the PYTHONPATH
|
6 |
+
```bash
|
7 |
+
export PYTHONPATH=$PYTHONPATH:/home/user/RWKVTTS:/home/user/RWKVTTS/third_party
|
8 |
+
```
|
9 |
+
# Install the dependencies
|
10 |
+
```bash
|
11 |
+
conda create -n rwkvtts-311 -y python=3.11
|
12 |
+
conda activate rwkvtts-311
|
13 |
+
conda install -y -c conda-forge pynini==2.1.6
|
14 |
+
cd RWKVTTS
|
15 |
+
pip install -r rwkvtts_requirements.txt
|
16 |
+
```
|
17 |
+
|
18 |
+
Download the pretrained models from the following links:
|
19 |
+
https://huggingface.co/yueyulin/CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO
|
20 |
+
|
21 |
+
Place the CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO to local directory. Let's say /home/user/CosyVoice2-0.5B-RWKV-7-1.5B-Instruct-CHENJPKO
|
22 |
+
|
23 |
+
Add two directories to the PYTHONPATH
|
24 |
+
|
25 |
+
The example code for inference is as follows:
|
26 |
+
```python
|
27 |
+
def do_tts(tts_text,prompt_texts,cosyvoice):
|
28 |
+
import logging
|
29 |
+
for i, (prompt_audio_file, prompt_text) in enumerate(zip(prompt_audios, prompt_texts)):
|
30 |
+
logging.info(f'Processing {prompt_text}')
|
31 |
+
prompt_speech_16k = load_wav(prompt_audio_file, 16000)
|
32 |
+
with torch.no_grad():
|
33 |
+
if prompt_text is not None:
|
34 |
+
for j, k in enumerate(cosyvoice.inference_zero_shot(tts_text,prompt_text, prompt_speech_16k, stream=False,speed=1)):
|
35 |
+
torchaudio.save('zero_{}_{}.wav'.format(i, j), k['tts_speech'], cosyvoice.sample_rate)
|
36 |
+
else:
|
37 |
+
for j, k in enumerate(cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=False,speed=1)):
|
38 |
+
torchaudio.save('zero_{}_{}.wav'.format(i, j), k['tts_speech'], cosyvoice.sample_rate)
|
39 |
+
logging.info(f'Finished processing {prompt_text}')
|
40 |
+
if __name__ == '__main__':
|
41 |
+
from cosyvoice.cli.cosyvoice import CosyVoice2
|
42 |
+
import torch
|
43 |
+
import sys
|
44 |
+
# model_path = '/home/yueyulin/models/CosyVoice2-0.5B_RWKV_0.19B/'
|
45 |
+
# device = 'cuda:0'
|
46 |
+
print(sys.argv)
|
47 |
+
model_path = sys.argv[1]
|
48 |
+
device = sys.argv[2] if len(sys.argv) > 2 else 'cuda:0'
|
49 |
+
is_flow_only = sys.argv[3]=='True' if len(sys.argv) > 3 else False
|
50 |
+
print(f'is_flow_only: {is_flow_only}')
|
51 |
+
cosyvoice = CosyVoice2(model_path,device=device,fp16=False,load_jit=False)
|
52 |
+
|
53 |
+
from cosyvoice.utils.file_utils import load_wav
|
54 |
+
import torchaudio
|
55 |
+
prompt_audios = [
|
56 |
+
'/home/yueyulin/github/RWKVTTS/zero_shot_prompt.wav',
|
57 |
+
'/home/yueyulin/github/RWKVTTS/mine.wav',
|
58 |
+
'/home/yueyulin/github/RWKVTTS/new.wav',
|
59 |
+
'/home/yueyulin/github/RWKVTTS/Trump.wav',
|
60 |
+
]
|
61 |
+
|
62 |
+
if not is_flow_only:
|
63 |
+
prompt_texts = [
|
64 |
+
'希望你以后做的比我还好呦。',
|
65 |
+
'少年强则中国强。',
|
66 |
+
'我随便说一句话,我喊开始录就开始录。',
|
67 |
+
'numbers of Latino, African American, Asian American and native American voters.'
|
68 |
+
]
|
69 |
+
else:
|
70 |
+
prompt_texts = [
|
71 |
+
None,
|
72 |
+
None,
|
73 |
+
None,
|
74 |
+
None
|
75 |
+
]
|
76 |
+
do_tts('Make America great again!',prompt_texts,cosyvoice)
|
77 |
+
```
|
78 |
+
More examples can be found in the model/test directory.
|
79 |
+
|
80 |
+
[Instruct example](model/test/test_instructed.py) is an example to use the instructed voice flow to generate the audio.
|
81 |
+
[Embedded ref voice example](model/test/test_speaker_adapter.py) is an example to use the speaker adapter to generate the audio.
|
82 |
+
|
83 |
+
Please refer the [Service Call URL](service/README.md) for the instructions and reference voices.
|
84 |
+
|
85 |
+
If you pass the prompt_texts as None, the engine will only clone the voice flow and texture which is good to clone voice cross lingual. If you pass the correct prompt texts to the engine, the engine will try to continue to finish the audio tokens following the prompt audio you provided. This will be good to continue the audio you provided but it will be weird when you try to mix languages.
|
86 |
+
|
87 |
+
The test source code is [test code](model/test/test_initialize.py).
|
88 |
+
|
89 |
+
Please change the paths to the correct paths in your system.
|
90 |
+
|
91 |
+
You can also use your own prompt audio and text. Since the llm module is to finish your audio tokens for you, so please make sure the audio is clean,complete and the text is correct. Otherwise, the result may not be good.
|
92 |
+
|
93 |
+
The following table shows the example results of the above code:
|
94 |
+
| Prompt Audio | Prompt Text | TTS Text | Result |
|
95 |
+
| --- | --- | --- | --- |
|
96 |
+
| https://github.com/yynil/RWKVTTS/raw/main/zero_shot_prompt.wav | 希望你以后做的比我还好呦。 | 中国在东亚,是世界上最大的国家,也是世界上人口最多的国家。 | https://github.com/yynil/RWKVTTS/raw/main/zero_0_0.wav |
|
97 |
+
| https://github.com/yynil/RWKVTTS/raw/main/mine.wav| 少年强则中国强。 | 中国在东亚,是世界上最大的国家,也是世界上人口最多的国家。 | https://github.com/yynil/RWKVTTS/raw/main/zero_1_0.wav |
|
98 |
+
| https://github.com/yynil/RWKVTTS/raw/main/new.wav | 我随便说一句话,我喊开始录就开始录。 | 中国在东亚,是世界上最大的国家,也是世界上人口最���的国家。 | https://github.com/yynil/RWKVTTS/raw/main/zero_2_0.wav |
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,181 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RWKVTTS
|
2 |
+
This project is to train an RWKV LLM for TTS generation which compatible to other TTS engine(like fish/cosy/chattts).
|
3 |
+
|
4 |
+
For most of modern LLM based TTS engine, there are two parts :
|
5 |
+
1. VQ VAE: this model is to encode audio to audio tokens and decode audio tokens to audio.
|
6 |
+
2. LLM: this model is to generate audio tokens using text tokens and the prompt audio tokens. The prompt audio tokens are also from VQ VAE.
|
7 |
+
|
8 |
+
Typically the training of the LLM based TTS involves VQ-VAE training and LLM training, like CosyTTS, ChatTTS and FishTTS. However we focus to train an RWKV LLM model to replace the LLM part in these TTS engines.
|
9 |
+
|
10 |
+
```mermaid
|
11 |
+
flowchart TB
|
12 |
+
node_1[["Input Prompt Text"]]
|
13 |
+
node_2(["Text Tokenizer"])
|
14 |
+
node_3(["Audio Tokenizer(VQ)"])
|
15 |
+
node_4[["Input Reference Audio"]]
|
16 |
+
node_5[["Text Tokens"]]
|
17 |
+
node_6[["Audio Tokens"]]
|
18 |
+
node_7(["Text Embedder"])
|
19 |
+
node_8(["Audio Embedder"])
|
20 |
+
node_9[["Text Embeddings"]]
|
21 |
+
node_10[["Audio Embeddings"]]
|
22 |
+
node_11(["Concatenate Embeddings"])
|
23 |
+
node_12[["Input Embeddings"]]
|
24 |
+
node_13{{"Language Model"}}
|
25 |
+
node_14[["Hidden States"]]
|
26 |
+
node_15(["Audio Head"])
|
27 |
+
node_16{"Continue to decode?"}
|
28 |
+
node_17(["Next Step Input"])
|
29 |
+
node_18(["Finish Decode"])
|
30 |
+
node_1 --> node_2
|
31 |
+
node_4 --> node_3
|
32 |
+
node_2 --> node_5
|
33 |
+
node_3 --> node_6
|
34 |
+
node_5 --> node_7
|
35 |
+
node_6 --> node_8
|
36 |
+
node_7 --> node_9
|
37 |
+
node_8 --> node_10
|
38 |
+
node_9 --> node_11
|
39 |
+
node_10 --> node_11
|
40 |
+
node_11 --> node_12
|
41 |
+
node_12 --> node_13
|
42 |
+
node_13 --> node_14
|
43 |
+
node_14 --> node_15
|
44 |
+
node_15 --> node_16
|
45 |
+
node_16 --"Yes"--> node_17
|
46 |
+
node_17 --> node_13
|
47 |
+
node_16 --"No"--> node_18
|
48 |
+
```
|
49 |
+
|
50 |
+
Different TTS engines might have different data layout and different special control token, so we need to prepare different data and train a RWKV LLM model for each TTS engine.
|
51 |
+
|
52 |
+
# Process to train LLM for different TTS engine
|
53 |
+
|
54 |
+
## Cosy 2.0
|
55 |
+
|
56 |
+
### Cosy 2.0 Data Layout
|
57 |
+
|
58 |
+
The layout of Cosy 2.0 for LLM:
|
59 |
+
|
60 |
+
```mermaid
|
61 |
+
|
62 |
+
flowchart LR
|
63 |
+
node_1[["SOS Embeddings"]]
|
64 |
+
node_2[["Text Embeddings"]]
|
65 |
+
node_3[["Task ID Embedings"]]
|
66 |
+
node_4[["Audio Embeddings"]]
|
67 |
+
node_5[["Last Audio Embeddings"]]
|
68 |
+
node_1 --- node_2
|
69 |
+
node_2 --- node_3
|
70 |
+
node_3 --> node_4
|
71 |
+
node_4 --> node_5
|
72 |
+
|
73 |
+
```
|
74 |
+
|
75 |
+
The forward of LLM for cosy 2.0:
|
76 |
+
```mermaid
|
77 |
+
graph TD
|
78 |
+
A[Input: batch] --> B[Extract tokens and lengths]
|
79 |
+
B --> C1[Prepare LLM Target]
|
80 |
+
B --> C2[Encode Text Tokens]
|
81 |
+
B --> C3[Generate SOS/EOS and Task ID Embeddings]
|
82 |
+
B --> C4[Encode Speech Tokens]
|
83 |
+
|
84 |
+
C1[Prepare LLM Target] --> D1["Create target sequence for each sample<br>[IGNORE_ID, ..., speech_tokens, EOS]"]
|
85 |
+
D1 --> D2[Pad and move target to device]
|
86 |
+
|
87 |
+
C2[Encode Text Tokens] --> E1[Apply text_embedding layer]
|
88 |
+
|
89 |
+
C3[Generate SOS/EOS and Task ID Embeddings] --> F1[Get SOS/EOS embeddings from llm_embedding]
|
90 |
+
C3 --> F2[Get task_id embeddings from llm_embedding]
|
91 |
+
|
92 |
+
C4[Encode Speech Tokens] --> G1[Apply speech_embedding layer]
|
93 |
+
|
94 |
+
E1 --> H[Unpad and pad sequence]
|
95 |
+
F1 --> H
|
96 |
+
F2 --> H
|
97 |
+
G1 --> H
|
98 |
+
|
99 |
+
H --> I1[Generate LM input]
|
100 |
+
H --> I2[Create attention mask]
|
101 |
+
|
102 |
+
I1 --> J[Run LLM forward pass]
|
103 |
+
I2 --> J
|
104 |
+
|
105 |
+
J --> K[Extract hidden states]
|
106 |
+
K --> L[Generate logits through llm_decoder]
|
107 |
+
|
108 |
+
D2 --> M[Compute loss and accuracy]
|
109 |
+
L --> M
|
110 |
+
|
111 |
+
M --> N[Return loss and accuracy]
|
112 |
+
```
|
113 |
+
|
114 |
+
There are some points to note for Cosy 2.0:
|
115 |
+
1. The prompt audio tokens are used to act reference audio, LLM will generate audio tokens mimic the reference audio.
|
116 |
+
2. '<|endofprompt|>' is used for prompt text, it is a special token to indicate this prompt is an instruction.
|
117 |
+
|
118 |
+
### Cosy 2.0 Data Preparation
|
119 |
+
|
120 |
+
1. Download reference audio files from https://huggingface.co/datasets/yueyulin/TTS_Reference and put them to folder $REF_AUDIO_DIR. These audios are used to generate audio tokens.
|
121 |
+
2. Download Cosy 2.0-0.5B model from https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B and put it to folder $MODEL_DIR.
|
122 |
+
3. Clone the Cosy 2.0 repo from:https://github.com/yynil/CosyVoice and follow the instruction to install the environment. In this repository, I change the codes to allow user to specify cuda device for multiple processes generation. If you have installed torch 2.6, please remember to force triton downgrading to 3.1.0.
|
123 |
+
4. Prepare the text data for audio tokens's training dataset. Currently we support parquet files and jsonl files. The text field is the only required field in the data file. I download the parquet from [wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia) for Chinese and Engish parquet files.
|
124 |
+
5. Generate the audio tokens using the following command:
|
125 |
+
```bash
|
126 |
+
bash run_multiple_process.sh --parquet_files /home/yueyulin/data/wiki/zh/train-00000-of-00006.parquet /home/yueyulin/data/wiki/zh/train-00001-of-00006.parquet /home/yueyulin/data/wiki/zh/train-00002-of-00006.parquet /home/yueyulin/data/wiki/zh/train-00003-of-00006.parquet /home/yueyulin/data/wiki/zh/train-00004-of-00006.parquet /home/yueyulin/data/wiki/zh/train-00005-of-00006.parquet --language zh --prompts_dir extract_data/prompts/zh --device cuda:0 --output_dir /home/yueyulin/data/speech_corpus
|
127 |
+
```
|
128 |
+
The prompts_dir is the $REF_AUDIO_DIR, the parquet_files are the list of files downloaded from wikimedia, each file is processed by one file. In my experience, one 4090 can process 6 files at the same time. The output_dir is the dirctory that audio tokens generated and saved.
|
129 |
+
|
130 |
+
|
131 |
+
### Cosy 2.0 LLM Training
|
132 |
+
After data is generated and saved, we will get the JSONL files like :
|
133 |
+
```json
|
134 |
+
{"text": "甄别重点监测企业是确保监测数据全面性和代表性的基础。首先,需要根据预警机制的覆盖范围和目标,明确监测企业的选择标准。选择标准可以包括企业规模、市场份额、行业影响力等。其次,通过企业调查、行业协会推荐等方式,初步筛选出符合条件的潜在监测企业。", "tts_speech_tokens": [2031, 4137, 6405, 6405, 6405, 6405, 6405, 6324, 6324, 6324, 6324, 6324, 6324, 4218, 1761, 4509, 2333, 4483, 5934, 6258, 1929, 3482, 314, 2300, 957, 5163, 6309, 5064, 6425, 3992, 1932, 80, 305, 734, 1479, 5650, 2472, 4778, 4487, 6175, 5667, 5373, 2187, 4851, 137, 141, 4919, 4407, 2436, 1295, 2024, 1294, 4940, 4778, 2330, 764, 1762, 2031, 1788, 5943, 5319, 5238, 5338, 3872, 1614, 4920, 6055, 6027, 3084, 5343, 4605, 2330, 218, 2172, 572, 1949, 1331, 865, 4921, 2472, 4688, 4379, 5850, 6342, 6373, 2997, 2529, 5087, 623, 3700, 6292, 6291, 5823, 5830, 2102, 1041, 6225, 6316, 3887, 889, 5487, 3813, 1626, 953, 734, 909, 4314, 4804, 4821, 4463, 23, 4683, 4678, 2724, 4832, 992, 1238, 2673, 324, 2099, 2486, 135, 2001, 4537, 5271, 2519, 957, 1699, 953, 1304, 1028, 4752, 2553, 5560, 4154, 1287, 59, 879, 4921, 2499, 5748, 5019, 240, 5889, 6264, 4293, 2186, 2105, 2005, 6405, 6405, 6324, 6324, 6324, 4137, 4218, 3651, 6048, 3132, 1433, 1457, 3962, 4515, 2482, 4490, 4561, 4669, 6054, 6270, 6316, 4615, 4781, 575, 632, 2031, 183, 4598, 4479, 6181, 5496, 4128, 3887, 1943, 1861, 6288, 5343, 6072, 3319, 2733, 322, 1187, 1727, 1807, 4921, 4677, 5668, 5019, 2427, 2976, 6066, 5332, 63, 73, 380, 4239, 6534, 6543, 5101, 1452, 213, 5921, 2273, 6453, 4347, 4537, 4459, 11, 2124, 866, 386, 485, 2511, 333, 632, 4317, 5772, 5803, 1457, 2163, 889, 5021, 2381, 5675, 5056, 5092, 1951, 3888, 3645, 4218, 6405, 6324, 4137, 1884, 1646, 2726, 377, 3992, 5529, 2481, 6054, 3822, 5340, 2330, 71, 2733, 2499, 5012, 4463, 5850, 6342, 6373, 2268, 4851, 137, 151, 4921, 4435, 4650, 528, 1295, 1295, 2023, 2753, 4850, 4570, 2243, 1047, 56, 113, 4512, 5568, 1662, 971, 5, 1480, 6387, 1045, 65, 460, 2160, 5102, 4568, 5056, 5098, 1602, 6048, 4367, 956, 59, 1524, 6405, 6405, 6324, 6324, 6324, 6324, 6324, 4137, 2031, 2706, 5325, 1653, 3887, 2219, 3667, 5664, 803, 4592, 2163, 5587, 4598, 5026, 5089, 1692, 5976, 1937, 146, 41, 1507, 1950, 2031, 0, 2349, 343, 4607, 5019, 566, 1683, 2166, 5051, 5678, 5057, 5830, 573, 2835, 2856, 5099, 707, 947, 1113, 4675, 4408, 4623, 1294, 2024, 2023, 3481, 4778, 2411, 1208, 1302, 660, 5827, 5345, 5074, 4560, 6501, 1403, 635, 716, 680, 5057, 4970, 1947, 3645, 1458, 1707, 6024, 6049, 5238, 5340, 1696, 5244, 1468, 1946, 509, 1318, 6534, 2800, 4510, 2234, 1991, 2017, 2018, 1370, 470, 2891, 4997, 1972, 1701, 5832, 1458, 1950, 4860, 5589, 1946, 1949, 509, 5369, 4966, 5019, 4849, 2411, 314, 1293, 1267, 377, 6421, 4800, 4416, 4893, 8, 1946, 1967, 1584, 4615, 5019, 2510, 867, 63, 245, 533, 1991, 4218, 6405, 6405, 6324, 6324, 6324, 6324, 6324, 4137, 1950, 4920, 4516, 276, 2024, 4777, 4194, 6373, 5643, 4851, 4448, 65, 1517, 1978, 4218, 6405, 4218, 2112, 1350, 4860, 5074, 5772, 6262, 672, 5097, 5090, 221, 1032, 4675, 4408, 285, 1295, 1294, 557, 4490, 228, 276, 4858, 4807, 2870, 1675, 6051, 1539, 4141, 1946, 4133, 6320, 4699, 982, 1950, 5832, 5835, 3645, 1947, 5589, 5589, 4136, 1946, 1235, 4642, 4993, 4857, 4598, 62, 4431, 4675, 285, 1043, 314, 2414, 2760, 2850, 5094, 3158, 1214, 1032, 2997, 2763, 5345, 5100, 402, 4677, 4857, 4543, 5, 1482, 2004, 56, 515, 1970, 2077, 6534, 3488, 5591, 5690, 5869, 5319, 2331, 5342, 1688, 1679, 1735, 4218, 6324, 6324, 6405, 4218, 2031, 5886, 6291, 6480, 2883, 5829, 5826, 2175, 5799, 5826, 2186, 2183, 5940, 5322, 120, 5918, 4571, 4687, 3813, 962, 737, 1561, 5886, 4077, 1429, 5831, 6560, 3644, 6429, 6507, 6534, 2101, 2186, 5097, 2682, 2673, 2017, 2576, 4594, 1005, 4785, 2760, 854, 1946, 683, 4844, 2733, 4695, 4840, 2192, 1482, 72, 29, 788, 1761, 4921, 4408, 2517, 566, 35, 2192, 5934, 4209, 5652, 4537, 5920, 278, 160, 3462, 4686, 5021, 4490, 5853, 3912, 6374, 2997, 4716, 2567, 140, 3462, 4435, 2436, 1295, 1295, 2023, 3482, 4769, 4598, 89, 1736, 4218, 6405, 6405, 6324, 6324, 4137], "prompt_text": "那么就在两侧的象限同时忙碌。", "llm_prompt_speech_token": [3686, 6324, 4137, 1959, 3666, 4376, 2836, 2127, 578, 2441, 1041, 2337, 6073, 3560, 1369, 5650, 4691, 5192, 2924, 89, 1687, 1539, 4218, 1848, 160, 4760, 2825, 1463, 1946, 1223, 1313, 2067, 5648, 2997, 2268, 2277, 4842, 4763, 308, 1038, 140, 842, 2983, 4672, 4650, 4696, 5995, 5603, 1238, 1238, 4672, 4650, 4777, 2474, 8, 767, 1731, 4299, 2079, 4941, 4947, 665, 719, 4319, 6424, 5067, 5967, 6048, 5967, 5238, 1523, 3875, 3872, 4314, 661, 1946, 1217, 500, 6422, 1506, 4852, 5831, 1457, 1448]}
|
135 |
+
{"text": "Once all the Cabinet and Cabinet-level officers have been invested, the act of their investiture usually ends with a \"family photo\" of the new Administration around the new president and vice-president. For this photo, the new ministers' alignment and proximity to the president is dictated by the order of precedence, with the ministers who head older departments standing in the first row, and the heads of the newer departments standing in the back rows. Some departments, such as the Department of Defence, take precedence from prior departments now abolished.", "tts_speech_tokens": [764, 35, 1896, 4299, 6486, 4299, 4299, 4299, 4218, 651, 2112, 2131, 1403, 2792, 2207, 1725, 5401, 281, 575, 683, 4997, 3474, 4492, 195, 87, 5109, 5846, 6077, 2270, 2172, 3828, 4424, 4543, 1520, 1753, 6258, 4075, 141, 5109, 5845, 3647, 1188, 3987, 3750, 4414, 1516, 4180, 5014, 5348, 1441, 6534, 5075, 5100, 1274, 1301, 3569, 3488, 3996, 6183, 4752, 4919, 2328, 3158, 6071, 5264, 5482, 5403, 5844, 5837, 191, 2139, 1839, 2255, 831, 4508, 4576, 6255, 1857, 29, 2, 2228, 5482, 6459, 2004, 2253, 2267, 2255, 885, 2112, 1788, 5916, 5835, 5919, 5919, 5919, 4056, 4299, 2058, 2982, 1295, 305, 1463, 3647, 2383, 2112, 3054, 4603, 3043, 4272, 2260, 4841, 6029, 6062, 5329, 6256, 6465, 2386, 2921, 2204, 4429, 5647, 2085, 2490, 809, 159, 546, 5325, 5298, 917, 1688, 3863, 3872, 3884, 3481, 3480, 4130, 5993, 5979, 5322, 5257, 5634, 4691, 4533, 5100, 1277, 764, 5111, 5, 47, 3748, 4929, 2376, 3583, 2990, 6456, 2232, 2306, 6507, 6210, 4463, 5840, 2270, 4071, 5693, 4663, 5100, 5226, 6510, 6534, 2900, 2567, 137, 882, 1199, 2831, 632, 389, 4251, 4191, 73, 49, 3831, 404, 971, 4853, 4613, 4074, 4314, 2417, 3750, 4507, 4416, 4594, 3624, 5325, 962, 224, 404, 5295, 4596, 2238, 3670, 3848, 4339, 1676, 812, 2441, 6097, 3934, 2261, 3750, 1564, 3401, 6074, 5823, 1383, 4293, 3816, 3734, 2219, 4450, 5482, 2996, 150, 3063, 143, 3019, 3667, 149, 3748, 4278, 4347, 3485, 5270, 4858, 5239, 2568, 2028, 4050, 3011, 32, 2264, 4672, 2991, 888, 804, 149, 2234, 5934, 1744, 2112, 3975, 5916, 5943, 5919, 5943, 5919, 5946, 5916, 3972, 4299, 6402, 6534, 1927, 140, 1038, 2263, 4567, 4413, 5563, 4672, 3999, 6264, 4826, 2810, 2567, 228, 227, 2324, 2504, 1773, 6375, 77, 3831, 754, 3401, 4612, 6498, 4311, 2411, 831, 2255, 4414, 5320, 4920, 2328, 5345, 5169, 4752, 4763, 5014, 6449, 2687, 3413, 3647, 2276, 3670, 4069, 1883, 2330, 4499, 1525, 1762, 1490, 2921, 1639, 2166, 4050, 4304, 2837, 732, 6049, 5405, 2266, 910, 4315, 2399, 798, 4859, 4857, 1923, 4434, 4485, 5152, 4206, 4447, 1917, 2136, 3807, 3740, 5, 2264, 5166, 5409, 806, 2982, 878, 2258, 860, 1525, 1762, 3320, 5169, 2166, 546, 2994, 4526, 4056, 2112, 60, 2274, 2528, 5084, 231, 4450, 4597, 1938, 2163, 650, 5108, 2335, 4188, 4859, 1760, 2096, 2903, 4349, 1684, 873, 3872, 6059, 6058, 5976, 4299, 2136, 4050, 3740, 2, 4432, 6455, 2226, 886, 3063, 881, 71, 2234, 5937, 5650, 5238, 4296, 1422, 2342, 2139, 3462, 2261, 1641, 4314, 230, 186, 2965, 4523, 4509, 4999, 4839, 5345, 6070, 5263, 4839, 3813, 3018, 5825, 2926, 5106, 2924, 194, 147, 1433, 728, 2915, 477, 2325, 5330, 6070, 1527, 2421, 2166, 3564, 6166, 1865, 1676, 2092, 4068, 2255, 1483, 5658, 5726, 2085, 3219, 71, 35, 2219, 3828, 2210, 5047, 6100, 4526, 2934, 3909, 4511, 6453, 6534, 3367, 3863, 3146, 5241, 5323, 6054, 1872, 3881, 947, 380, 632, 2909, 2884, 4296, 5913, 5835, 5919, 5919, 5919, 5838, 3975, 2112, 3648, 2192, 831, 3906, 2222, 5118, 5111, 4487, 879, 5650, 4422, 5256, 6465, 4446, 4522, 3831, 2294, 5588, 5825, 3377, 6050, 1698, 147, 1920, 1404, 6328, 1622, 1676, 2083, 2124, 2336, 3669, 5402, 4269, 2490, 71, 8, 113, 1563, 395, 4238, 2510, 3016, 3936, 4430, 2163, 461, 5192, 5998, 5272, 1869, 651, 4302, 1685, 221, 380, 389, 803, 5412, 4753, 2244, 2028, 3648, 3729, 5916, 5919, 5916, 3732, 3975, 2112, 3894, 5239, 5648, 2250, 2918, 4807, 6258, 879, 4600, 2166, 3483, 6327, 6239, 1652, 1757, 1881, 128, 2264, 5935, 5631, 5729, 5482, 2198, 2309, 1329, 4756, 2263, 4448, 4437, 6454, 4272, 3465, 157, 66, 954, 2166, 5598, 3980, 3836, 1838, 2064, 4069, 2371, 2938, 4565, 4356, 789, 4612, 5940, 6510, 3270, 5, 737, 8, 2234, 3747, 5650, 5482, 4269, 303, 2193, 2447, 4849, 2112, 2085, 4050, 3739, 2192, 4428, 5486, 2253, 885, 2992, 2249, 5205, 3453, 4672, 6186, 6534, 6059, 4068, 2184, 4320, 3978, 4052, 1622, 926, 3140, 231, 157, 2160, 1404, 6084, 3809, 1598, 2092, 6255, 2234, 3750, 5405, 3459, 3669, 23, 1463, 974, 2675, 2891, 2166, 712, 5030, 5023, 5080, 2741, 308, 32, 2203, 5217, 4593, 1437, 303, 2112, 3975], "prompt_text": " So I am gonna do this right now. So let's do it.", "llm_prompt_speech_token": [1822, 5727, 5000, 930, 5015, 2912, 3616, 692, 1250, 1978, 4214, 3485, 2036, 1298, 2918, 5192, 5056, 5074, 5065, 4813, 3005, 3002, 3313, 4238, 795, 4523, 4520, 3038, 4496, 859, 1887, 2490, 3309, 6235, 5264, 6074, 6047, 5339, 5474, 4291, 2915, 2666, 3759, 4056, 4299, 3975, 6159, 6186, 6186, 6186, 5838, 5109, 3732, 2112, 2139, 3945, 4534, 4569, 4575, 6453, 5405, 4461, 4338, 5572, 3809, 2411, 1214, 1205, 3805, 4526, 4379, 2189, 3890, 3242, 1418, 2876, 5828, 2799, 5133, 5563, 5481, 2325, 155, 533, 2801, 3617, 725, 56, 4385, 834, 3444, 5482, 3273, 2166, 2328, 1908, 1372, 868]}
|
136 |
+
```
|
137 |
+
|
138 |
+
We use Deepspeed to train the model:
|
139 |
+
```bash
|
140 |
+
deepspeed --num_nodes 1 --num_gpus 4 train_scripts/train_llm.py --data_file /external_data/yueyudata/speech_corpus/ --model_name /external_data/models/rwkv7-1.5B-world/ --output_dir /external_data/yueyudata/cosy_voice_llm --max_length 2048 --wandb_project toy_cosy_llm --wandb_run_name server2_rwkv_7_1.5B --ds_param_offload True --ds_optimizer_offload True --ds_stage 2 --gradient_checkpointing True --logging_steps 10 --per_device_train_batch_size 8
|
141 |
+
```
|
142 |
+
The base model can be downloaded from https://huggingface.co/collections/fla-hub/rwkv7-6790fd37b4b6137b088a0d8a , just choose a proper model for your training.
|
143 |
+
|
144 |
+
|
145 |
+
### Cosy 2.0 LLM Inference
|
146 |
+
|
147 |
+
### Some samples
|
148 |
+
|
149 |
+
#### Zero shot inference
|
150 |
+
prompt audio :
|
151 |
+
[prompt audio](mine.wav)
|
152 |
+
|
153 |
+
prompt text: "今天天气挺不错的。"
|
154 |
+
|
155 |
+
tts text: "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
|
156 |
+
|
157 |
+
tts audio:
|
158 |
+
[tts audio](zero_shot_0.wav)
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
### TODO:
|
163 |
+
0. Drop prompt audio tokens randomly to simulate unconditional guided generation.
|
164 |
+
1. Add special control tokens in Cosy 2.0 in RWKV tokenizer and add them to generate audio tokens again:
|
165 |
+
```python
|
166 |
+
special_tokens = {
|
167 |
+
'eos_token': '<|endoftext|>',
|
168 |
+
'pad_token': '<|endoftext|>',
|
169 |
+
'additional_special_tokens': [
|
170 |
+
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
171 |
+
'[breath]', '<strong>', '</strong>', '[noise]',
|
172 |
+
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
173 |
+
'[quick_breath]',
|
174 |
+
"<laughter>", "</laughter>",
|
175 |
+
"[hissing]", "[sigh]", "[vocalized-noise]",
|
176 |
+
"[lipsmack]", "[mn]"
|
177 |
+
]
|
178 |
+
}
|
179 |
+
```
|
180 |
+
2. Add special control tokens like dialects in RWKV7LM and generate audio tokens for training.
|
181 |
+
3. Implement streaming generation for Cosy 2.0 in RWKV7LM.
|
Trump.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:296432bb06954080b77c04a88841d61928d936077f5162947359520fa17836be
|
3 |
+
size 342108
|
_config.yml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
markdown: kramdown
|
2 |
+
kramdown:
|
3 |
+
parse_block_html: true
|
another.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4d103efaf538db967559861dbcf9995b60eca582360a6add5cf27c3faf3a49e
|
3 |
+
size 199724
|
badXT_71.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c5e28420eb8c4506a1988d484fe9270b8422161d733c567abfccd74c106ceb9
|
3 |
+
size 794726
|
data/cosy/data/data_processor.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pyexpat import model
|
2 |
+
import torchaudio
|
3 |
+
from hyperpyyaml import load_hyperpyyaml
|
4 |
+
import os
|
5 |
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
6 |
+
from cosyvoice.cli.cosyvoice import CosyVoice2
|
7 |
+
import json
|
8 |
+
import torch
|
9 |
+
|
10 |
+
def load_from_configuration(model_dir):
|
11 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
12 |
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
13 |
+
return configs
|
14 |
+
def init_process(model_dir,device):
|
15 |
+
cosyvoice = CosyVoice2(model_dir, load_jit=False, load_trt=False, fp16=True,device=device)
|
16 |
+
# configs = load_from_configuration(model_dir)
|
17 |
+
# frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
18 |
+
# configs['feat_extractor'],
|
19 |
+
# '{}/campplus.onnx'.format(model_dir),
|
20 |
+
# '{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
21 |
+
# '{}/spk2info.pt'.format(model_dir),
|
22 |
+
# configs['allowed_special'],
|
23 |
+
# device)
|
24 |
+
frontend = cosyvoice.frontend
|
25 |
+
llm = cosyvoice.model.llm
|
26 |
+
return frontend,llm,cosyvoice
|
27 |
+
|
28 |
+
|
29 |
+
def preprocess_prompts(frontend,prompts_dir):
|
30 |
+
language_results = {}
|
31 |
+
final_rate = 24000
|
32 |
+
for root, dirs, files in os.walk(prompts_dir):
|
33 |
+
for file in files:
|
34 |
+
if file.endswith('.json'):
|
35 |
+
json_file = os.path.join(root, file)
|
36 |
+
print(f"处理文件 {json_file}")
|
37 |
+
language = json_file.split('/')[-2]
|
38 |
+
if language not in language_results:
|
39 |
+
language_results[language] = []
|
40 |
+
|
41 |
+
# 尝试不同的编码格式读取文件
|
42 |
+
try:
|
43 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
44 |
+
json_data = json.load(f)
|
45 |
+
except UnicodeDecodeError:
|
46 |
+
try:
|
47 |
+
# 尝试 GB2312/GBK 编码 (常用于中文)
|
48 |
+
with open(json_file, 'r', encoding='gbk') as f:
|
49 |
+
json_data = json.load(f)
|
50 |
+
except UnicodeDecodeError:
|
51 |
+
try:
|
52 |
+
# 尝试 GB18030 编码 (扩展的中文编码)
|
53 |
+
with open(json_file, 'r', encoding='gb18030') as f:
|
54 |
+
json_data = json.load(f)
|
55 |
+
except Exception as e:
|
56 |
+
print(f"无法读取文件 {json_file}: {e}")
|
57 |
+
continue
|
58 |
+
|
59 |
+
wav_file = json_file.replace('.json', '.wav')
|
60 |
+
prompt_text = json_data['text']
|
61 |
+
prompt_speech = torchaudio.load(wav_file, backend='soundfile')[0]
|
62 |
+
fake_tts_text = "a"
|
63 |
+
with torch.no_grad():
|
64 |
+
model_input = frontend.frontend_zero_shot(fake_tts_text, prompt_text, prompt_speech,final_rate)
|
65 |
+
language_results[language].append((model_input,prompt_text))
|
66 |
+
return language_results
|
67 |
+
|
68 |
+
def generate_speech_tokens(llm,frontend,tts_text,model_input,device):
|
69 |
+
tts_text = frontend.text_normalize(tts_text,split=False, text_frontend=True)
|
70 |
+
tts_text_token, tts_text_token_len = frontend._extract_text_token(tts_text)
|
71 |
+
tts_text_token_len = torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(device)
|
72 |
+
prompt_text = model_input['prompt_text'].to(device)
|
73 |
+
prompt_text_len = torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(device)
|
74 |
+
llm_prompt_speech_token = model_input['llm_prompt_speech_token'].to(device)
|
75 |
+
prompt_speech_token_len = torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(device)
|
76 |
+
flow_prompt_speech_token = model_input['flow_prompt_speech_token'].to(device)
|
77 |
+
prompt_speech_feat = model_input['prompt_speech_feat'].to(device)
|
78 |
+
llm_embedding = model_input['llm_embedding'].to(device)
|
79 |
+
flow_embedding = model_input['flow_embedding'].to(device)
|
80 |
+
speech_tokens = []
|
81 |
+
for i in llm.inference(text = tts_text_token,
|
82 |
+
text_len = tts_text_token_len,
|
83 |
+
prompt_text = prompt_text,
|
84 |
+
prompt_text_len = prompt_text_len,
|
85 |
+
prompt_speech_token = llm_prompt_speech_token,
|
86 |
+
prompt_speech_token_len = prompt_speech_token_len,
|
87 |
+
embedding=llm_embedding
|
88 |
+
):
|
89 |
+
speech_tokens.append(i)
|
90 |
+
tts_speech_tokens = torch.tensor(speech_tokens).unsqueeze(dim=0).to(device)
|
91 |
+
return tts_speech_tokens
|
92 |
+
|
93 |
+
if __name__ == '__main__':
|
94 |
+
model_dir = '/data/yueyu/models/CosyVoice2-0.5B'
|
95 |
+
prompts_dir = 'extract_data/prompts'
|
96 |
+
|
97 |
+
device = 'cuda:0'
|
98 |
+
frontend,llm,cosyvoice = init_process(model_dir
|
99 |
+
,device)
|
100 |
+
prompts = preprocess_prompts(frontend,prompts_dir)
|
101 |
+
print(prompts)
|
102 |
+
model_input = prompts['zh'][0][0]
|
103 |
+
prompt_text = prompts['zh'][0][1]
|
104 |
+
tts_text = '扫一扫,立即体验中国银行信用卡好礼、绑卡立减等热门活动,实时掌握更多优惠信息。'
|
105 |
+
tts_text = '在中国的一个偏远山区,有一位名叫李远的年轻人,他对集群通信系统有着浓厚的兴趣。每天晚上,他都会在自己的小屋里研究各种关于集群通信系统的资料,试图弄懂其中的原理和运作机制。他对这个领域的研究不仅仅停留在理论层面,还亲手制作了一些模型,试图通过实践来加深理解。'
|
106 |
+
tts_text = "歷史(现代汉语词汇,古典文言文称之为史),指人类社会过去的事件和行动,以及对这些事件行为有系统的记录、诠释和研究。歷史可提供今人理解過去,作為未來行事的參考依據,与伦理、哲学和艺术同属人类精神文明的重要成果。历史的第二个含义,即对过去事件的记录和研究,又称历史学”,或简称“史学”。隶属于历史学或与其密切相关的学科有年代学、编纂学、家谱学、古文字学、计量历史学、考古学、社会学和新闻学等,参见历史学。记录和研究历史的人称为历史学家,简称“史学家”,中国古代称为史官。记录历史的书籍称为史书,如《史記》、《汉书》等,粗分為「官修」與「民載」兩類。"
|
107 |
+
tts_text = "### 如何提高花样游泳水平"
|
108 |
+
tts_speech_tokens = generate_speech_tokens(llm,frontend,tts_text,model_input,device)
|
109 |
+
print(tts_speech_tokens)
|
110 |
+
|
111 |
+
|
112 |
+
flow_prompt_speech_token = model_input['flow_prompt_speech_token'].to(device)
|
113 |
+
prompt_speech_feat = model_input['prompt_speech_feat'].to(device)
|
114 |
+
llm_embedding = model_input['llm_embedding'].to(device)
|
115 |
+
flow_embedding = model_input['flow_embedding'].to(device)
|
116 |
+
cosyvoice.model.hift_cache_dict['xxxx'] = None
|
117 |
+
tts_speech = cosyvoice.model.token2wav(token=tts_speech_tokens,
|
118 |
+
prompt_token=flow_prompt_speech_token,
|
119 |
+
prompt_feat=prompt_speech_feat,
|
120 |
+
embedding=flow_embedding,
|
121 |
+
uuid='xxxx',
|
122 |
+
token_offset=0,
|
123 |
+
finalize=True,
|
124 |
+
speed=1.0)
|
125 |
+
print(f'tts_speech shape:{tts_speech.shape}')
|
126 |
+
tts_speech = tts_speech.cpu()
|
127 |
+
torchaudio.save('zh_tts_S.wav', tts_speech, 24000)
|
128 |
+
print(model_input)
|
data/cosy/test/test_vq.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from turtle import back
|
2 |
+
from click import prompt
|
3 |
+
import torch
|
4 |
+
from cosyvoice.cli.cosyvoice import CosyVoice2
|
5 |
+
print(torch.cuda.is_available())
|
6 |
+
print(torch.cuda.current_device())
|
7 |
+
print(torch.cuda.device(0))
|
8 |
+
print(torch.cuda.device_count())
|
9 |
+
model_path = '/data/yueyu/models/CosyVoice2-0.5B'
|
10 |
+
# cosyvoice = CosyVoice2(model_path, load_jit=False, load_trt=False, fp16=False)
|
11 |
+
# print(cosyvoice)
|
12 |
+
# from cosyvoice.utils.file_utils import load_wav
|
13 |
+
# import torchaudio
|
14 |
+
# prompt_speech_16k = load_wav('/home/yueyulin/github/CosyVoice/asset/zero_shot_prompt.wav', 16000)
|
15 |
+
# # prompt_speech_16k = torch.rand((1, 16000))
|
16 |
+
# for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
17 |
+
# torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
18 |
+
|
19 |
+
# for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
|
20 |
+
# torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
21 |
+
# # instruct usage
|
22 |
+
# for i, j in enumerate(cosyvoice.inference_instruct2('吾今朝早上去外婆家吃饭。', '用上海话说这句话', prompt_speech_16k, stream=False)):
|
23 |
+
# torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
24 |
+
|
25 |
+
from hyperpyyaml import load_hyperpyyaml
|
26 |
+
import os
|
27 |
+
def load_from_configuration(model_dir):
|
28 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
29 |
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
30 |
+
return configs
|
31 |
+
|
32 |
+
configs = load_from_configuration(model_path)
|
33 |
+
print(configs)
|
34 |
+
|
35 |
+
import torchaudio
|
36 |
+
def load_wav(wav, target_sr):
|
37 |
+
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
|
38 |
+
speech = speech.mean(dim=0, keepdim=True)
|
39 |
+
if sample_rate != target_sr:
|
40 |
+
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
41 |
+
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
42 |
+
return speech
|
43 |
+
|
44 |
+
zh_prompt_tar_file="/data/yueyu/data/Emilia-Dataset/Emilia/ZH/ZH-B000000.tar"
|
45 |
+
en_prompt_tar_file="/data/yueyu/data/Emilia-Dataset/Emilia/EN/EN-B000000.tar"
|
46 |
+
|
47 |
+
|
48 |
+
def load_file_list(tar_file):
|
49 |
+
#the files are FILE_NAME.mp3/FILE_NAME.json
|
50 |
+
#return all FILE_NAME as a list which has a mp3 and json
|
51 |
+
import tarfile
|
52 |
+
with tarfile.open(tar_file, 'r') as f:
|
53 |
+
file_names = f.getnames()
|
54 |
+
mp3_files = [i for i in file_names if i.endswith('.mp3')]
|
55 |
+
json_files = [i for i in file_names if i.endswith('.json')]
|
56 |
+
|
57 |
+
#filter mp3_files without corresponded json
|
58 |
+
mp3_files = [i for i in mp3_files if i.replace('.mp3', '.json') in json_files]
|
59 |
+
return mp3_files
|
60 |
+
|
61 |
+
zh_files = load_file_list(zh_prompt_tar_file)
|
62 |
+
print(zh_files[:10])
|
63 |
+
en_files = load_file_list(en_prompt_tar_file)
|
64 |
+
print(en_files[:10])
|
65 |
+
import io
|
66 |
+
|
67 |
+
def load_random_samples_from_tar(tar_file, files, num_samples,target_sr,max_duration=10):
|
68 |
+
import random
|
69 |
+
import tarfile
|
70 |
+
import json
|
71 |
+
samples = []
|
72 |
+
with tarfile.open(tar_file, 'r') as f:
|
73 |
+
for i in random.sample(files, len(files)):
|
74 |
+
mp3 = f.extractfile(i)
|
75 |
+
mp3_bytes = io.BytesIO(mp3.read())
|
76 |
+
speech, sample_rate = torchaudio.load(mp3_bytes,backend='soundfile')
|
77 |
+
json_file = f.extractfile(i.replace('.mp3', '.json'))
|
78 |
+
json_data = json.load(json_file)
|
79 |
+
duration = json_data['duration']
|
80 |
+
if duration > max_duration:
|
81 |
+
continue
|
82 |
+
speech = speech.mean(dim=0, keepdim=True)
|
83 |
+
if sample_rate != target_sr:
|
84 |
+
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
85 |
+
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
86 |
+
samples.append((speech, json_data,sample_rate))
|
87 |
+
if len(samples) == num_samples:
|
88 |
+
break
|
89 |
+
return samples
|
90 |
+
target_sr = 16000
|
91 |
+
zh_samples = load_random_samples_from_tar(zh_prompt_tar_file, zh_files, 10, target_sr)
|
92 |
+
|
93 |
+
one_sample,one_json,sample_rate = zh_samples[0]
|
94 |
+
print(one_json)
|
95 |
+
print(sample_rate)
|
96 |
+
torchaudio.save('zh_sample.wav', one_sample, target_sr)
|
97 |
+
print(len(zh_samples))
|
98 |
+
|
99 |
+
en_samples = load_random_samples_from_tar(en_prompt_tar_file, en_files, 10, target_sr)
|
100 |
+
one_sample,one_json,sample_rate = en_samples[0]
|
101 |
+
print(one_json)
|
102 |
+
print(sample_rate)
|
103 |
+
torchaudio.save('en_sample.wav', one_sample, target_sr)
|
104 |
+
print(len(en_samples))
|
105 |
+
|
106 |
+
def resample_audio(samples, target_sr):
|
107 |
+
resampled_samples = []
|
108 |
+
for i in samples:
|
109 |
+
speech, sample_rate = i
|
110 |
+
if sample_rate != target_sr:
|
111 |
+
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
112 |
+
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
113 |
+
resampled_samples.append((speech, sample_rate))
|
114 |
+
return resampled_samples
|
115 |
+
|
116 |
+
prompt_text = zh_samples[0][1]['text']
|
117 |
+
prompt_speech = zh_samples[0][0]
|
118 |
+
print(prompt_text)
|
119 |
+
print(prompt_speech)
|
120 |
+
from cosyvoice.cli.cosyvoice import CosyVoice2
|
121 |
+
cosyvoice = CosyVoice2(model_path, load_jit=False, load_trt=False, fp16=True)
|
122 |
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
123 |
+
frontend = cosyvoice.frontend
|
124 |
+
prompt_text = frontend.text_normalize(prompt_text,split=False, text_frontend=True)
|
125 |
+
print(f'normalized prompt_text:{prompt_text}')
|
126 |
+
tts_text = '扫一扫,立即体验中国银行信用卡好礼、绑卡立减等热门活动,实时掌握更多优惠信息。'
|
127 |
+
tts_text = "在中国的一个偏远山区,有一位名叫李远的年轻人,他对集群通信系统有着浓厚的兴趣。每天晚上,他都会在自己的小屋里研究各种关于集群通信系统的资料,试图弄懂其中的原理和运作机制。他对这个领域的研究不仅仅停留在理论层面,还亲手制作了一些模型,试图通过实践来加深理解。"
|
128 |
+
tts_text = "歷史(现代汉语词汇,古典文言文称之为史),指人类社会过去的事件和行动,以及对这些事件行为有系统的记录、诠释和研究。歷史可提供今人理解過去,作為未來行事的參考依據,与伦理、哲学和艺术同属人类精神文明的重要成果。历史的第二个含义,即对过去事件的记录和研究,又称历史学”,或简称“史学”。隶属于历史学或与其密切相关的学科有年代学、编纂学、家谱学、古文字学、计量历史学、考古学、社会学和新闻学等,参见历史学。记录和研究历史的人称为历史学家,简称“史学家”,中国古代称为史官。记录历史的书籍称为史书,如《史記》、《汉书》等,粗分為「官修」與「民載」兩類。"
|
129 |
+
tts_text = frontend.text_normalize(tts_text,split=False, text_frontend=True)
|
130 |
+
print(f'normalized tts_text:{tts_text}')
|
131 |
+
final_rate = 24000
|
132 |
+
model_input = frontend.frontend_zero_shot(tts_text, prompt_text, prompt_speech,final_rate)
|
133 |
+
print(model_input)
|
134 |
+
llm = cosyvoice.model.llm
|
135 |
+
device = cosyvoice.model.device
|
136 |
+
text = model_input['text'].to(device)
|
137 |
+
text_len = torch.tensor([text.shape[1]], dtype=torch.int32).to(device)
|
138 |
+
prompt_text = model_input['prompt_text'].to(device)
|
139 |
+
prompt_text_len = torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(device)
|
140 |
+
llm_prompt_speech_token = model_input['llm_prompt_speech_token'].to(device)
|
141 |
+
prompt_speech_token_len = torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(device)
|
142 |
+
flow_prompt_speech_token = model_input['flow_prompt_speech_token'].to(device)
|
143 |
+
prompt_speech_feat = model_input['prompt_speech_feat'].to(device)
|
144 |
+
llm_embedding = model_input['llm_embedding'].to(device)
|
145 |
+
flow_embedding = model_input['flow_embedding'].to(device)
|
146 |
+
speech_tokens = []
|
147 |
+
for i in llm.inference(text = text,
|
148 |
+
text_len = text_len,
|
149 |
+
prompt_text = prompt_text,
|
150 |
+
prompt_text_len = prompt_text_len,
|
151 |
+
prompt_speech_token = llm_prompt_speech_token,
|
152 |
+
prompt_speech_token_len = prompt_speech_token_len,
|
153 |
+
embedding=llm_embedding
|
154 |
+
):
|
155 |
+
speech_tokens.append(i)
|
156 |
+
print(speech_tokens)
|
157 |
+
|
158 |
+
tts_speech_tokens = torch.tensor(speech_tokens).unsqueeze(dim=0).to(device)
|
159 |
+
print(f'tts_speech_tokens shape:{tts_speech_tokens.shape}')
|
160 |
+
cosyvoice.model.hift_cache_dict['xxxx'] = None
|
161 |
+
tts_speech = cosyvoice.model.token2wav(token=tts_speech_tokens,
|
162 |
+
prompt_token=flow_prompt_speech_token,
|
163 |
+
prompt_feat=prompt_speech_feat,
|
164 |
+
embedding=flow_embedding,
|
165 |
+
uuid='xxxx',
|
166 |
+
token_offset=0,
|
167 |
+
finalize=True,
|
168 |
+
speed=1.0)
|
169 |
+
print(f'tts_speech shape:{tts_speech.shape}')
|
170 |
+
tts_speech = tts_speech.cpu()
|
171 |
+
torchaudio.save('zh_tts.wav', tts_speech, final_rate)
|
data/utils/convert_embeddings_2_pt.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
from sklearn.cluster import KMeans
|
7 |
+
jsonl_dir = sys.argv[1]
|
8 |
+
output_file_name = sys.argv[2]
|
9 |
+
|
10 |
+
# Load the embeddings from jsonl files the key is the name of the file
|
11 |
+
embeddings = {}
|
12 |
+
for file in os.listdir(jsonl_dir):
|
13 |
+
print("Processing", file)
|
14 |
+
if file.endswith("_embeddings.json"):
|
15 |
+
with open(os.path.join(jsonl_dir, file), "r") as f:
|
16 |
+
print("Loading", file)
|
17 |
+
data = json.load(f)
|
18 |
+
key_name = os.path.basename(file).replace("_embeddings.json", "")
|
19 |
+
np_array = np.array(data)
|
20 |
+
if np_array.shape[0] == 1:
|
21 |
+
np_array = np_array[0]
|
22 |
+
else:
|
23 |
+
#find the cluster center of the embeddings using kmeans
|
24 |
+
kmeans = KMeans(n_clusters=1, random_state=0, n_init = 'auto').fit(np_array)
|
25 |
+
np_array = kmeans.cluster_centers_[0]
|
26 |
+
|
27 |
+
embeddings[key_name]= {'embedding' : torch.tensor(np_array, dtype=torch.float32).unsqueeze(0)}
|
28 |
+
torch.save(embeddings, output_file_name)
|
29 |
+
print("Embeddings saved to", output_file_name)
|
30 |
+
|
31 |
+
state_dict = torch.load(output_file_name)
|
32 |
+
print("Loaded embeddings from", output_file_name)
|
33 |
+
for key in state_dict:
|
34 |
+
print(key, state_dict[key]['embedding'].shape)
|
data/utils/create_embeddings_from_raw.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from re import A
|
3 |
+
import whisper
|
4 |
+
from librosa import resample
|
5 |
+
import multiprocessing
|
6 |
+
from tqdm import tqdm
|
7 |
+
import onnxruntime
|
8 |
+
from onnxruntime import InferenceSession
|
9 |
+
import torch
|
10 |
+
import pyarrow.parquet as pq
|
11 |
+
import numpy as np
|
12 |
+
import json
|
13 |
+
import io
|
14 |
+
import soundfile as sf
|
15 |
+
import torchaudio
|
16 |
+
import torchaudio.compliance.kaldi as kaldi
|
17 |
+
import mmap
|
18 |
+
import os
|
19 |
+
import pyarrow.parquet as pq
|
20 |
+
import io
|
21 |
+
import soundfile as sf
|
22 |
+
import torchaudio.compliance.kaldi as kaldi
|
23 |
+
import torch
|
24 |
+
import numpy as np
|
25 |
+
import onnxruntime
|
26 |
+
|
27 |
+
def process_file(file_info):
|
28 |
+
"""处理单个parquet文件的函数,每个进程调用一次"""
|
29 |
+
parquet_file, output_path, speaker_extractor, device = file_info
|
30 |
+
|
31 |
+
# 为每个进程创建独立的speech_tokenizer_session
|
32 |
+
option = onnxruntime.SessionOptions()
|
33 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
34 |
+
option.intra_op_num_threads = 1
|
35 |
+
ort_session = onnxruntime.InferenceSession(speaker_extractor, sess_options=option,
|
36 |
+
providers=["CPUExecutionProvider"])
|
37 |
+
results = {}
|
38 |
+
try:
|
39 |
+
# 创建目标文件名
|
40 |
+
base_filename = os.path.splitext(os.path.basename(parquet_file))[0]
|
41 |
+
output_file = os.path.join(output_path, f"{base_filename}_tokens.jsonl")
|
42 |
+
|
43 |
+
# 使用PyArrow读取parquet文件的元数据,获取总行数
|
44 |
+
parquet_metadata = pq.read_metadata(parquet_file)
|
45 |
+
total_rows = parquet_metadata.num_rows
|
46 |
+
batch_size = 100
|
47 |
+
|
48 |
+
# 使用 mmap 读取 parquet 文件
|
49 |
+
with open(parquet_file, 'rb') as f:
|
50 |
+
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
51 |
+
|
52 |
+
# 使用 io.BytesIO 将 mmap 对象包装成文件对象
|
53 |
+
buffer = io.BytesIO(mm)
|
54 |
+
|
55 |
+
pf = pq.ParquetFile(buffer) # 使用 mmap 包装的 buffer
|
56 |
+
|
57 |
+
progress = tqdm(total=total_rows,
|
58 |
+
desc=f"Processing {os.path.basename(parquet_file)}",
|
59 |
+
position=multiprocessing.current_process()._identity[0] % 10)
|
60 |
+
|
61 |
+
current_row = 0
|
62 |
+
idx = 0
|
63 |
+
for batch in pf.iter_batches(batch_size=batch_size):
|
64 |
+
df_batch = batch.to_pandas()
|
65 |
+
|
66 |
+
# 处理当前批次中的每一行
|
67 |
+
for _, row in df_batch.iterrows():
|
68 |
+
current_row += 1
|
69 |
+
audio_obj = row['audio']
|
70 |
+
audio_data = audio_obj['bytes']
|
71 |
+
transcription = row['transcription']
|
72 |
+
language = row['language']
|
73 |
+
speaker = row['speaker']
|
74 |
+
if speaker not in results:
|
75 |
+
results[speaker] = {}
|
76 |
+
if language not in results[speaker]:
|
77 |
+
results[speaker][language] = []
|
78 |
+
if len(results[speaker][language]) >= 10:
|
79 |
+
progress.update(1)
|
80 |
+
continue
|
81 |
+
|
82 |
+
with io.BytesIO(audio_data) as audio_buffer:
|
83 |
+
prompt_data, sample_rate = sf.read(audio_buffer)
|
84 |
+
# 确保是单声道,并转换为float32
|
85 |
+
if len(prompt_data.shape) > 1:
|
86 |
+
prompt_data = prompt_data[:, 0]
|
87 |
+
prompt_data = prompt_data.astype(np.float32)
|
88 |
+
|
89 |
+
# 重采样到16kHz (如果需要)
|
90 |
+
if sample_rate != 16000:
|
91 |
+
prompt_data = resample(prompt_data, orig_sr=sample_rate, target_sr=16000)
|
92 |
+
|
93 |
+
prompt_speech_16k = torch.tensor(prompt_data).unsqueeze(0)
|
94 |
+
|
95 |
+
feat = kaldi.fbank(prompt_speech_16k,
|
96 |
+
num_mel_bins=80,
|
97 |
+
dither=0,
|
98 |
+
sample_frequency=16000)
|
99 |
+
feat = feat - feat.mean(dim=0,keepdim=True)
|
100 |
+
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
101 |
+
|
102 |
+
results[speaker][language].append(embedding)
|
103 |
+
|
104 |
+
progress.update(1)
|
105 |
+
|
106 |
+
# 关闭 mmap 对象
|
107 |
+
mm.close()
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
print(f'All speakers {results.keys()}')
|
112 |
+
for speaker in results:
|
113 |
+
print(f'{speaker} : All languages {results[speaker].keys()} in {os.getpid()}')
|
114 |
+
return results
|
115 |
+
except Exception as e:
|
116 |
+
import traceback
|
117 |
+
traceback.print_exc()
|
118 |
+
return f"Error processing {parquet_file}: {str(e)}"
|
119 |
+
def process_file_x(file_info):
|
120 |
+
"""处理单个parquet文件的函数,每个进程调用一次"""
|
121 |
+
parquet_file, output_path, speaker_extractor, device = file_info
|
122 |
+
|
123 |
+
# 为每个进程创���独立的speech_tokenizer_session
|
124 |
+
option = onnxruntime.SessionOptions()
|
125 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
126 |
+
option.intra_op_num_threads = 1
|
127 |
+
ort_session = InferenceSession(speaker_extractor, sess_options=option,
|
128 |
+
providers=["CPUExecutionProvider"])
|
129 |
+
results = {}
|
130 |
+
try:
|
131 |
+
# 创建目标文件名
|
132 |
+
base_filename = os.path.splitext(os.path.basename(parquet_file))[0]
|
133 |
+
output_file = os.path.join(output_path, f"{base_filename}_tokens.jsonl")
|
134 |
+
|
135 |
+
# 使用PyArrow读取parquet文件的元数据,获取总行数
|
136 |
+
parquet_metadata = pq.read_metadata(parquet_file)
|
137 |
+
total_rows = parquet_metadata.num_rows
|
138 |
+
batch_size = 100
|
139 |
+
|
140 |
+
pf = pq.ParquetFile(parquet_file)
|
141 |
+
|
142 |
+
progress = tqdm(total=total_rows,
|
143 |
+
desc=f"Processing {os.path.basename(parquet_file)}",
|
144 |
+
position=multiprocessing.current_process()._identity[0] % 10)
|
145 |
+
|
146 |
+
current_row = 0
|
147 |
+
idx = 0
|
148 |
+
for batch in pf.iter_batches(batch_size=batch_size):
|
149 |
+
df_batch = batch.to_pandas()
|
150 |
+
|
151 |
+
# 处理当前批次中的每一行
|
152 |
+
for _, row in df_batch.iterrows():
|
153 |
+
current_row += 1
|
154 |
+
audio_obj = row['audio']
|
155 |
+
audio_data = audio_obj['bytes']
|
156 |
+
transcription = row['transcription']
|
157 |
+
language = row['language']
|
158 |
+
speaker = row['speaker']
|
159 |
+
if speaker not in results:
|
160 |
+
results[speaker] = {}
|
161 |
+
if language not in results[speaker]:
|
162 |
+
results[speaker][language] = []
|
163 |
+
if len(results[speaker][language]) >= 10:
|
164 |
+
progress.update(1)
|
165 |
+
continue
|
166 |
+
|
167 |
+
with io.BytesIO(audio_data) as buffer:
|
168 |
+
prompt_data, sample_rate = sf.read(buffer)
|
169 |
+
# 确保是单声道,并转换为float32
|
170 |
+
if len(prompt_data.shape) > 1:
|
171 |
+
prompt_data = prompt_data[:, 0]
|
172 |
+
prompt_data = prompt_data.astype(np.float32)
|
173 |
+
|
174 |
+
# 重采样到16kHz (如果需要)
|
175 |
+
if sample_rate != 16000:
|
176 |
+
prompt_data = resample(prompt_data, orig_sr=sample_rate, target_sr=16000)
|
177 |
+
|
178 |
+
prompt_speech_16k = torch.tensor(prompt_data).unsqueeze(0)
|
179 |
+
|
180 |
+
feat = kaldi.fbank(prompt_speech_16k,
|
181 |
+
num_mel_bins=80,
|
182 |
+
dither=0,
|
183 |
+
sample_frequency=16000)
|
184 |
+
feat = feat - feat.mean(dim=0,keepdim=True)
|
185 |
+
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
186 |
+
|
187 |
+
results[speaker][language].append(embedding)
|
188 |
+
|
189 |
+
progress.update(1)
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
print(f'All speakers {results.keys()}')
|
196 |
+
for speaker in results:
|
197 |
+
print(f'{speaker} : All languages {results[speaker].keys()} in {os.getpid()}')
|
198 |
+
return results
|
199 |
+
except Exception as e:
|
200 |
+
import traceback
|
201 |
+
traceback.print_exc()
|
202 |
+
return f"Error processing {parquet_file}: {str(e)}"
|
203 |
+
if __name__ == '__main__':
|
204 |
+
import argparse
|
205 |
+
parser = argparse.ArgumentParser()
|
206 |
+
parser.add_argument('--data_path', type=str, default='/external_data/yueyudata/starrail-voice')
|
207 |
+
parser.add_argument('--output_path',type=str,default='/external_data/yueyudata/starrail-voice-speaker-embeddings')
|
208 |
+
parser.add_argument('--speaker_extractor',type=str,default='/external_data/models/CosyVoice2-0.5B_RWKV_1.5B/campplus.onnx')
|
209 |
+
parser.add_argument('--device',type=str,default='cuda:0')
|
210 |
+
parser.add_argument('--num_processes',type=int,default=4)
|
211 |
+
args = parser.parse_args()
|
212 |
+
|
213 |
+
print(args)
|
214 |
+
data_path = args.data_path
|
215 |
+
output_path = args.output_path
|
216 |
+
device = args.device
|
217 |
+
speaker_extractor = args.speaker_extractor
|
218 |
+
num_processes = args.num_processes
|
219 |
+
|
220 |
+
# 确保输出目录存在
|
221 |
+
os.makedirs(output_path, exist_ok=True)
|
222 |
+
|
223 |
+
# 找到所有parquet文件
|
224 |
+
parquet_files = []
|
225 |
+
for root, dirs, files in os.walk(data_path):
|
226 |
+
for file in files:
|
227 |
+
if file.endswith('.parquet'):
|
228 |
+
parquet_files.append(os.path.join(root, file))
|
229 |
+
print(f'Found {len(parquet_files)} parquet files in {data_path}')
|
230 |
+
|
231 |
+
# 准备多进程参数
|
232 |
+
file_info_list = [(file, output_path, speaker_extractor, device) for file in parquet_files]
|
233 |
+
|
234 |
+
# 使用进程池处理文件
|
235 |
+
print(f"Starting processing with {num_processes} processes")
|
236 |
+
|
237 |
+
# 使用进程池处理文件
|
238 |
+
print(f"Starting processing with {num_processes} processes")
|
239 |
+
with multiprocessing.Pool(processes=num_processes) as pool:
|
240 |
+
results = pool.map(process_file, file_info_list)
|
241 |
+
|
242 |
+
# 输出处理结果
|
243 |
+
print('Processing complete,merge results')
|
244 |
+
final_results = {}
|
245 |
+
for result in results:
|
246 |
+
if isinstance(result, dict):
|
247 |
+
for speaker in result:
|
248 |
+
if speaker not in final_results:
|
249 |
+
final_results[speaker] = {}
|
250 |
+
for language in result[speaker]:
|
251 |
+
if language not in final_results[speaker]:
|
252 |
+
final_results[speaker][language] = []
|
253 |
+
final_results[speaker][language].extend(result[speaker][language])
|
254 |
+
else:
|
255 |
+
print(result)
|
256 |
+
|
257 |
+
# 输出结果
|
258 |
+
for speaker in final_results:
|
259 |
+
for language in final_results[speaker]:
|
260 |
+
output_file = os.path.join(output_path, f"{speaker}_{language}_embeddings.json")
|
261 |
+
print(f"Writing embeddings for {speaker} ({language}) to {output_file}")
|
262 |
+
with open(output_file, 'w', encoding='utf-8') as f_out:
|
263 |
+
json.dump(final_results[speaker][language], f_out)
|
data/utils/create_lm_corpus_from_raw.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
import io
|
6 |
+
import torch
|
7 |
+
import soundfile as sf
|
8 |
+
import pyarrow.parquet as pq
|
9 |
+
import whisper
|
10 |
+
from librosa import resample
|
11 |
+
import multiprocessing
|
12 |
+
from tqdm import tqdm
|
13 |
+
import onnxruntime
|
14 |
+
from onnxruntime import InferenceSession
|
15 |
+
|
16 |
+
def process_file(file_info):
|
17 |
+
"""处理单个parquet文件的函数,每个进程调用一次"""
|
18 |
+
parquet_file, output_path, speech_tokenizer_model, device = file_info
|
19 |
+
|
20 |
+
# 为每个进程创建独立的speech_tokenizer_session
|
21 |
+
option = onnxruntime.SessionOptions()
|
22 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
23 |
+
option.intra_op_num_threads = 1
|
24 |
+
cuda_idx = int(device.split(':')[-1] if device is not None and 'cuda' in device else '0')
|
25 |
+
speech_tokenizer_session = InferenceSession(speech_tokenizer_model, sess_options=option,
|
26 |
+
providers=[("CUDAExecutionProvider", {"device_id": cuda_idx})
|
27 |
+
if torch.cuda.is_available() else "CPUExecutionProvider"])
|
28 |
+
|
29 |
+
try:
|
30 |
+
# 创建目标文件名
|
31 |
+
base_filename = os.path.splitext(os.path.basename(parquet_file))[0]
|
32 |
+
output_file = os.path.join(output_path, f"{base_filename}_tokens.jsonl")
|
33 |
+
|
34 |
+
# 使用PyArrow读取parquet文件的元数据,获取总行数
|
35 |
+
parquet_metadata = pq.read_metadata(parquet_file)
|
36 |
+
total_rows = parquet_metadata.num_rows
|
37 |
+
batch_size = 1000
|
38 |
+
|
39 |
+
# 检查是否有已经处理过的文件,计算已处理的行数
|
40 |
+
processed_rows = 0
|
41 |
+
if os.path.exists(output_file):
|
42 |
+
with open(output_file, 'r', encoding='utf-8') as f_check:
|
43 |
+
for _ in f_check:
|
44 |
+
processed_rows += 1
|
45 |
+
print(f"Found existing file {output_file} with {processed_rows} processed rows")
|
46 |
+
|
47 |
+
# 如果已经处理完所有行,跳过此文件
|
48 |
+
if processed_rows >= total_rows:
|
49 |
+
return f"Skipped {parquet_file}: all {total_rows} rows already processed"
|
50 |
+
|
51 |
+
# 逐批处理数据,以追加方式打开输出文件
|
52 |
+
with open(output_file, 'a' if processed_rows > 0 else 'w', encoding='utf-8') as f_out:
|
53 |
+
pf = pq.ParquetFile(parquet_file)
|
54 |
+
progress = tqdm(total=total_rows, initial=processed_rows,
|
55 |
+
desc=f"Processing {os.path.basename(parquet_file)}",
|
56 |
+
position=multiprocessing.current_process()._identity[0] % 10)
|
57 |
+
|
58 |
+
skip_rows = processed_rows
|
59 |
+
current_row = 0
|
60 |
+
|
61 |
+
for batch in pf.iter_batches(batch_size=batch_size):
|
62 |
+
df_batch = batch.to_pandas()
|
63 |
+
|
64 |
+
# 处理当前批次中的每一行
|
65 |
+
for _, row in df_batch.iterrows():
|
66 |
+
current_row += 1
|
67 |
+
|
68 |
+
# 跳过已处理的行
|
69 |
+
if current_row <= skip_rows:
|
70 |
+
continue
|
71 |
+
|
72 |
+
audio_obj = row['audio']
|
73 |
+
audio_data = audio_obj['bytes']
|
74 |
+
transcription = row['transcription']
|
75 |
+
language = row['language']
|
76 |
+
speaker = row['speaker']
|
77 |
+
|
78 |
+
with io.BytesIO(audio_data) as buffer:
|
79 |
+
prompt_data, sample_rate = sf.read(buffer)
|
80 |
+
# 确保是单声道,并转换为float32
|
81 |
+
if len(prompt_data.shape) > 1:
|
82 |
+
prompt_data = prompt_data[:, 0]
|
83 |
+
prompt_data = prompt_data.astype(np.float32)
|
84 |
+
|
85 |
+
# 重采样到16kHz (如果需要)
|
86 |
+
if sample_rate != 16000:
|
87 |
+
prompt_data = resample(prompt_data, orig_sr=sample_rate, target_sr=16000)
|
88 |
+
|
89 |
+
prompt_speech_16k = torch.tensor(prompt_data).unsqueeze(0)
|
90 |
+
|
91 |
+
feat = whisper.log_mel_spectrogram(prompt_speech_16k, n_mels=128)
|
92 |
+
speech_token = speech_tokenizer_session.run(None,
|
93 |
+
{speech_tokenizer_session.get_inputs()[0].name:
|
94 |
+
feat.detach().cpu().numpy(),
|
95 |
+
speech_tokenizer_session.get_inputs()[1].name:
|
96 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
97 |
+
|
98 |
+
# 写入结果
|
99 |
+
f_out.write(json.dumps({'tts_speech_tokens':speech_token,
|
100 |
+
'text':transcription,
|
101 |
+
'language':language,
|
102 |
+
'speaker':speaker,
|
103 |
+
"prompt_text":"",
|
104 |
+
"llm_prompt_speech_token":[]},
|
105 |
+
ensure_ascii=False)+'\n')
|
106 |
+
progress.update(1)
|
107 |
+
|
108 |
+
# 释放内存
|
109 |
+
del df_batch
|
110 |
+
import gc
|
111 |
+
gc.collect()
|
112 |
+
|
113 |
+
return f"Successfully processed {parquet_file}: {total_rows-processed_rows} new rows processed"
|
114 |
+
except Exception as e:
|
115 |
+
return f"Error processing {parquet_file}: {str(e)}"
|
116 |
+
|
117 |
+
if __name__ == '__main__':
|
118 |
+
import argparse
|
119 |
+
parser = argparse.ArgumentParser()
|
120 |
+
parser.add_argument('--data_path', type=str, default='/external_data/yueyudata/starrail-voice')
|
121 |
+
parser.add_argument('--output_path',type=str,default='/external_data/yueyudata/starrail-voice-voice_tokens')
|
122 |
+
parser.add_argument('--speech_tokenizer_model',type=str,default='/external_data/models/CosyVoice2-0.5B_RWKV_1.5B/speech_tokenizer_v2.onnx')
|
123 |
+
parser.add_argument('--device',type=str,default='cuda:0')
|
124 |
+
parser.add_argument('--num_processes',type=int,default=4)
|
125 |
+
args = parser.parse_args()
|
126 |
+
|
127 |
+
data_path = args.data_path
|
128 |
+
output_path = args.output_path
|
129 |
+
device = args.device
|
130 |
+
speech_tokenizer_model = args.speech_tokenizer_model
|
131 |
+
num_processes = args.num_processes
|
132 |
+
|
133 |
+
# 确保输出目录存在
|
134 |
+
os.makedirs(output_path, exist_ok=True)
|
135 |
+
|
136 |
+
# 找到所有parquet文件
|
137 |
+
parquet_files = []
|
138 |
+
for root, dirs, files in os.walk(data_path):
|
139 |
+
for file in files:
|
140 |
+
if file.endswith('.parquet'):
|
141 |
+
parquet_files.append(os.path.join(root, file))
|
142 |
+
print(f'Found {len(parquet_files)} parquet files in {data_path}')
|
143 |
+
|
144 |
+
# 准备多进程参数
|
145 |
+
file_info_list = [(file, output_path, speech_tokenizer_model, device) for file in parquet_files]
|
146 |
+
|
147 |
+
# 使用进程池处理文件
|
148 |
+
print(f"Starting processing with {num_processes} processes")
|
149 |
+
with multiprocessing.Pool(processes=num_processes) as pool:
|
150 |
+
results = pool.map(process_file, file_info_list)
|
151 |
+
|
152 |
+
# 输出处理结果
|
153 |
+
for result in results:
|
154 |
+
print(result)
|
155 |
+
|
156 |
+
print("All files processed successfully!")
|
data/utils/llm_dataset.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
random.seed(time.time())
|
8 |
+
import logging
|
9 |
+
from tqdm import tqdm
|
10 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
11 |
+
|
12 |
+
def verify_jsonl_files(data_files):
|
13 |
+
"""检查每个 jsonl 文件的有效性"""
|
14 |
+
invalid_files = []
|
15 |
+
|
16 |
+
for file_path in tqdm(data_files, desc="验证文件"):
|
17 |
+
try:
|
18 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
19 |
+
for i, line in enumerate(f):
|
20 |
+
try:
|
21 |
+
json.loads(line)
|
22 |
+
except json.JSONDecodeError:
|
23 |
+
invalid_files.append((file_path, i+1))
|
24 |
+
logging.error(f"文件 {file_path} 在第 {i+1} 行有无效的 JSON")
|
25 |
+
break
|
26 |
+
except Exception as e:
|
27 |
+
invalid_files.append((file_path, f"读取错误: {str(e)}"))
|
28 |
+
logging.error(f"无法读取文件 {file_path}: {str(e)}")
|
29 |
+
|
30 |
+
return invalid_files
|
31 |
+
def load_jsonl_dataset(directory,tokenizer):
|
32 |
+
'''
|
33 |
+
load jsonl files in a directory recursively
|
34 |
+
'''
|
35 |
+
data_files = []
|
36 |
+
for root, dirs, files in os.walk(directory):
|
37 |
+
for file in files:
|
38 |
+
if file.endswith('.jsonl'):
|
39 |
+
data_files.append(os.path.join(root, file))
|
40 |
+
|
41 |
+
logging.info(f"找到 {len(data_files)} 个 JSONL 文件")
|
42 |
+
# 验证文件
|
43 |
+
invalid_files = verify_jsonl_files(data_files)
|
44 |
+
if invalid_files:
|
45 |
+
logging.error(f"发现 {len(invalid_files)} 个无效文件:")
|
46 |
+
for file_info in invalid_files:
|
47 |
+
if isinstance(file_info[1], int):
|
48 |
+
logging.error(f" - {file_info[0]} (错误在第 {file_info[1]} 行)")
|
49 |
+
else:
|
50 |
+
logging.error(f" - {file_info[0]} ({file_info[1]})")
|
51 |
+
|
52 |
+
# 移除无效文件
|
53 |
+
valid_files = [f for f in data_files if f not in [info[0] for info in invalid_files]]
|
54 |
+
logging.info(f"继续处理剩余的 {len(valid_files)} 个有效文件")
|
55 |
+
data_files = valid_files
|
56 |
+
# 手动收集所有样本,确保特征一致性
|
57 |
+
all_samples = []
|
58 |
+
|
59 |
+
for file_path in tqdm(data_files, desc="加载数据集"):
|
60 |
+
try:
|
61 |
+
# 手动解析JSONL文件,避免datasets加载时的类型推断问题
|
62 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
63 |
+
for line in f:
|
64 |
+
try:
|
65 |
+
data = json.loads(line)
|
66 |
+
# 确保所有字段存在且类型一致
|
67 |
+
llm_prompt_speech_token = data.get('llm_prompt_speech_token', [])
|
68 |
+
tts_speech_tokens = data.get('tts_speech_tokens', [])
|
69 |
+
text = str(data.get('text', ""))
|
70 |
+
prompt_text = str(data.get('prompt_text', ""))
|
71 |
+
|
72 |
+
# 确保列表类型
|
73 |
+
if not isinstance(llm_prompt_speech_token, list):
|
74 |
+
llm_prompt_speech_token = []
|
75 |
+
if not isinstance(tts_speech_tokens, list):
|
76 |
+
tts_speech_tokens = []
|
77 |
+
|
78 |
+
# 添加处理后的样本
|
79 |
+
all_samples.append({
|
80 |
+
'llm_prompt_speech_token': llm_prompt_speech_token,
|
81 |
+
'tts_speech_tokens': tts_speech_tokens,
|
82 |
+
'text': text,
|
83 |
+
'prompt_text': prompt_text
|
84 |
+
})
|
85 |
+
except json.JSONDecodeError:
|
86 |
+
continue # 跳过无效的JSON行
|
87 |
+
except Exception as e:
|
88 |
+
logging.error(f"处理样本时出错: {str(e)}")
|
89 |
+
except Exception as e:
|
90 |
+
logging.error(f"打开文件 {file_path} 时出错: {str(e)}")
|
91 |
+
|
92 |
+
if not all_samples:
|
93 |
+
raise ValueError("没有成功加载任何样本")
|
94 |
+
|
95 |
+
# 创建数据集
|
96 |
+
logging.info(f"手动创建数据集,包含 {len(all_samples)} 个样本")
|
97 |
+
dataset = datasets.Dataset.from_list(all_samples)
|
98 |
+
|
99 |
+
logging.info(f"成功加载 {len(dataset)} 个样本")
|
100 |
+
|
101 |
+
#1. concatenate llm_prompt_speech_token and tts_speech_tokens (list of int)
|
102 |
+
#delay the concatenation to collate_fn since sometimes we want to drop the prompt
|
103 |
+
# dataset = dataset.map(lambda x: {'speech_token': x['llm_prompt_speech_token'] + x['tts_speech_tokens']},remove_columns=['tts_speech_tokens','llm_prompt_speech_token'])
|
104 |
+
#2. Filter the data either :
|
105 |
+
# 1. the length of the speech_token is less than 1
|
106 |
+
# 2. the length of the speech_token is greater than 1000
|
107 |
+
# 3. the length of the text is greater than 500
|
108 |
+
# 4. the length of the prompt_text is greater than 500
|
109 |
+
# 5. the length of the text_token is less than 1
|
110 |
+
# 6. the length of the prompt_text_token is less than 1
|
111 |
+
dataset = dataset.filter(lambda x:len(x['llm_prompt_speech_token']) < 2048 and len(x['tts_speech_tokens']) < 2048
|
112 |
+
and len(tokenizer.encode(x['text'])) < 2048 and len(tokenizer.encode(x['prompt_text'])) < 2048 )
|
113 |
+
logging.info(f"过滤后剩余 {len(dataset)} 个样本")
|
114 |
+
#2. tokenize the text to text_tokens and prompt_text to prompt_text_tokens
|
115 |
+
# dataset = dataset.map(lambda x: {'text_tokens': tokenizer.encode(x['text']), 'prompt_text_tokens': tokenizer.encode(x['prompt_text'])},remove_columns=['text','prompt_text'])
|
116 |
+
return dataset
|
117 |
+
|
118 |
+
def collate_fn(batch, tokenizer, pad_to_max_length=True, max_length=2048, drop_prompt_audio_rate=-0.1):
|
119 |
+
'''
|
120 |
+
convert the data to torch tensors
|
121 |
+
1. call tokenizer.encode('text') and tokenizer.encode('prompt_text'), concatenate them to get the text_token, record each sample's length to text_token_len
|
122 |
+
2. convert the text_tokens and text_token_len to torch tensor
|
123 |
+
3. record each sample's speech_token length to speech_token_len
|
124 |
+
4. convert the speech_token and speech_token_len to torch tensor
|
125 |
+
5. We will drop prompt with drop_prompt_audio_rate to ask model to learn generate audio without guaidance
|
126 |
+
By default we won't drop anything
|
127 |
+
'''
|
128 |
+
all_text_tokens = []
|
129 |
+
all_speech_tokens = []
|
130 |
+
speech_token_len = []
|
131 |
+
text_token_len = []
|
132 |
+
my_max_length = 0
|
133 |
+
is_drop_prompt = random.random() < drop_prompt_audio_rate
|
134 |
+
|
135 |
+
for sample in batch:
|
136 |
+
tts_speech_tokens = sample['tts_speech_tokens']
|
137 |
+
llm_prompt_speech_token = sample['llm_prompt_speech_token']
|
138 |
+
|
139 |
+
if is_drop_prompt:
|
140 |
+
# 只使用文本部分,不使用提示
|
141 |
+
text_tokens = tokenizer.encode(sample['text'])
|
142 |
+
all_text_tokens.append(torch.tensor(text_tokens, dtype=torch.int32))
|
143 |
+
text_token_len.append(len(text_tokens))
|
144 |
+
|
145 |
+
# 只使用语音部分,不使用提示语音
|
146 |
+
current_speech_tokens = tts_speech_tokens
|
147 |
+
all_speech_tokens.append(torch.tensor(current_speech_tokens, dtype=torch.int32))
|
148 |
+
speech_token_len.append(len(current_speech_tokens))
|
149 |
+
|
150 |
+
total_length = len(text_tokens) + len(current_speech_tokens)
|
151 |
+
else:
|
152 |
+
# 使用提示+文本
|
153 |
+
text_tokens = tokenizer.encode(sample['text'])
|
154 |
+
prompt_tokens = tokenizer.encode(sample['prompt_text'])
|
155 |
+
combined_text_tokens = prompt_tokens + text_tokens
|
156 |
+
all_text_tokens.append(torch.tensor(combined_text_tokens, dtype=torch.int32))
|
157 |
+
text_token_len.append(len(combined_text_tokens))
|
158 |
+
|
159 |
+
# 使用提示语音+语音
|
160 |
+
current_speech_tokens = llm_prompt_speech_token + tts_speech_tokens
|
161 |
+
all_speech_tokens.append(torch.tensor(current_speech_tokens, dtype=torch.int32))
|
162 |
+
speech_token_len.append(len(current_speech_tokens))
|
163 |
+
|
164 |
+
total_length = len(combined_text_tokens) + len(current_speech_tokens)
|
165 |
+
|
166 |
+
if total_length > my_max_length:
|
167 |
+
my_max_length = total_length
|
168 |
+
|
169 |
+
# 检查长度是否超出最大长度
|
170 |
+
skip = my_max_length > max_length
|
171 |
+
|
172 |
+
# 将列表转换为填充后的张量
|
173 |
+
all_text_tokens = torch.nn.utils.rnn.pad_sequence(all_text_tokens, batch_first=True, padding_value=0)
|
174 |
+
all_speech_tokens = torch.nn.utils.rnn.pad_sequence(all_speech_tokens, batch_first=True, padding_value=0)
|
175 |
+
|
176 |
+
# 如果需要填充到最大长度
|
177 |
+
if pad_to_max_length and not skip:
|
178 |
+
pad_length = max_length - my_max_length
|
179 |
+
if pad_length > 0:
|
180 |
+
all_speech_tokens = torch.nn.functional.pad(all_speech_tokens, (0, pad_length), value=0)
|
181 |
+
|
182 |
+
return {
|
183 |
+
'text_token': all_text_tokens,
|
184 |
+
'text_token_len': torch.tensor(text_token_len, dtype=torch.int32),
|
185 |
+
'speech_token': all_speech_tokens, # 确保命名一致
|
186 |
+
'speech_token_len': torch.tensor(speech_token_len, dtype=torch.int32),
|
187 |
+
'skip': skip
|
188 |
+
}
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == '__main__':
|
192 |
+
from transformers import AutoTokenizer
|
193 |
+
model_path = "/external_data/models/rwkv7-2.9B-world"
|
194 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
195 |
+
directory = '/external_data/yueyudata/speech_corpus'
|
196 |
+
dataset = load_jsonl_dataset(directory,tokenizer)
|
197 |
+
print(dataset)
|
198 |
+
print(dataset[0])
|
199 |
+
from functools import partial
|
200 |
+
collate_fn = partial(collate_fn,tokenizer=tokenizer,pad_to_max_length=False)
|
201 |
+
dataloader = torch.utils.data.DataLoader(dataset,batch_size=1,collate_fn=collate_fn)
|
202 |
+
for data in dataloader:
|
203 |
+
print(data)
|
204 |
+
print(data['speech_token'].shape)
|
205 |
+
print(data['text_token'].shape)
|
206 |
+
break
|
data/utils/test_utilities.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data.utils.utilitie import generate_mixed_instructions
|
2 |
+
if __name__ == '__main__':
|
3 |
+
print(generate_mixed_instructions('我来自中国。'))
|
4 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
5 |
+
print(generate_mixed_instructions('I am from China.',language='en'))
|
6 |
+
print(generate_mixed_instructions('This is a city with a long history.',language='en'))
|
7 |
+
print(generate_mixed_instructions('我来自中国。'))
|
8 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
9 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
10 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
11 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
12 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
13 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
14 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
15 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
16 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
17 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
18 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
19 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
20 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
21 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
22 |
+
print(generate_mixed_instructions('这是一个拥有悠久历史的城市。'))
|
23 |
+
print(generate_mixed_instructions('I am from China.',language='en'))
|
24 |
+
print(generate_mixed_instructions('This is a city with a long history.',language='en'))
|
25 |
+
print(generate_mixed_instructions('This is a city with a long history.',language='en'))
|
26 |
+
print(generate_mixed_instructions('This is a city with a long history.',language='en'))
|
27 |
+
print(generate_mixed_instructions('This is a city with a long history.',language='en'))
|
28 |
+
print(generate_mixed_instructions('This is a city with a long history.',language='en'))
|
29 |
+
print(generate_mixed_instructions('This is a city with a long history.',language='en'))
|
30 |
+
print(generate_mixed_instructions('This is a city with a long history.',language='en'))
|
31 |
+
print(generate_mixed_instructions('This is a city with a long history.',language='en'))
|
data/utils/utilitie.py
ADDED
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from concurrent.futures import thread
|
2 |
+
from operator import is_
|
3 |
+
from librosa import ex
|
4 |
+
from regex import P
|
5 |
+
from torch import device
|
6 |
+
from tqdm import tqdm
|
7 |
+
import tarfile
|
8 |
+
import random
|
9 |
+
import time
|
10 |
+
import io
|
11 |
+
import torchaudio
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
import multiprocessing
|
15 |
+
import torch
|
16 |
+
from data.cosy.data.data_processor import init_process, preprocess_prompts
|
17 |
+
import random
|
18 |
+
from typing import List
|
19 |
+
import torch
|
20 |
+
import torchaudio
|
21 |
+
import io
|
22 |
+
|
23 |
+
'''
|
24 |
+
Natural Language Instruction
|
25 |
+
Emotion: 高兴(Happy), 悲伤(Sad), 惊讶(Surprised), 愤怒(Angry), 恐惧(Fearful), 厌恶(Disgusted), 冷
|
26 |
+
静(Calm), 严肃(Serious)
|
27 |
+
Speaking Rate: 快速(Fast), 非常快速(Very Fast), 慢速(Slow), 非常慢速(Very Slow)
|
28 |
+
Dialect: 粤语, 四川话, 上海话, 郑州话, 长沙话, 天津话
|
29 |
+
Role-playing: 神秘(Mysterious), 凶猛(Fierce), 好奇(Curious), 优雅(Elegant), 孤独(Lonely), 机器
|
30 |
+
人(Robot), 小猪佩奇(Peppa), etc.
|
31 |
+
Fine-grained Instruction
|
32 |
+
Vocal Bursts: [laughter], [breath], etc.
|
33 |
+
Vocal Features: <laughter></laughter>, <strong></strong>
|
34 |
+
Examples
|
35 |
+
- 你能用高兴的情感说吗?< |endofprompt| >今天真是太开心了,马上要放假了!I’m so happy,
|
36 |
+
Spring Festival is coming!
|
37 |
+
- Please speaking very fast.< |endofprompt| >Today is a happy day, full of laughter and joy.
|
38 |
+
- 请问你能模仿粤语的口音吗?< |endofprompt| >多保重,早休息。
|
39 |
+
- 尝试一下以机器人的角色和我交流。< |endofprompt| >接收知识光波!
|
40 |
+
- [laughter]有时候,看着小孩子们的天真行为[laughter],我们总会会心一笑。
|
41 |
+
- She pursued her dreams with <strong>enthusiasm</strong> and <strong>grit</strong>.
|
42 |
+
'''
|
43 |
+
|
44 |
+
emotions = ['高兴', '悲伤', '惊讶', '愤怒', '恐惧', '厌恶', '冷静', '严肃']
|
45 |
+
emotions_in_english = ['Happy', 'Sad', 'Surprised', 'Angry', 'Fearful', 'Disgusted', 'Calm', 'Serious']
|
46 |
+
speaking_rates = ['快速', '非常快速', '慢速', '非常慢速']
|
47 |
+
speaking_rates_in_english = ['Fast', 'Very Fast', 'Slow', 'Very Slow']
|
48 |
+
dialects = ['普通话','粤语', '四川话', '上海话', '郑州话', '长沙话', '天津话']
|
49 |
+
dialects_in_english = ['Mandarin','Cantonese', 'Sichuanese', 'Shanghainese', 'Zhengzhou Dialect', 'Changsha Dialect', 'Tianjin Dialect']
|
50 |
+
role_playings = ['神秘', '凶猛', '好奇', '优雅', '孤独', '机器人', '小猪佩奇']
|
51 |
+
role_playings_in_english = ['Mysterious', 'Fierce', 'Curious', 'Elegant', 'Lonely', 'Robot', 'Peppa']
|
52 |
+
vocal_bursts = ['[laughter]', '[breath]']
|
53 |
+
vocal_features = ['<laughter></laughter>', '<strong></strong>']
|
54 |
+
end_of_prompt = '<|endofprompt|>'
|
55 |
+
|
56 |
+
def generate_in_emotion_in_chinese(text :str):
|
57 |
+
templates = [
|
58 |
+
'你能用{}的情感说吗?{}{}',
|
59 |
+
'请用{}的情感说。{}{}',
|
60 |
+
'请用{}的情感表达。{}{}',
|
61 |
+
'请用{}的情感说一下。{}{}',
|
62 |
+
'请用{}的情感说一句。{}{}'
|
63 |
+
]
|
64 |
+
select_emotion = random.choice(emotions)
|
65 |
+
return random.choice(templates).format(select_emotion,end_of_prompt,text)
|
66 |
+
|
67 |
+
def generate_in_emotion_in_english(text :str):
|
68 |
+
templates = [
|
69 |
+
'Can you say it with {} emotion?{}{}',
|
70 |
+
'Please say it with {} emotion.{}{}',
|
71 |
+
'Please express it with {} emotion.{}{}',
|
72 |
+
'Please say it with {} emotion.{}{}',
|
73 |
+
'Please say a sentence with {} emotion.{}{}'
|
74 |
+
]
|
75 |
+
select_emotion = random.choice(emotions_in_english)
|
76 |
+
return random.choice(templates).format(select_emotion,end_of_prompt,text)
|
77 |
+
|
78 |
+
def generate_speaking_rate_in_chinese(text :str):
|
79 |
+
templates = [
|
80 |
+
'请用{}的语速说。{}{}',
|
81 |
+
'请用{}的语速说一下。{}{}',
|
82 |
+
'请用{}的语速说一句。{}{}',
|
83 |
+
'请用{}的语速表达。{}{}',
|
84 |
+
'请用{}的语速说。{}{}',
|
85 |
+
'请{}地说。{}{}',
|
86 |
+
'请{}地说一下。{}{}',
|
87 |
+
'请{}地说一句。{}{}',
|
88 |
+
'{}的说。{}{}',
|
89 |
+
'{}的说一下。{}{}',
|
90 |
+
'{}的说一句。{}{}',
|
91 |
+
'{}的表达。{}{}'
|
92 |
+
|
93 |
+
]
|
94 |
+
select_rate = random.choice(speaking_rates)
|
95 |
+
template = random.choice(templates)
|
96 |
+
return template.format(select_rate,end_of_prompt,text)
|
97 |
+
|
98 |
+
def generate_speaking_rate_in_english(text :str):
|
99 |
+
templates = [
|
100 |
+
'Please say it with {} speaking rate.{}{}',
|
101 |
+
'Say it with {} speaking rate.{}{}',
|
102 |
+
'Please say a sentence with {} speaking rate.{}{}',
|
103 |
+
'Please express it with {} speaking rate.{}{}',
|
104 |
+
'Please speak {}ly.{}{}',
|
105 |
+
'Speak {}ly.{}{}',
|
106 |
+
'Please say it {}ly.{}{}',
|
107 |
+
'Say it {}ly.{}{}'
|
108 |
+
]
|
109 |
+
select_rate = random.choice(speaking_rates_in_english)
|
110 |
+
template = random.choice(templates)
|
111 |
+
return template.format(select_rate,end_of_prompt,text)
|
112 |
+
|
113 |
+
|
114 |
+
def load_file_list(tar_file):
|
115 |
+
#the files are FILE_NAME.mp3/FILE_NAME.json
|
116 |
+
#return all FILE_NAME as a list which has a mp3 and json
|
117 |
+
import tarfile
|
118 |
+
with tarfile.open(tar_file, 'r') as f:
|
119 |
+
file_names = f.getnames()
|
120 |
+
mp3_files = [i for i in file_names if i.endswith('.mp3')]
|
121 |
+
json_files = [i for i in file_names if i.endswith('.json')]
|
122 |
+
|
123 |
+
#filter mp3_files without corresponded json
|
124 |
+
mp3_files = [i for i in mp3_files if i.replace('.mp3', '.json') in json_files]
|
125 |
+
return mp3_files
|
126 |
+
|
127 |
+
def extract_prompt(input_tar_files, input_tar_languages, max_duration=5, num_samples=10, target_sr=16000, output_dir=None):
|
128 |
+
"""
|
129 |
+
Extract prompt from tar files
|
130 |
+
Args:
|
131 |
+
input_tar_files: list of str, input tar files
|
132 |
+
input_tar_languages: list of str, input tar languages for each tar file, must be the same length as input_tar_files
|
133 |
+
max_duration: float, max duration of audio
|
134 |
+
num_samples: int, number of samples to extract
|
135 |
+
target_sr: int, target sample rate
|
136 |
+
output_dir: str, output directory
|
137 |
+
"""
|
138 |
+
for tar_file, language in zip(input_tar_files, input_tar_languages):
|
139 |
+
print(f'Extracting prompt from {tar_file}...with language {language}')
|
140 |
+
random.seed(time.time())
|
141 |
+
samples = []
|
142 |
+
mp3_files = load_file_list(tar_file)
|
143 |
+
with tarfile.open(tar_file, 'r') as f:
|
144 |
+
progress_bar = tqdm(total=num_samples,desc=f'Extracting prompt from {tar_file}')
|
145 |
+
for i in random.sample(mp3_files, len(mp3_files)):
|
146 |
+
mp3 = f.extractfile(i)
|
147 |
+
mp3_bytes = io.BytesIO(mp3.read())
|
148 |
+
speech, sample_rate = torchaudio.load(mp3_bytes,backend='soundfile')
|
149 |
+
json_file = f.extractfile(i.replace('.mp3', '.json'))
|
150 |
+
json_data = json.load(json_file)
|
151 |
+
duration = json_data['duration']
|
152 |
+
if duration > max_duration:
|
153 |
+
continue
|
154 |
+
speech = speech.mean(dim=0, keepdim=True)
|
155 |
+
if sample_rate != target_sr:
|
156 |
+
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
157 |
+
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
158 |
+
samples.append((speech, json_data,sample_rate))
|
159 |
+
progress_bar.update(1)
|
160 |
+
if len(samples) == num_samples:
|
161 |
+
break
|
162 |
+
if output_dir is not None:
|
163 |
+
"""
|
164 |
+
json looks like:
|
165 |
+
{'id': 'ZH_B00000_S01450_W000017', 'wav': 'ZH_B00000/ZH_B00000_S01450/mp3/ZH_B00000_S01450_W000017.mp3', 'text': '因此,我们认为流通性具有更广泛的含义。', 'duration': 4.193, 'speaker': 'ZH_B00000_S01450', 'language': 'zh', 'dnsmos': 3.3709}
|
166 |
+
"""
|
167 |
+
output_dir_lang = os.path.join(output_dir, language)
|
168 |
+
os.makedirs(output_dir_lang, exist_ok=True)
|
169 |
+
progress_bar = tqdm(total=len(samples), desc=f'Saving samples to {output_dir_lang}')
|
170 |
+
for i, (speech, json_data, sample_rate) in enumerate(samples):
|
171 |
+
id = json_data['id']
|
172 |
+
wave_file = os.path.join(output_dir_lang, f'{id}.wav')
|
173 |
+
json_file = os.path.join(output_dir_lang, f'{id}.json')
|
174 |
+
torchaudio.save(wave_file, speech, target_sr)
|
175 |
+
with open(json_file, 'w') as f:
|
176 |
+
json.dump(json_data, f,ensure_ascii=False)
|
177 |
+
progress_bar.update(1)
|
178 |
+
print(f'Extracted {len(samples)} samples from {tar_file} with language {language}')
|
179 |
+
|
180 |
+
def generate_dialect_in_chinese(text: str):
|
181 |
+
templates = [
|
182 |
+
'请问你能模仿{}的口音吗?{}{}',
|
183 |
+
'请用{}的口音说一下。{}{}',
|
184 |
+
'用{}的口音说一句。{}{}',
|
185 |
+
'能用{}的口音读一下吗?{}{}',
|
186 |
+
'请尝试用{}的口音说这段话。{}{}',
|
187 |
+
'请以{}的口音表达。{}{}',
|
188 |
+
'请用{}的语调说。{}{}',
|
189 |
+
'试试用{}的方言说。{}{}',
|
190 |
+
'能否用{}的语调读出来?{}{}',
|
191 |
+
'请说一段{}。{}{}'
|
192 |
+
]
|
193 |
+
select_dialect = random.choice(dialects)
|
194 |
+
return random.choice(templates).format(select_dialect, end_of_prompt, text)
|
195 |
+
|
196 |
+
def generate_dialect_in_english(text: str):
|
197 |
+
templates = [
|
198 |
+
'Can you mimic the {} accent?{}{}',
|
199 |
+
'Please speak with a {} accent.{}{}',
|
200 |
+
'Say it with a {} accent.{}{}',
|
201 |
+
'Could you read this with a {} accent?{}{}',
|
202 |
+
'Please try to speak this with a {} accent.{}{}',
|
203 |
+
'Please express it with a {} accent.{}{}',
|
204 |
+
'Please use {} intonation.{}{}',
|
205 |
+
'Try speaking in {}.{}{}',
|
206 |
+
'Could you read this in {}?{}{}',
|
207 |
+
'Please say a passage in {}.{}{}'
|
208 |
+
]
|
209 |
+
select_dialect = random.choice(dialects_in_english)
|
210 |
+
return random.choice(templates).format(select_dialect, end_of_prompt, text)
|
211 |
+
|
212 |
+
def generate_role_playing_in_chinese(text: str):
|
213 |
+
templates = [
|
214 |
+
'尝试一下以{}的角色和我交流。{}{}',
|
215 |
+
'请以{}的角色说这句话。{}{}',
|
216 |
+
'假装你是{},说一下这句话。{}{}',
|
217 |
+
'扮演{}来说这段话。{}{}',
|
218 |
+
'请用{}的语气说。{}{}',
|
219 |
+
'以{}的形象来表达。{}{}',
|
220 |
+
'你能用{}的方式说吗?{}{}',
|
221 |
+
'模仿{}说话。{}{}',
|
222 |
+
'请���{}的口吻说一下。{}{}',
|
223 |
+
'像{}一样说这句话。{}{}'
|
224 |
+
]
|
225 |
+
select_role = random.choice(role_playings)
|
226 |
+
return random.choice(templates).format(select_role, end_of_prompt, text)
|
227 |
+
|
228 |
+
def generate_role_playing_in_english(text: str):
|
229 |
+
templates = [
|
230 |
+
'Try to communicate with me as a {} character.{}{}',
|
231 |
+
'Please say this as a {} character.{}{}',
|
232 |
+
'Pretend you are {}, say this sentence.{}{}',
|
233 |
+
'Act as {} to say this passage.{}{}',
|
234 |
+
'Please speak with a {} tone.{}{}',
|
235 |
+
'Express this with a {} image.{}{}',
|
236 |
+
'Can you say this in a {} way?{}{}',
|
237 |
+
'Mimic {} speaking.{}{}',
|
238 |
+
'Please say this in the manner of {}.{}{}',
|
239 |
+
'Say this like {}.{}{}'
|
240 |
+
]
|
241 |
+
select_role = random.choice(role_playings_in_english)
|
242 |
+
return random.choice(templates).format(select_role, end_of_prompt, text)
|
243 |
+
|
244 |
+
def generate_vocal_bursts(text: str):
|
245 |
+
"""
|
246 |
+
在文本中随机添加声音爆发标记,如[laughter]、[breath]等
|
247 |
+
"""
|
248 |
+
templates = [
|
249 |
+
'{}{}', # 在句首添加
|
250 |
+
'{}{}{}', # 在句中添加
|
251 |
+
'{}{}' # 在句末添加
|
252 |
+
]
|
253 |
+
|
254 |
+
burst = random.choice(vocal_bursts)
|
255 |
+
template_choice = random.choice(templates)
|
256 |
+
|
257 |
+
if template_choice == '{}{}': # 句首
|
258 |
+
return burst + text
|
259 |
+
elif template_choice == '{}{}{}': # 句中
|
260 |
+
words = text.split()
|
261 |
+
if len(words) <= 3: # 文本太短不分割
|
262 |
+
return burst + text
|
263 |
+
split_point = random.randint(1, len(words) - 1)
|
264 |
+
return ' '.join(words[:split_point]) + ' ' + burst + ' ' + ' '.join(words[split_point:])
|
265 |
+
else: # 句末
|
266 |
+
return text + ' ' + burst
|
267 |
+
|
268 |
+
def generate_vocal_features(text: str):
|
269 |
+
"""
|
270 |
+
在文本中随机添加声音特征标记,如<laughter></laughter>、<strong></strong>等
|
271 |
+
支持中文和英文文本
|
272 |
+
"""
|
273 |
+
feature = random.choice(vocal_features)
|
274 |
+
feature_start, feature_end = feature.split('><')
|
275 |
+
feature_start += '>'
|
276 |
+
feature_end = '<' + feature_end
|
277 |
+
|
278 |
+
# 检查是否为中文文本
|
279 |
+
has_chinese = any('\u4e00' <= char <= '\u9fff' for char in text)
|
280 |
+
|
281 |
+
if has_chinese:
|
282 |
+
# 处理中文文本
|
283 |
+
if len(text) <= 10: # 文本太短,整个加强
|
284 |
+
return feature_start + text + feature_end
|
285 |
+
|
286 |
+
# 对中文处理,随机选择一个字符范围
|
287 |
+
text_len = len(text)
|
288 |
+
# 随机选择一个起始位置和一个范围长度
|
289 |
+
start_pos = random.randint(1, max(1, text_len // 2)) # 避免总是从句首开始
|
290 |
+
span_length = random.randint(1, min(5, text_len - start_pos))
|
291 |
+
end_pos = start_pos + span_length - 1
|
292 |
+
|
293 |
+
# 在选定位置插入标记
|
294 |
+
result = text[:start_pos] + feature_start + text[start_pos:end_pos+1] + feature_end + text[end_pos+1:]
|
295 |
+
return result
|
296 |
+
else:
|
297 |
+
# 处理英文文本
|
298 |
+
words = text.split()
|
299 |
+
if len(words) <= 3: # 文本太短,整个加强
|
300 |
+
return feature_start + text + feature_end
|
301 |
+
|
302 |
+
# 随机选择一个词或短语来添加特征
|
303 |
+
start_idx = random.randint(0, len(words) - 1)
|
304 |
+
span_length = random.randint(1, min(3, len(words) - start_idx)) # 最多3个词
|
305 |
+
|
306 |
+
result = []
|
307 |
+
for i, word in enumerate(words):
|
308 |
+
if i == start_idx:
|
309 |
+
result.append(feature_start + word)
|
310 |
+
elif i == start_idx + span_length - 1:
|
311 |
+
result.append(word + feature_end)
|
312 |
+
else:
|
313 |
+
result.append(word)
|
314 |
+
|
315 |
+
return ' '.join(result)
|
316 |
+
|
317 |
+
def generate_mixed_instructions(text: str, language="zh"):
|
318 |
+
"""
|
319 |
+
混合多种指令类型,可以同时包含情感、语速、方言、角色扮演等
|
320 |
+
"""
|
321 |
+
instruction_generators = []
|
322 |
+
|
323 |
+
if language == "zh":
|
324 |
+
instruction_generators = [
|
325 |
+
generate_in_emotion_in_chinese,
|
326 |
+
generate_speaking_rate_in_chinese,
|
327 |
+
generate_dialect_in_chinese,
|
328 |
+
generate_role_playing_in_chinese
|
329 |
+
]
|
330 |
+
else: # 英文
|
331 |
+
instruction_generators = [
|
332 |
+
generate_in_emotion_in_english,
|
333 |
+
generate_speaking_rate_in_english,
|
334 |
+
generate_dialect_in_english,
|
335 |
+
generate_role_playing_in_english
|
336 |
+
]
|
337 |
+
|
338 |
+
# 随机选择1个generator
|
339 |
+
selected_generator = random.choice(instruction_generators)
|
340 |
+
|
341 |
+
# 可能会添加声音特征
|
342 |
+
text_with_features = text
|
343 |
+
if random.random() < 0.3: # 30%的概率添加声音特征
|
344 |
+
text_with_features = generate_vocal_features(text)
|
345 |
+
|
346 |
+
# 可能会添加声音爆发
|
347 |
+
if random.random() < 0.2: # 20%的概率添加声音爆发
|
348 |
+
text_with_features = generate_vocal_bursts(text_with_features)
|
349 |
+
|
350 |
+
# 应用选择的指令生成器
|
351 |
+
result = text_with_features
|
352 |
+
result = selected_generator(result)
|
353 |
+
|
354 |
+
return result
|
355 |
+
|
356 |
+
frontend = None
|
357 |
+
llm = None
|
358 |
+
cosyvoice = None
|
359 |
+
output_fp = None
|
360 |
+
prompts = None
|
361 |
+
global_device = None
|
362 |
+
processed_count = 0
|
363 |
+
def initialize_process(model_dir,prompts_dir,output_dir,device):
|
364 |
+
current_process = multiprocessing.current_process()
|
365 |
+
file_name = f'{output_dir}/{current_process.pid}.jsonl'
|
366 |
+
global frontend,llm,cosyvoice,output_fp,prompts,global_device
|
367 |
+
global_device = device
|
368 |
+
output_fp = open(file_name, 'w')
|
369 |
+
print(f'Initializing process with device {device} and output file {file_name}')
|
370 |
+
frontend,llm,cosyvoice = init_process(model_dir,device)
|
371 |
+
prompts = preprocess_prompts(frontend,prompts_dir)
|
372 |
+
print(f'load prompts {prompts.keys()}')
|
373 |
+
return frontend,llm,cosyvoice
|
374 |
+
|
375 |
+
def generate_speech_tokens(llm,frontend,tts_text,model_input,device):
|
376 |
+
tts_text = frontend.text_normalize(tts_text,split=False, text_frontend=True)
|
377 |
+
tts_text_token, tts_text_token_len = frontend._extract_text_token(tts_text)
|
378 |
+
tts_text_token_len = torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to(device)
|
379 |
+
prompt_text = model_input['prompt_text'].to(device) if 'prompt_text' in model_input else torch.zeros(1, 0, dtype=torch.int32).to(device)
|
380 |
+
prompt_text_len = torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(device) if prompt_text is not None else torch.zeros(1, 0, dtype=torch.int32).to(device)
|
381 |
+
llm_prompt_speech_token = model_input['llm_prompt_speech_token'].to(device) if 'llm_prompt_speech_token' in model_input else torch.zeros(1, 0, dtype=torch.int32).to(device)
|
382 |
+
prompt_speech_token_len = torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(device) if llm_prompt_speech_token is not None else None
|
383 |
+
flow_prompt_speech_token = model_input['flow_prompt_speech_token'].to(device)
|
384 |
+
prompt_speech_feat = model_input['prompt_speech_feat'].to(device)
|
385 |
+
llm_embedding = model_input['llm_embedding'].to(device)
|
386 |
+
flow_embedding = model_input['flow_embedding'].to(device)
|
387 |
+
speech_tokens = []
|
388 |
+
with torch.no_grad():
|
389 |
+
for i in llm.inference(text = tts_text_token,
|
390 |
+
text_len = tts_text_token_len,
|
391 |
+
prompt_text = prompt_text,
|
392 |
+
prompt_text_len = prompt_text_len,
|
393 |
+
prompt_speech_token = llm_prompt_speech_token,
|
394 |
+
prompt_speech_token_len = prompt_speech_token_len,
|
395 |
+
embedding=llm_embedding
|
396 |
+
):
|
397 |
+
speech_tokens.append(i)
|
398 |
+
return speech_tokens
|
399 |
+
|
400 |
+
def process_text(text,language):
|
401 |
+
global frontend,llm,cosyvoice,output_fp,prompts,processed_count,global_device
|
402 |
+
processed_count += 1
|
403 |
+
if processed_count % 100 == 0:
|
404 |
+
print(f'Processed {processed_count} samples')
|
405 |
+
tts_text = text
|
406 |
+
splits_txt_by_lines = tts_text.split('\n')
|
407 |
+
#remove the sentences with length less than 10
|
408 |
+
splits_txt_by_lines = [i.strip() for i in splits_txt_by_lines if len(i.strip()) > 10]
|
409 |
+
random.seed(time.time())
|
410 |
+
model_input,prompt_text = random.choice(prompts[language])
|
411 |
+
llm_prompt_speech_token = model_input['llm_prompt_speech_token'].cpu().tolist()
|
412 |
+
for tts_text in splits_txt_by_lines:
|
413 |
+
tts_speech_tokens = generate_speech_tokens(llm,frontend,tts_text,model_input,cosyvoice.device)
|
414 |
+
output_data = {
|
415 |
+
'text': tts_text,
|
416 |
+
'tts_speech_tokens': tts_speech_tokens,
|
417 |
+
'prompt_text': prompt_text,
|
418 |
+
'llm_prompt_speech_token': llm_prompt_speech_token[0]
|
419 |
+
}
|
420 |
+
output_fp.write(json.dumps(output_data,ensure_ascii=False)+'\n')
|
421 |
+
output_fp.flush()
|
422 |
+
return processed_count
|
423 |
+
def process_jsonl_file(jsonl_file,language,process_pool):
|
424 |
+
print(f'Processing {jsonl_file}...')
|
425 |
+
count = 0
|
426 |
+
import json
|
427 |
+
with open(jsonl_file, 'r') as f:
|
428 |
+
for line in f:
|
429 |
+
line = line.strip()
|
430 |
+
if len(line) == 0:
|
431 |
+
continue
|
432 |
+
data = json.loads(line)
|
433 |
+
text = data['text']
|
434 |
+
count += 1
|
435 |
+
future = process_pool.submit(process_text,text,language)
|
436 |
+
print(f'processed {future.result()} requests')
|
437 |
+
print(f'Processed {count} samples from {jsonl_file}')
|
438 |
+
return count
|
439 |
+
|
440 |
+
def process_parquet_file(parquet_file,language,process_pool):
|
441 |
+
print(f'Processing {parquet_file}...')
|
442 |
+
import pandas as pd
|
443 |
+
df = pd.read_parquet(parquet_file)
|
444 |
+
count = 0
|
445 |
+
for i in range(len(df)):
|
446 |
+
text = df.iloc[i]['text']
|
447 |
+
count += 1
|
448 |
+
future = process_pool.submit(process_text,text,language)
|
449 |
+
print(f'processed {future.result()} requests')
|
450 |
+
print(f'Processed {count} samples from {parquet_file}')
|
451 |
+
return count
|
452 |
+
|
453 |
+
def generate_speech_tokens_single_process(cosy_model_dir, prompts_dir, output_dir, language, jsonl_files=None, parquet_files=None, device="cuda:0",is_cross_lingual=False,is_instructed=False):
|
454 |
+
"""
|
455 |
+
单进程单线程版本的语音标记生成函数
|
456 |
+
"""
|
457 |
+
import torch
|
458 |
+
import json
|
459 |
+
import os
|
460 |
+
import random
|
461 |
+
import time
|
462 |
+
import traceback
|
463 |
+
import logging
|
464 |
+
import sys
|
465 |
+
from datetime import datetime
|
466 |
+
from data.cosy.data.data_processor import init_process, preprocess_prompts
|
467 |
+
|
468 |
+
# 设置日志
|
469 |
+
output_dir_lang = os.path.join(output_dir, language)
|
470 |
+
os.makedirs(output_dir_lang, exist_ok=True)
|
471 |
+
process_id = os.getpid()
|
472 |
+
log_file = os.path.join(output_dir_lang, f'process_{process_id}_log.txt')
|
473 |
+
|
474 |
+
# 配置日志输出到文件和控制台
|
475 |
+
logging.basicConfig(
|
476 |
+
level=logging.INFO,
|
477 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
478 |
+
handlers=[
|
479 |
+
logging.FileHandler(log_file),
|
480 |
+
logging.StreamHandler(sys.stdout)
|
481 |
+
]
|
482 |
+
)
|
483 |
+
logger = logging.getLogger(f'process_{process_id}')
|
484 |
+
|
485 |
+
# 记录启动信息
|
486 |
+
logger.info(f"='='='='='='='='='='='Instructed={is_instructed}'='='='='='='='='='='='='='='='='='")
|
487 |
+
logger.info(f"启动时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
488 |
+
logger.info(f"进程ID: {process_id}")
|
489 |
+
logger.info(f"设备: {device}")
|
490 |
+
logger.info(f"模型目录: {cosy_model_dir}")
|
491 |
+
logger.info(f"提示词目录: {prompts_dir}")
|
492 |
+
logger.info(f"输出目录: {output_dir_lang}")
|
493 |
+
if jsonl_files:
|
494 |
+
logger.info(f"JSONL文件: {jsonl_files}")
|
495 |
+
if parquet_files:
|
496 |
+
logger.info(f"Parquet文件: {parquet_files}")
|
497 |
+
logger.info(f"='='='='='='='='='='='='='='='='='='='='='='='='='='='='='")
|
498 |
+
|
499 |
+
output_fp = None
|
500 |
+
frontend = None
|
501 |
+
llm = None
|
502 |
+
cosyvoice = None
|
503 |
+
total_processed = 0
|
504 |
+
|
505 |
+
try:
|
506 |
+
# 初始化模型
|
507 |
+
logger.info(f'初始化模型,使用设备: {device}')
|
508 |
+
frontend, llm, cosyvoice = init_process(cosy_model_dir, device)
|
509 |
+
|
510 |
+
# 预处理提示
|
511 |
+
logger.info(f'开始预处理提示词')
|
512 |
+
prompts = preprocess_prompts(frontend, prompts_dir)
|
513 |
+
logger.info(f'加载提示完成: {prompts.keys()}')
|
514 |
+
|
515 |
+
output_file = os.path.join(output_dir_lang, f'{process_id}.jsonl')
|
516 |
+
output_fp = open(output_file, 'w')
|
517 |
+
|
518 |
+
# 处理函数
|
519 |
+
def process_single_text(text):
|
520 |
+
try:
|
521 |
+
tts_text = text
|
522 |
+
splits_txt_by_lines = tts_text.split('\n')
|
523 |
+
# 删除长度小于10的句子
|
524 |
+
splits_txt_by_lines = [i.strip() for i in splits_txt_by_lines if len(i.strip()) > 10]
|
525 |
+
|
526 |
+
if not splits_txt_by_lines:
|
527 |
+
logger.warning(f"文本没有有效句子: '{text[:100]}...'")
|
528 |
+
return 0
|
529 |
+
|
530 |
+
random.seed(time.time())
|
531 |
+
cross_linguals_map = {
|
532 |
+
'zh': 'en',
|
533 |
+
'en': 'zh'
|
534 |
+
}
|
535 |
+
try:
|
536 |
+
model_input, prompt_text = random.choice(prompts[language if not is_cross_lingual else cross_linguals_map[language]])
|
537 |
+
except KeyError:
|
538 |
+
logger.error(f"语言 '{language}' 在提示词中不存在! 可用语言: {list(prompts.keys())}")
|
539 |
+
return 0
|
540 |
+
|
541 |
+
llm_prompt_speech_token = model_input['llm_prompt_speech_token'].cpu().tolist() if 'llm_prompt_speech_token' in model_input else []
|
542 |
+
|
543 |
+
processed_count = 0
|
544 |
+
for tts_text in splits_txt_by_lines:
|
545 |
+
try:
|
546 |
+
if is_instructed:
|
547 |
+
tts_text = generate_mixed_instructions(tts_text, language)
|
548 |
+
prompt_text = ""
|
549 |
+
llm_prompt_speech_token[0]=[]
|
550 |
+
if 'prompt_text' in model_input:
|
551 |
+
del model_input['prompt_text']
|
552 |
+
if 'prompt_text_len' in model_input:
|
553 |
+
del model_input['prompt_text_len']
|
554 |
+
if 'llm_prompt_speech_token' in model_input:
|
555 |
+
del model_input['llm_prompt_speech_token']
|
556 |
+
if 'llm_prompt_speech_token_len' in model_input:
|
557 |
+
del model_input['llm_prompt_speech_token_len']
|
558 |
+
# 生成语音标记
|
559 |
+
tts_speech_tokens = generate_speech_tokens(llm, frontend, tts_text, model_input, device)
|
560 |
+
output_data = {
|
561 |
+
'text': tts_text,
|
562 |
+
'tts_speech_tokens': tts_speech_tokens,
|
563 |
+
'prompt_text': prompt_text,
|
564 |
+
'llm_prompt_speech_token': llm_prompt_speech_token[0]
|
565 |
+
}
|
566 |
+
output_fp.write(json.dumps(output_data, ensure_ascii=False) + '\n')
|
567 |
+
output_fp.flush()
|
568 |
+
processed_count += 1
|
569 |
+
except Exception as e:
|
570 |
+
logger.error(f"处理单个句子时出错: '{tts_text[:100]}...'")
|
571 |
+
logger.error(f"错误信息: {str(e)}")
|
572 |
+
logger.error(traceback.format_exc())
|
573 |
+
|
574 |
+
return processed_count
|
575 |
+
except Exception as e:
|
576 |
+
logger.error(f"处理文本块时出错")
|
577 |
+
logger.error(f"错误信息: {str(e)}")
|
578 |
+
logger.error(traceback.format_exc())
|
579 |
+
return 0
|
580 |
+
|
581 |
+
# 收集要处理的文件
|
582 |
+
files_to_process = []
|
583 |
+
|
584 |
+
# 处理JSONL文件
|
585 |
+
if jsonl_files is not None:
|
586 |
+
logger.info(f"处理指定的JSONL文件")
|
587 |
+
for file in jsonl_files:
|
588 |
+
if file.endswith('.jsonl'):
|
589 |
+
files_to_process.append(('jsonl', file))
|
590 |
+
logger.info(f"共有 {len([f for t, f in files_to_process if t == 'jsonl'])} 个JSONL文件需要处理")
|
591 |
+
|
592 |
+
# 处理Parquet文件
|
593 |
+
if parquet_files is not None:
|
594 |
+
logger.info(f"处理指定的Parquet文件")
|
595 |
+
for file in parquet_files:
|
596 |
+
if file.endswith('.parquet'):
|
597 |
+
files_to_process.append(('parquet', file))
|
598 |
+
logger.info(f"共有 {len([f for t, f in files_to_process if t == 'parquet'])} 个Parquet文件需要处理")
|
599 |
+
|
600 |
+
# 顺序处理所有文件
|
601 |
+
for file_type, file_path in files_to_process:
|
602 |
+
logger.info(f'开始处理文件: {file_path}')
|
603 |
+
try:
|
604 |
+
if file_type == 'jsonl':
|
605 |
+
# 处理JSONL文件
|
606 |
+
# 首先计算文件总行数,用于进度条
|
607 |
+
total_lines = 0
|
608 |
+
with open(file_path, 'r') as f:
|
609 |
+
for line in f:
|
610 |
+
if line.strip(): # 只计算非空行
|
611 |
+
total_lines += 1
|
612 |
+
|
613 |
+
logger.info(f"JSONL文件 {file_path} 共有 {total_lines} 行")
|
614 |
+
# 使用进度条处理文件
|
615 |
+
with open(file_path, 'r') as f:
|
616 |
+
from tqdm import tqdm
|
617 |
+
progress_bar = tqdm(total=total_lines, desc=f'处理JSONL文件: {os.path.basename(file_path)}')
|
618 |
+
file_processed = 0
|
619 |
+
for line in f:
|
620 |
+
line = line.strip()
|
621 |
+
if len(line) == 0:
|
622 |
+
continue
|
623 |
+
try:
|
624 |
+
data = json.loads(line)
|
625 |
+
text = data['text']
|
626 |
+
processed = process_single_text(text)
|
627 |
+
total_processed += processed
|
628 |
+
file_processed += processed
|
629 |
+
progress_bar.update(1)
|
630 |
+
progress_bar.set_postfix(total=total_processed)
|
631 |
+
except Exception as e:
|
632 |
+
logger.error(f"处理JSONL行时出错: {line[:100]}...")
|
633 |
+
logger.error(f"错误信息: {str(e)}")
|
634 |
+
logger.error(traceback.format_exc())
|
635 |
+
progress_bar.close()
|
636 |
+
logger.info(f"JSONL文件 {file_path} 完成处理,成功处理 {file_processed} 条记录")
|
637 |
+
|
638 |
+
elif file_type == 'parquet':
|
639 |
+
# 处理Parquet文件
|
640 |
+
try:
|
641 |
+
import pandas as pd
|
642 |
+
logger.info(f"加载Parquet文件: {file_path}")
|
643 |
+
df = pd.read_parquet(file_path)
|
644 |
+
logger.info(f"Parquet文件 {file_path} 共有 {len(df)} 行")
|
645 |
+
|
646 |
+
from tqdm import tqdm
|
647 |
+
progress_bar = tqdm(total=len(df), desc=f'处理Parquet文件: {os.path.basename(file_path)}')
|
648 |
+
file_processed = 0
|
649 |
+
for i in range(len(df)):
|
650 |
+
try:
|
651 |
+
text = df.iloc[i]['text']
|
652 |
+
processed = process_single_text(text)
|
653 |
+
total_processed += processed
|
654 |
+
file_processed += processed
|
655 |
+
progress_bar.update(1)
|
656 |
+
progress_bar.set_postfix(total=total_processed)
|
657 |
+
except Exception as e:
|
658 |
+
logger.error(f"处理Parquet行 {i} 时出错")
|
659 |
+
logger.error(f"错误信息: {str(e)}")
|
660 |
+
logger.error(traceback.format_exc())
|
661 |
+
progress_bar.close()
|
662 |
+
logger.info(f"Parquet文件 {file_path} 完成处理,成功处理 {file_processed} 条记录")
|
663 |
+
except ImportError:
|
664 |
+
logger.error("处理Parquet文件需要pandas库,请安装: pip install pandas")
|
665 |
+
except Exception as e:
|
666 |
+
logger.error(f"处理Parquet文件 {file_path} 时出现错误")
|
667 |
+
logger.error(f"错误信息: {str(e)}")
|
668 |
+
logger.error(traceback.format_exc())
|
669 |
+
except Exception as e:
|
670 |
+
logger.error(f"处理文件 {file_path} 时出现错误")
|
671 |
+
logger.error(f"错误信息: {str(e)}")
|
672 |
+
logger.error(traceback.format_exc())
|
673 |
+
|
674 |
+
logger.info(f'总共成功处理 {total_processed} 个样本,结果保存到 {output_file}')
|
675 |
+
|
676 |
+
except Exception as e:
|
677 |
+
logger.error("处理过程中出现全局错误")
|
678 |
+
logger.error(f"错误信息: {str(e)}")
|
679 |
+
logger.error(traceback.format_exc())
|
680 |
+
|
681 |
+
finally:
|
682 |
+
# 确保资源正确关闭
|
683 |
+
logger.info("清理资源...")
|
684 |
+
if output_fp is not None:
|
685 |
+
try:
|
686 |
+
output_fp.close()
|
687 |
+
logger.info(f"关闭输出文件")
|
688 |
+
except Exception as e:
|
689 |
+
logger.error(f"关闭输出文件时出错: {str(e)}")
|
690 |
+
|
691 |
+
# 释放GPU资源
|
692 |
+
if torch.cuda.is_available():
|
693 |
+
try:
|
694 |
+
torch.cuda.empty_cache()
|
695 |
+
logger.info("已清理GPU缓存")
|
696 |
+
except Exception as e:
|
697 |
+
logger.error(f"清理GPU缓存时出错: {str(e)}")
|
698 |
+
|
699 |
+
logger.info(f"处理结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
700 |
+
logger.info(f"='='='='='='='='='='='='='='='='='='='='='='='='='='='='='")
|
701 |
+
|
702 |
+
if __name__ == '__main__':
|
703 |
+
import argparse
|
704 |
+
"""
|
705 |
+
Parse arguments
|
706 |
+
task: str, including 'extract_prompt'
|
707 |
+
input_tar_files: list of str, input tar files
|
708 |
+
input_tar_languages: list of str, input tar languages for each tar file, must be the same length as input_tar_files
|
709 |
+
max_duration: float, max duration of audio
|
710 |
+
num_samples: int, number of samples to extract
|
711 |
+
target_sr: int, target sample rate
|
712 |
+
output_dir: str, output directory
|
713 |
+
num_processes: int, number of processes to use
|
714 |
+
prompt_dir: str, prompt directory which contains prompt jsonl files and audio files
|
715 |
+
language: str, language, zh or en
|
716 |
+
cosy_model_dir: str, cosy model directory
|
717 |
+
device: str, cuda device used to extract speech tokens
|
718 |
+
jsonl_files: list of str, jsonl files
|
719 |
+
parquet_files: list of str, parquet files
|
720 |
+
"""
|
721 |
+
parser = argparse.ArgumentParser()
|
722 |
+
parser.add_argument('--task', type=str, help='task')
|
723 |
+
parser.add_argument('--input_tar_files', nargs='+', type=str, help='input tar files')
|
724 |
+
parser.add_argument('--input_tar_languages', nargs='+', type=str, help='input tar languages for each tar file')
|
725 |
+
parser.add_argument('--output_dir', type=str, help='output directory',required=True)
|
726 |
+
parser.add_argument('--max_duration', type=float, default=5, help='max duration of audio')
|
727 |
+
parser.add_argument('--num_samples', type=int, default=10, help='number of samples to extract')
|
728 |
+
parser.add_argument('--target_sr', type=int, default=16000, help='target sample rate')
|
729 |
+
parser.add_argument('--num_processes', type=int, default=1, help='number of processes to use')
|
730 |
+
parser.add_argument('--prompts_dir', type=str, help='prompt directory which contains prompt jsonl files and audio files')
|
731 |
+
parser.add_argument('--language', type=str, help='language')
|
732 |
+
parser.add_argument('--cosy_model_dir', type=str, help='cosy model directory')
|
733 |
+
parser.add_argument('--device', type=str, help='cuda device used to extract speech tokens')
|
734 |
+
parser.add_argument('--jsonl_files', nargs='+', type=str, help='jsonl files')
|
735 |
+
parser.add_argument('--parquet_files', nargs='+', type=str, help='parquet files')
|
736 |
+
parser.add_argument('--is_cross_lingual', action='store_true', help='is cross lingual')
|
737 |
+
parser.add_argument('--is_instructed', action='store_true', help='is instructed')
|
738 |
+
args = parser.parse_args()
|
739 |
+
task = args.task
|
740 |
+
if task == 'extract_prompt':
|
741 |
+
input_tar_files = args.input_tar_files
|
742 |
+
input_tar_languages = args.input_tar_languages
|
743 |
+
output_dir = args.output_dir
|
744 |
+
assert len(input_tar_files) == len(input_tar_languages), 'input_tar_files and input_tar_languages must have the same length'
|
745 |
+
extract_prompt(input_tar_files, input_tar_languages, args.max_duration, args.num_samples, args.target_sr, output_dir)
|
746 |
+
elif task == 'generate_speech_tokens':
|
747 |
+
prompts_dir = args.prompts_dir
|
748 |
+
language = args.language
|
749 |
+
cosy_model_dir = args.cosy_model_dir
|
750 |
+
jsonl_files = args.jsonl_files
|
751 |
+
parquet_files = args.parquet_files
|
752 |
+
device = args.device
|
753 |
+
is_cross_lingual = args.is_cross_lingual
|
754 |
+
is_instructed = args.is_instructed
|
755 |
+
# 使用单进程单线程版本替代多进程版本
|
756 |
+
generate_speech_tokens_single_process(
|
757 |
+
cosy_model_dir=cosy_model_dir,
|
758 |
+
prompts_dir=prompts_dir,
|
759 |
+
output_dir=args.output_dir,
|
760 |
+
language=language,
|
761 |
+
jsonl_files=jsonl_files,
|
762 |
+
parquet_files=parquet_files,
|
763 |
+
device=device,
|
764 |
+
is_cross_lingual=is_cross_lingual,
|
765 |
+
is_instructed=is_instructed,
|
766 |
+
)
|
767 |
+
|
eval/eval_seed_generate.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Download the evaluation file from:https://drive.google.com/file/d/1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP/edit
|
2 |
+
import os
|
3 |
+
voice_engine = None
|
4 |
+
def init_process_func(model_path,device):
|
5 |
+
global voice_engine
|
6 |
+
from cosyvoice.cli.cosyvoice import CosyVoice2
|
7 |
+
voice_engine = CosyVoice2(model_path,device=device,fp16=False,load_jit=False)
|
8 |
+
print(f'Finish loading cosyvoice model from {model_path} in process {os.getpid()}')
|
9 |
+
def do_tts(ID,tts_text,prompt_text,prompt_audio_file,output_dir):
|
10 |
+
from cosyvoice.utils.file_utils import load_wav
|
11 |
+
import torchaudio
|
12 |
+
global voice_engine
|
13 |
+
try:
|
14 |
+
final_output_file = os.path.join(output_dir,f'{ID}.wav')
|
15 |
+
prompt_speech_16k = load_wav(prompt_audio_file, 16000)
|
16 |
+
for output in voice_engine.inference_zero_shot(tts_text,prompt_text, prompt_speech_16k, stream=False,speed=1):
|
17 |
+
torchaudio.save(final_output_file, output['tts_speech'], voice_engine.sample_rate)
|
18 |
+
break # only save the first output
|
19 |
+
print(f'TTS {tts_text} and Save to {final_output_file} at process {os.getpid()}')
|
20 |
+
except Exception as e:
|
21 |
+
print(f'Error: {e}')
|
22 |
+
print(f'Error processing {ID} at process {os.getpid()}')
|
23 |
+
import traceback
|
24 |
+
traceback.print_exc()
|
25 |
+
return
|
26 |
+
if __name__ == '__main__':
|
27 |
+
import argparse
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument("--eval_dir", type=str, default='eval_data/seedtts_testset')
|
30 |
+
parser.add_argument("--language", type=str, default='zh',choices=['zh','en'])
|
31 |
+
parser.add_argument("--model_path", type=str, default='/home/yueyulin/models/CosyVoice2-0.5B_RWKV_1.5B/')
|
32 |
+
parser.add_argument("--device", type=str, default='cuda:0')
|
33 |
+
parser.add_argument("--num_processes", type=int, default=2)
|
34 |
+
parser.add_argument("--output_dir", type=str, default='generated')
|
35 |
+
parser.add_argument("--list_file", type=str, default='meta.lst')
|
36 |
+
|
37 |
+
|
38 |
+
args = parser.parse_args()
|
39 |
+
print(args)
|
40 |
+
output_dir = os.path.join(args.eval_dir,args.language,args.output_dir)
|
41 |
+
#first delete the output_dir
|
42 |
+
if os.path.exists(output_dir):
|
43 |
+
import shutil
|
44 |
+
shutil.rmtree(output_dir)
|
45 |
+
os.makedirs(output_dir)
|
46 |
+
list_file = os.path.join(args.eval_dir,args.language,args.list_file)
|
47 |
+
with open(list_file) as f:
|
48 |
+
lines = f.readlines()
|
49 |
+
lines = [line.strip() for line in lines]
|
50 |
+
print(f'Processing {len(lines)} lines')
|
51 |
+
|
52 |
+
from multiprocessing import Pool
|
53 |
+
from functools import partial
|
54 |
+
import time
|
55 |
+
with Pool(args.num_processes,init_process_func,(args.model_path,args.device)) as p:
|
56 |
+
for line in lines:
|
57 |
+
# 10002287-00000095|在此奉劝大家别乱打美白针。|prompt-wavs/10002287-00000094.wav|简单地说,这相当于惠普把消费领域市场拱手相让了。
|
58 |
+
parts = line.split('|')
|
59 |
+
ID = parts[0]
|
60 |
+
tts_text = parts[3]
|
61 |
+
prompt_text = parts[1]
|
62 |
+
prompt_audio_file = os.path.join(args.eval_dir,args.language,parts[2])
|
63 |
+
p.apply_async(do_tts,(ID,tts_text,prompt_text,prompt_audio_file,output_dir))
|
64 |
+
p.close()
|
65 |
+
p.join()
|
66 |
+
print('All done')
|
gradio/tts_demo_page.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import gradio as gr
|
6 |
+
from cosyvoice.cli.cosyvoice import CosyVoice2
|
7 |
+
from cosyvoice.utils.file_utils import load_wav
|
8 |
+
|
9 |
+
# 全局变量
|
10 |
+
model_path = '/external_data/models/CosyVoice2-0.5B_RWKV_0.19B/'
|
11 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
12 |
+
|
13 |
+
# 在应用启动时初始化模型(全局共享)
|
14 |
+
print("正在初始化 CosyVoice2 模型...")
|
15 |
+
cosyvoice = CosyVoice2(model_path, device=device, fp16=True)
|
16 |
+
# 预热模型
|
17 |
+
cosyvoice.model.llm.dummy_forward()
|
18 |
+
print("模型初始化完成!")
|
19 |
+
|
20 |
+
def synthesize_speech(audio_file, prompt_text, tts_text):
|
21 |
+
"""合成语音"""
|
22 |
+
global cosyvoice
|
23 |
+
|
24 |
+
if not audio_file or not prompt_text or not tts_text:
|
25 |
+
return None, "请提供所有必需的输入(提示音频、提示文本和要合成的文本)"
|
26 |
+
|
27 |
+
try:
|
28 |
+
# 加载提示音频
|
29 |
+
prompt_speech_16k = load_wav(audio_file, 16000)
|
30 |
+
|
31 |
+
# 执行推理
|
32 |
+
result = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=False)
|
33 |
+
|
34 |
+
# 获取合成的语音
|
35 |
+
output_speech = result[0]['tts_speech']
|
36 |
+
|
37 |
+
# 保存临时文件
|
38 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
|
39 |
+
temp_file.close()
|
40 |
+
torchaudio.save(temp_file.name, output_speech, cosyvoice.sample_rate)
|
41 |
+
|
42 |
+
return temp_file.name, f"语音合成成功!"
|
43 |
+
except Exception as e:
|
44 |
+
return None, f"合成过程中出错:{str(e)}"
|
45 |
+
|
46 |
+
# 创建 Gradio 界面
|
47 |
+
with gr.Blocks(title="RWKV TTS 演示") as demo:
|
48 |
+
gr.Markdown("# RWKV 语音合成演示")
|
49 |
+
gr.Markdown("### 语音合成系统已准备就绪,可直接使用")
|
50 |
+
|
51 |
+
with gr.Row():
|
52 |
+
with gr.Column():
|
53 |
+
audio_input = gr.Audio(type="filepath", label="上传提示音频文件(WAV 格式)")
|
54 |
+
prompt_text = gr.Textbox(label="提示文本(与提示音频对应的文字内容)", placeholder="例如:今天天气挺不错的。")
|
55 |
+
tts_text = gr.Textbox(label="要合成的文本", placeholder="例如:收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。")
|
56 |
+
synthesize_button = gr.Button("生成语音")
|
57 |
+
|
58 |
+
with gr.Column():
|
59 |
+
audio_output = gr.Audio(label="合成的语音")
|
60 |
+
output_message = gr.Textbox(label="状态信息")
|
61 |
+
|
62 |
+
synthesize_button.click(
|
63 |
+
fn=synthesize_speech,
|
64 |
+
inputs=[audio_input, prompt_text, tts_text],
|
65 |
+
outputs=[audio_output, output_message]
|
66 |
+
)
|
67 |
+
|
68 |
+
gr.Markdown("""
|
69 |
+
## 使用说明
|
70 |
+
|
71 |
+
1. 上传一个WAV格式的提示音频文件
|
72 |
+
2. 输入与提示音频对应的文本内容
|
73 |
+
3. 输入希望合成的文本
|
74 |
+
4. 点击"生成语音"按钮进行语音合成
|
75 |
+
|
76 |
+
注意:模型已在服务启动时预加载,所有用户共享同一个模型实例。
|
77 |
+
""")
|
78 |
+
|
79 |
+
# 启动应用
|
80 |
+
if __name__ == "__main__":
|
81 |
+
demo.launch()
|
mine.wav
ADDED
Binary file (97 kB). View file
|
|
new.mp3
ADDED
Binary file (25.7 kB). View file
|
|
new.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e62a130a15a7560ebf8c1bd73212a9d6410a50e595de9a809bc64775a4a6f07
|
3 |
+
size 141964
|
run_multiple_process.sh
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export PYTHONPATH=/home/yueyulin/github/CosyVoice:/home/yueyulin/github/CosyVoice/third_party/Matcha-TTS/:/home/yueyulin/github/RWKVTTS
|
2 |
+
|
3 |
+
# 设置默认参数
|
4 |
+
LANGUAGE="zh"
|
5 |
+
OUTPUT_DIR="/home/yueyulin/data/speech_corpus"
|
6 |
+
COSY_MODEL_DIR="/home/yueyulin/models/CosyVoice2-0.5B/"
|
7 |
+
PROMPTS_DIR="extract_data/prompts/zh"
|
8 |
+
DEVICE="cuda:0"
|
9 |
+
PARQUET_FILES=()
|
10 |
+
JSONL_FILES=()
|
11 |
+
FILE_TYPE="" # 用于标记文件类型
|
12 |
+
is_cross_lingual=""
|
13 |
+
is_instructed=""
|
14 |
+
|
15 |
+
# 解析命令行参数
|
16 |
+
while [[ $# -gt 0 ]]; do
|
17 |
+
case $1 in
|
18 |
+
--language)
|
19 |
+
LANGUAGE="$2"
|
20 |
+
shift 2
|
21 |
+
;;
|
22 |
+
--output_dir)
|
23 |
+
OUTPUT_DIR="$2"
|
24 |
+
shift 2
|
25 |
+
;;
|
26 |
+
--cosy_model_dir)
|
27 |
+
COSY_MODEL_DIR="$2"
|
28 |
+
shift 2
|
29 |
+
;;
|
30 |
+
--prompts_dir)
|
31 |
+
PROMPTS_DIR="$2"
|
32 |
+
shift 2
|
33 |
+
;;
|
34 |
+
--parquet_files)
|
35 |
+
# 接收多个parquet文件路径
|
36 |
+
shift
|
37 |
+
while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
|
38 |
+
PARQUET_FILES+=("$1")
|
39 |
+
shift
|
40 |
+
done
|
41 |
+
FILE_TYPE="parquet"
|
42 |
+
;;
|
43 |
+
--jsonl_files)
|
44 |
+
# 接收多个jsonl文件路径
|
45 |
+
shift
|
46 |
+
while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
|
47 |
+
JSONL_FILES+=("$1")
|
48 |
+
shift
|
49 |
+
done
|
50 |
+
FILE_TYPE="jsonl"
|
51 |
+
;;
|
52 |
+
--device)
|
53 |
+
DEVICE="$2"
|
54 |
+
shift 2
|
55 |
+
;;
|
56 |
+
--cross_lingual)
|
57 |
+
is_cross_lingual="--is_cross_lingual"
|
58 |
+
shift
|
59 |
+
;;
|
60 |
+
--instructed)
|
61 |
+
is_instructed="--is_instructed"
|
62 |
+
shift
|
63 |
+
;;
|
64 |
+
*)
|
65 |
+
echo "未知参数: $1"
|
66 |
+
exit 1
|
67 |
+
;;
|
68 |
+
esac
|
69 |
+
done
|
70 |
+
|
71 |
+
# 检查是否提供了文件
|
72 |
+
if [ "$FILE_TYPE" == "parquet" ]; then
|
73 |
+
if [ ${#PARQUET_FILES[@]} -eq 0 ]; then
|
74 |
+
echo "错误: 未指定parquet文件,请使用 --parquet_files 参数"
|
75 |
+
exit 1
|
76 |
+
fi
|
77 |
+
FILES=("${PARQUET_FILES[@]}")
|
78 |
+
FILE_ARG="--parquet_files"
|
79 |
+
echo "将处理 ${#FILES[@]} 个parquet文件"
|
80 |
+
elif [ "$FILE_TYPE" == "jsonl" ]; then
|
81 |
+
if [ ${#JSONL_FILES[@]} -eq 0 ]; then
|
82 |
+
echo "错误: 未指定jsonl文件,请使用 --jsonl_files 参数"
|
83 |
+
exit 1
|
84 |
+
fi
|
85 |
+
FILES=("${JSONL_FILES[@]}")
|
86 |
+
FILE_ARG="--jsonl_files"
|
87 |
+
echo "将处理 ${#FILES[@]} 个jsonl文件"
|
88 |
+
else
|
89 |
+
echo "错误: 请使用 --parquet_files 或 --jsonl_files 参数指定输入文件"
|
90 |
+
exit 1
|
91 |
+
fi
|
92 |
+
|
93 |
+
echo "运行参数:"
|
94 |
+
echo "语言: $LANGUAGE"
|
95 |
+
echo "输出目录: $OUTPUT_DIR"
|
96 |
+
echo "模型目录: $COSY_MODEL_DIR"
|
97 |
+
echo "提示词目录: $PROMPTS_DIR"
|
98 |
+
echo "设备: $DEVICE"
|
99 |
+
echo "文件类型: $FILE_TYPE"
|
100 |
+
|
101 |
+
# 确保输出目录存在
|
102 |
+
mkdir -p $OUTPUT_DIR
|
103 |
+
|
104 |
+
# 启动处理进程,每个文件一个进程
|
105 |
+
for ((i=0; i<${#FILES[@]}; i++)); do
|
106 |
+
FILE="${FILES[$i]}"
|
107 |
+
FILENAME=$(basename "$FILE")
|
108 |
+
|
109 |
+
echo "处理文件 $FILENAME 使用 $DEVICE"
|
110 |
+
|
111 |
+
# 在后台启动进程
|
112 |
+
nohup python data/utils/utilitie.py \
|
113 |
+
--task generate_speech_tokens \
|
114 |
+
--language $LANGUAGE \
|
115 |
+
$is_cross_lingual \
|
116 |
+
$FILE_ARG "$FILE" \
|
117 |
+
--output_dir $OUTPUT_DIR \
|
118 |
+
--cosy_model_dir $COSY_MODEL_DIR \
|
119 |
+
--prompts_dir $PROMPTS_DIR \
|
120 |
+
$is_instructed \
|
121 |
+
--device "$DEVICE" > "$OUTPUT_DIR/log_${FILENAME%.*}.log" 2>&1 &
|
122 |
+
|
123 |
+
# 记录进程ID
|
124 |
+
PID=$!
|
125 |
+
echo "启动进程 PID: $PID 处理文件: $FILENAME 使用 $DEVICE"
|
126 |
+
|
127 |
+
# 等待一点时间确保进程启动
|
128 |
+
sleep 5
|
129 |
+
done
|
130 |
+
|
131 |
+
echo "所有处理进程已启动,日志文件保存在 $OUTPUT_DIR 目录"
|
132 |
+
echo "使用 'ps aux | grep utilitie.py' 命令查看运行状态"
|
133 |
+
echo "使用 'nvidia-smi' 命令监控GPU使用情况"
|
134 |
+
|
135 |
+
# 等待所有后台进程完成
|
136 |
+
wait
|
137 |
+
echo "所有处理已完成"
|
rwkvtts_requirements.txt
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
aiofiles==23.2.1
|
3 |
+
aiohappyeyeballs==2.4.8
|
4 |
+
aiohttp==3.11.13
|
5 |
+
aiosignal==1.3.2
|
6 |
+
alembic==1.15.1
|
7 |
+
altair==5.5.0
|
8 |
+
annotated-types==0.7.0
|
9 |
+
antlr4-python3-runtime==4.9.3
|
10 |
+
anyio==4.8.0
|
11 |
+
argon2-cffi==23.1.0
|
12 |
+
argon2-cffi-bindings==21.2.0
|
13 |
+
arrow==1.3.0
|
14 |
+
asttokens==3.0.0
|
15 |
+
async-lru==2.0.4
|
16 |
+
attrs==25.1.0
|
17 |
+
audioread==3.0.1
|
18 |
+
autopage==0.5.2
|
19 |
+
babel==2.17.0
|
20 |
+
beautifulsoup4==4.13.3
|
21 |
+
bleach==6.2.0
|
22 |
+
certifi==2025.1.31
|
23 |
+
cffi==1.17.1
|
24 |
+
cfgv==3.4.0
|
25 |
+
charset-normalizer==3.4.1
|
26 |
+
click==8.1.8
|
27 |
+
cliff==4.9.1
|
28 |
+
cmaes==0.11.1
|
29 |
+
cmd2==2.5.11
|
30 |
+
colorama==0.4.6
|
31 |
+
coloredlogs==15.0.1
|
32 |
+
colorlog==6.9.0
|
33 |
+
comm==0.2.2
|
34 |
+
conformer==0.3.2
|
35 |
+
contourpy==1.3.1
|
36 |
+
csvw==3.5.1
|
37 |
+
cycler==0.12.1
|
38 |
+
Cython==3.0.12
|
39 |
+
datasets==3.3.2
|
40 |
+
debugpy==1.8.13
|
41 |
+
decorator==5.2.1
|
42 |
+
deepspeed==0.16.4
|
43 |
+
defusedxml==0.7.1
|
44 |
+
diffusers==0.32.2
|
45 |
+
dill==0.3.8
|
46 |
+
distlib==0.3.9
|
47 |
+
dlinfo==2.0.0
|
48 |
+
einops==0.8.1
|
49 |
+
executing==2.2.0
|
50 |
+
fastapi==0.115.11
|
51 |
+
fastjsonschema==2.21.1
|
52 |
+
ffmpy==0.5.0
|
53 |
+
filelock==3.17.0
|
54 |
+
flatbuffers==25.2.10
|
55 |
+
fonttools==4.56.0
|
56 |
+
fqdn==1.5.1
|
57 |
+
frozenlist==1.5.0
|
58 |
+
fsspec==2024.12.0
|
59 |
+
gdown==5.2.0
|
60 |
+
gradio==3.43.2
|
61 |
+
gradio_client==0.5.0
|
62 |
+
greenlet==3.1.1
|
63 |
+
grpcio==1.70.0
|
64 |
+
h11==0.14.0
|
65 |
+
hjson==3.1.0
|
66 |
+
httpcore==1.0.7
|
67 |
+
httpx==0.28.1
|
68 |
+
huggingface-hub==0.29.1
|
69 |
+
humanfriendly==10.0
|
70 |
+
hydra-colorlog==1.2.0
|
71 |
+
hydra-core==1.3.2
|
72 |
+
hydra-optuna-sweeper==1.2.0
|
73 |
+
HyperPyYAML==1.2.2
|
74 |
+
identify==2.6.8
|
75 |
+
idna==3.10
|
76 |
+
importlib_metadata==8.6.1
|
77 |
+
importlib_resources==6.5.2
|
78 |
+
inflect==7.5.0
|
79 |
+
iniconfig==2.0.0
|
80 |
+
ipykernel==6.29.5
|
81 |
+
ipython==9.0.1
|
82 |
+
ipython_pygments_lexers==1.1.1
|
83 |
+
ipywidgets==8.1.5
|
84 |
+
isodate==0.7.2
|
85 |
+
isoduration==20.11.0
|
86 |
+
jedi==0.19.2
|
87 |
+
Jinja2==3.1.5
|
88 |
+
joblib==1.4.2
|
89 |
+
json5==0.10.0
|
90 |
+
jsonpointer==3.0.0
|
91 |
+
jsonschema==4.23.0
|
92 |
+
jsonschema-specifications==2024.10.1
|
93 |
+
jupyter-events==0.12.0
|
94 |
+
jupyter-lsp==2.2.5
|
95 |
+
jupyter_client==8.6.3
|
96 |
+
jupyter_core==5.7.2
|
97 |
+
jupyter_server==2.15.0
|
98 |
+
jupyter_server_terminals==0.5.3
|
99 |
+
jupyterlab==4.3.5
|
100 |
+
jupyterlab_pygments==0.3.0
|
101 |
+
jupyterlab_server==2.27.3
|
102 |
+
jupyterlab_widgets==3.0.13
|
103 |
+
kiwisolver==1.4.8
|
104 |
+
language-tags==1.2.0
|
105 |
+
lazy_loader==0.4
|
106 |
+
librosa==0.10.2.post1
|
107 |
+
lightning==2.5.0.post0
|
108 |
+
lightning-utilities==0.13.1
|
109 |
+
llvmlite==0.44.0
|
110 |
+
Mako==1.3.9
|
111 |
+
Markdown==3.7
|
112 |
+
markdown-it-py==3.0.0
|
113 |
+
MarkupSafe==2.1.5
|
114 |
+
matcha-tts==0.0.7.2
|
115 |
+
matplotlib==3.10.1
|
116 |
+
matplotlib-inline==0.1.7
|
117 |
+
mdurl==0.1.2
|
118 |
+
mistune==3.1.2
|
119 |
+
modelscope==1.23.2
|
120 |
+
more-itertools==10.6.0
|
121 |
+
mpmath==1.3.0
|
122 |
+
msgpack==1.1.0
|
123 |
+
multidict==6.1.0
|
124 |
+
multiprocess==0.70.16
|
125 |
+
narwhals==1.29.0
|
126 |
+
nbclient==0.10.2
|
127 |
+
nbconvert==7.16.6
|
128 |
+
nbformat==5.10.4
|
129 |
+
nest-asyncio==1.6.0
|
130 |
+
networkx==3.4.2
|
131 |
+
ninja==1.11.1.3
|
132 |
+
nodeenv==1.9.1
|
133 |
+
notebook==7.3.2
|
134 |
+
notebook_shim==0.2.4
|
135 |
+
numba==0.61.0
|
136 |
+
numpy==1.26.4
|
137 |
+
nvidia-cublas-cu12==12.4.5.8
|
138 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
139 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
140 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
141 |
+
nvidia-cudnn-cu12==9.1.0.70
|
142 |
+
nvidia-cufft-cu12==11.2.1.3
|
143 |
+
nvidia-curand-cu12==10.3.5.147
|
144 |
+
nvidia-cusolver-cu12==11.6.1.9
|
145 |
+
nvidia-cusparse-cu12==12.3.1.170
|
146 |
+
nvidia-cusparselt-cu12==0.6.2
|
147 |
+
nvidia-nccl-cu12==2.21.5
|
148 |
+
nvidia-nvjitlink-cu12==12.4.127
|
149 |
+
nvidia-nvtx-cu12==12.4.127
|
150 |
+
omegaconf==2.3.0
|
151 |
+
onnx==1.17.0
|
152 |
+
onnxruntime-gpu==1.20.1
|
153 |
+
openai-whisper==20240930
|
154 |
+
optuna==2.10.1
|
155 |
+
orjson==3.10.15
|
156 |
+
overrides==7.7.0
|
157 |
+
packaging==24.2
|
158 |
+
pandas==2.2.3
|
159 |
+
pandocfilters==1.5.1
|
160 |
+
parso==0.8.4
|
161 |
+
pbr==6.1.1
|
162 |
+
pexpect==4.9.0
|
163 |
+
phonemizer==3.3.0
|
164 |
+
pillow==10.4.0
|
165 |
+
platformdirs==4.3.6
|
166 |
+
pluggy==1.5.0
|
167 |
+
pooch==1.8.2
|
168 |
+
pre_commit==4.1.0
|
169 |
+
prettytable==3.15.1
|
170 |
+
prometheus_client==0.21.1
|
171 |
+
prompt_toolkit==3.0.50
|
172 |
+
propcache==0.3.0
|
173 |
+
protobuf==6.30.0
|
174 |
+
psutil==7.0.0
|
175 |
+
ptyprocess==0.7.0
|
176 |
+
pure_eval==0.2.3
|
177 |
+
py-cpuinfo==9.0.0
|
178 |
+
pyarrow==19.0.1
|
179 |
+
pycparser==2.22
|
180 |
+
pydantic==2.10.6
|
181 |
+
pydantic_core==2.27.2
|
182 |
+
pydub==0.25.1
|
183 |
+
Pygments==2.19.1
|
184 |
+
pyparsing==3.2.1
|
185 |
+
pyperclip==1.9.0
|
186 |
+
PySocks==1.7.1
|
187 |
+
pytest==8.3.5
|
188 |
+
python-dateutil==2.9.0.post0
|
189 |
+
python-dotenv==1.0.1
|
190 |
+
python-json-logger==3.2.1
|
191 |
+
python-multipart==0.0.20
|
192 |
+
pytorch-lightning==2.5.0.post0
|
193 |
+
pytz==2025.1
|
194 |
+
pyworld==0.3.5
|
195 |
+
PyYAML==6.0.2
|
196 |
+
pyzmq==26.2.1
|
197 |
+
rdflib==7.1.3
|
198 |
+
referencing==0.36.2
|
199 |
+
regex==2024.11.6
|
200 |
+
requests==2.32.3
|
201 |
+
rfc3339-validator==0.1.4
|
202 |
+
rfc3986==1.5.0
|
203 |
+
rfc3986-validator==0.1.1
|
204 |
+
rich==13.9.4
|
205 |
+
rootutils==1.0.7
|
206 |
+
rpds-py==0.23.1
|
207 |
+
ruamel.yaml==0.18.10
|
208 |
+
ruamel.yaml.clib==0.2.12
|
209 |
+
rwkv-fla==0.7.202503020902
|
210 |
+
safetensors==0.5.3
|
211 |
+
scikit-learn==1.6.1
|
212 |
+
scipy==1.15.2
|
213 |
+
seaborn==0.13.2
|
214 |
+
segments==2.3.0
|
215 |
+
semantic-version==2.10.0
|
216 |
+
Send2Trash==1.8.3
|
217 |
+
six==1.17.0
|
218 |
+
sniffio==1.3.1
|
219 |
+
soundfile==0.13.1
|
220 |
+
soupsieve==2.6
|
221 |
+
soxr==0.5.0.post1
|
222 |
+
SQLAlchemy==2.0.38
|
223 |
+
stack-data==0.6.3
|
224 |
+
starlette==0.46.0
|
225 |
+
stevedore==5.4.1
|
226 |
+
sympy==1.13.1
|
227 |
+
tensorboard==2.19.0
|
228 |
+
tensorboard-data-server==0.7.2
|
229 |
+
terminado==0.18.1
|
230 |
+
threadpoolctl==3.5.0
|
231 |
+
tiktoken==0.9.0
|
232 |
+
tinycss2==1.4.0
|
233 |
+
tokenizers==0.21.0
|
234 |
+
torch==2.6.0
|
235 |
+
torchaudio==2.6.0
|
236 |
+
torchmetrics==1.6.2
|
237 |
+
torchvision==0.21.0
|
238 |
+
tornado==6.4.2
|
239 |
+
tqdm==4.67.1
|
240 |
+
traitlets==5.14.3
|
241 |
+
transformers==4.49.0
|
242 |
+
triton==3.2.0
|
243 |
+
typeguard==4.4.2
|
244 |
+
types-python-dateutil==2.9.0.20241206
|
245 |
+
typing_extensions==4.12.2
|
246 |
+
tzdata==2025.1
|
247 |
+
Unidecode==1.3.8
|
248 |
+
uri-template==1.3.0
|
249 |
+
uritemplate==4.1.1
|
250 |
+
urllib3==2.3.0
|
251 |
+
uvicorn==0.34.0
|
252 |
+
virtualenv==20.29.2
|
253 |
+
wcwidth==0.2.13
|
254 |
+
webcolors==24.11.1
|
255 |
+
webencodings==0.5.1
|
256 |
+
websocket-client==1.8.0
|
257 |
+
websockets==11.0.3
|
258 |
+
Werkzeug==3.1.3
|
259 |
+
WeTextProcessing==1.0.4.1
|
260 |
+
wget==3.2
|
261 |
+
widgetsnbextension==4.0.13
|
262 |
+
xxhash==3.5.0
|
263 |
+
yarl==1.18.3
|
264 |
+
zipp==3.21.0
|
third_party/cosyvoice/dataset/processor.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import logging
|
15 |
+
import random
|
16 |
+
|
17 |
+
import pyarrow.parquet as pq
|
18 |
+
from io import BytesIO
|
19 |
+
import torch
|
20 |
+
import torchaudio
|
21 |
+
from torch.nn.utils.rnn import pad_sequence
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import pyworld as pw
|
24 |
+
|
25 |
+
|
26 |
+
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
27 |
+
|
28 |
+
|
29 |
+
def parquet_opener(data, mode='train', tts_data={}):
|
30 |
+
""" Give url or local file, return file descriptor
|
31 |
+
Inplace operation.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
data(Iterable[str]): url or local file list
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Iterable[{src, stream}]
|
38 |
+
"""
|
39 |
+
for sample in data:
|
40 |
+
assert 'src' in sample
|
41 |
+
url = sample['src']
|
42 |
+
try:
|
43 |
+
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
44 |
+
df = df.to_pandas()
|
45 |
+
for i in range(len(df)):
|
46 |
+
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
47 |
+
continue
|
48 |
+
sample.update(dict(df.loc[i]))
|
49 |
+
if mode == 'train':
|
50 |
+
# NOTE do not return sample directly, must initialize a new dict
|
51 |
+
yield {**sample}
|
52 |
+
else:
|
53 |
+
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
54 |
+
yield {**sample, 'tts_index': index, 'tts_text': text}
|
55 |
+
except Exception as ex:
|
56 |
+
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
57 |
+
|
58 |
+
|
59 |
+
def filter(data,
|
60 |
+
max_length=10240,
|
61 |
+
min_length=10,
|
62 |
+
token_max_length=200,
|
63 |
+
token_min_length=1,
|
64 |
+
min_output_input_ratio=0.0005,
|
65 |
+
max_output_input_ratio=1,
|
66 |
+
mode='train'):
|
67 |
+
""" Filter sample according to feature and label length
|
68 |
+
Inplace operation.
|
69 |
+
|
70 |
+
Args::
|
71 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
72 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
73 |
+
min_length: drop utterance which is less than min_length(10ms)
|
74 |
+
token_max_length: drop utterance which is greater than
|
75 |
+
token_max_length, especially when use char unit for
|
76 |
+
english modeling
|
77 |
+
token_min_length: drop utterance which is
|
78 |
+
less than token_max_length
|
79 |
+
min_output_input_ratio: minimal ration of
|
80 |
+
token_length / feats_length(10ms)
|
81 |
+
max_output_input_ratio: maximum ration of
|
82 |
+
token_length / feats_length(10ms)
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Iterable[{key, wav, label, sample_rate}]
|
86 |
+
"""
|
87 |
+
for sample in data:
|
88 |
+
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
89 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
90 |
+
del sample['audio_data']
|
91 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
92 |
+
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
93 |
+
if num_frames < min_length:
|
94 |
+
continue
|
95 |
+
if num_frames > max_length:
|
96 |
+
continue
|
97 |
+
if len(sample['text_token']) < token_min_length:
|
98 |
+
continue
|
99 |
+
if len(sample['text_token']) > token_max_length:
|
100 |
+
continue
|
101 |
+
if len(sample['speech_token']) == 0:
|
102 |
+
continue
|
103 |
+
if num_frames != 0:
|
104 |
+
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
105 |
+
continue
|
106 |
+
if len(sample['text_token']) / num_frames > max_output_input_ratio:
|
107 |
+
continue
|
108 |
+
yield sample
|
109 |
+
|
110 |
+
|
111 |
+
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
112 |
+
""" Resample data.
|
113 |
+
Inplace operation.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
117 |
+
resample_rate: target resample rate
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Iterable[{key, wav, label, sample_rate}]
|
121 |
+
"""
|
122 |
+
for sample in data:
|
123 |
+
assert 'sample_rate' in sample
|
124 |
+
assert 'speech' in sample
|
125 |
+
sample_rate = sample['sample_rate']
|
126 |
+
waveform = sample['speech']
|
127 |
+
if sample_rate != resample_rate:
|
128 |
+
if sample_rate < min_sample_rate:
|
129 |
+
continue
|
130 |
+
sample['sample_rate'] = resample_rate
|
131 |
+
sample['speech'] = torchaudio.transforms.Resample(
|
132 |
+
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
133 |
+
max_val = sample['speech'].abs().max()
|
134 |
+
if max_val > 1:
|
135 |
+
sample['speech'] /= max_val
|
136 |
+
yield sample
|
137 |
+
|
138 |
+
|
139 |
+
def truncate(data, truncate_length=24576, mode='train'):
|
140 |
+
""" Truncate data.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
144 |
+
truncate_length: truncate length
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Iterable[{key, wav, label, sample_rate}]
|
148 |
+
"""
|
149 |
+
for sample in data:
|
150 |
+
waveform = sample['speech']
|
151 |
+
if waveform.shape[1] > truncate_length:
|
152 |
+
start = random.randint(0, waveform.shape[1] - truncate_length)
|
153 |
+
waveform = waveform[:, start: start + truncate_length]
|
154 |
+
else:
|
155 |
+
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
156 |
+
sample['speech'] = waveform
|
157 |
+
yield sample
|
158 |
+
|
159 |
+
|
160 |
+
def compute_fbank(data,
|
161 |
+
feat_extractor,
|
162 |
+
mode='train'):
|
163 |
+
""" Extract fbank
|
164 |
+
|
165 |
+
Args:
|
166 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Iterable[{key, feat, label}]
|
170 |
+
"""
|
171 |
+
for sample in data:
|
172 |
+
assert 'sample_rate' in sample
|
173 |
+
assert 'speech' in sample
|
174 |
+
assert 'utt' in sample
|
175 |
+
assert 'text_token' in sample
|
176 |
+
waveform = sample['speech']
|
177 |
+
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
178 |
+
sample['speech_feat'] = mat
|
179 |
+
yield sample
|
180 |
+
|
181 |
+
|
182 |
+
def compute_f0(data, sample_rate, hop_size, mode='train'):
|
183 |
+
""" Extract f0
|
184 |
+
|
185 |
+
Args:
|
186 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
Iterable[{key, feat, label}]
|
190 |
+
"""
|
191 |
+
frame_period = hop_size * 1000 / sample_rate
|
192 |
+
for sample in data:
|
193 |
+
assert 'sample_rate' in sample
|
194 |
+
assert 'speech' in sample
|
195 |
+
assert 'utt' in sample
|
196 |
+
assert 'text_token' in sample
|
197 |
+
waveform = sample['speech']
|
198 |
+
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
|
199 |
+
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
|
200 |
+
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
|
201 |
+
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
|
202 |
+
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
|
203 |
+
sample['pitch_feat'] = f0
|
204 |
+
yield sample
|
205 |
+
|
206 |
+
|
207 |
+
def parse_embedding(data, normalize, mode='train'):
|
208 |
+
""" Parse utt_embedding/spk_embedding
|
209 |
+
|
210 |
+
Args:
|
211 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
Iterable[{key, feat, label}]
|
215 |
+
"""
|
216 |
+
for sample in data:
|
217 |
+
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
218 |
+
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
219 |
+
if normalize:
|
220 |
+
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
221 |
+
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
222 |
+
yield sample
|
223 |
+
|
224 |
+
|
225 |
+
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
226 |
+
""" Decode text to chars or BPE
|
227 |
+
Inplace operation
|
228 |
+
|
229 |
+
Args:
|
230 |
+
data: Iterable[{key, wav, txt, sample_rate}]
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
234 |
+
"""
|
235 |
+
tokenizer = get_tokenizer()
|
236 |
+
for sample in data:
|
237 |
+
assert 'text' in sample
|
238 |
+
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
239 |
+
if mode == 'inference':
|
240 |
+
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
241 |
+
yield sample
|
242 |
+
|
243 |
+
|
244 |
+
def shuffle(data, shuffle_size=10000, mode='train'):
|
245 |
+
""" Local shuffle the data
|
246 |
+
|
247 |
+
Args:
|
248 |
+
data: Iterable[{key, feat, label}]
|
249 |
+
shuffle_size: buffer size for shuffle
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
Iterable[{key, feat, label}]
|
253 |
+
"""
|
254 |
+
buf = []
|
255 |
+
for sample in data:
|
256 |
+
buf.append(sample)
|
257 |
+
if len(buf) >= shuffle_size:
|
258 |
+
random.shuffle(buf)
|
259 |
+
for x in buf:
|
260 |
+
yield x
|
261 |
+
buf = []
|
262 |
+
# The sample left over
|
263 |
+
random.shuffle(buf)
|
264 |
+
for x in buf:
|
265 |
+
yield x
|
266 |
+
|
267 |
+
|
268 |
+
def sort(data, sort_size=500, mode='train'):
|
269 |
+
""" Sort the data by feature length.
|
270 |
+
Sort is used after shuffle and before batch, so we can group
|
271 |
+
utts with similar lengths into a batch, and `sort_size` should
|
272 |
+
be less than `shuffle_size`
|
273 |
+
|
274 |
+
Args:
|
275 |
+
data: Iterable[{key, feat, label}]
|
276 |
+
sort_size: buffer size for sort
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
Iterable[{key, feat, label}]
|
280 |
+
"""
|
281 |
+
|
282 |
+
buf = []
|
283 |
+
for sample in data:
|
284 |
+
buf.append(sample)
|
285 |
+
if len(buf) >= sort_size:
|
286 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
287 |
+
for x in buf:
|
288 |
+
yield x
|
289 |
+
buf = []
|
290 |
+
# The sample left over
|
291 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
292 |
+
for x in buf:
|
293 |
+
yield x
|
294 |
+
|
295 |
+
|
296 |
+
def static_batch(data, batch_size=16):
|
297 |
+
""" Static batch the data by `batch_size`
|
298 |
+
|
299 |
+
Args:
|
300 |
+
data: Iterable[{key, feat, label}]
|
301 |
+
batch_size: batch size
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
Iterable[List[{key, feat, label}]]
|
305 |
+
"""
|
306 |
+
buf = []
|
307 |
+
for sample in data:
|
308 |
+
buf.append(sample)
|
309 |
+
if len(buf) >= batch_size:
|
310 |
+
yield buf
|
311 |
+
buf = []
|
312 |
+
if len(buf) > 0:
|
313 |
+
yield buf
|
314 |
+
|
315 |
+
|
316 |
+
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
317 |
+
""" Dynamic batch the data until the total frames in batch
|
318 |
+
reach `max_frames_in_batch`
|
319 |
+
|
320 |
+
Args:
|
321 |
+
data: Iterable[{key, feat, label}]
|
322 |
+
max_frames_in_batch: max_frames in one batch
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
Iterable[List[{key, feat, label}]]
|
326 |
+
"""
|
327 |
+
buf = []
|
328 |
+
longest_frames = 0
|
329 |
+
for sample in data:
|
330 |
+
assert 'speech_feat' in sample
|
331 |
+
assert isinstance(sample['speech_feat'], torch.Tensor)
|
332 |
+
new_sample_frames = sample['speech_feat'].size(0)
|
333 |
+
longest_frames = max(longest_frames, new_sample_frames)
|
334 |
+
frames_after_padding = longest_frames * (len(buf) + 1)
|
335 |
+
if frames_after_padding > max_frames_in_batch:
|
336 |
+
yield buf
|
337 |
+
buf = [sample]
|
338 |
+
longest_frames = new_sample_frames
|
339 |
+
else:
|
340 |
+
buf.append(sample)
|
341 |
+
if len(buf) > 0:
|
342 |
+
yield buf
|
343 |
+
|
344 |
+
|
345 |
+
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
346 |
+
""" Wrapper for static/dynamic batch
|
347 |
+
"""
|
348 |
+
if mode == 'inference':
|
349 |
+
return static_batch(data, 1)
|
350 |
+
else:
|
351 |
+
if batch_type == 'static':
|
352 |
+
return static_batch(data, batch_size)
|
353 |
+
elif batch_type == 'dynamic':
|
354 |
+
return dynamic_batch(data, max_frames_in_batch)
|
355 |
+
else:
|
356 |
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
357 |
+
|
358 |
+
|
359 |
+
def padding(data, use_spk_embedding, mode='train', gan=False):
|
360 |
+
""" Padding the data into training data
|
361 |
+
|
362 |
+
Args:
|
363 |
+
data: Iterable[List[{key, feat, label}]]
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
367 |
+
"""
|
368 |
+
for sample in data:
|
369 |
+
assert isinstance(sample, list)
|
370 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
371 |
+
dtype=torch.int32)
|
372 |
+
order = torch.argsort(speech_feat_len, descending=True)
|
373 |
+
|
374 |
+
utts = [sample[i]['utt'] for i in order]
|
375 |
+
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
376 |
+
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
377 |
+
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
378 |
+
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
379 |
+
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
380 |
+
speech_token = pad_sequence(speech_token,
|
381 |
+
batch_first=True,
|
382 |
+
padding_value=0)
|
383 |
+
speech_feat = [sample[i]['speech_feat'] for i in order]
|
384 |
+
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
385 |
+
speech_feat = pad_sequence(speech_feat,
|
386 |
+
batch_first=True,
|
387 |
+
padding_value=0)
|
388 |
+
text = [sample[i]['text'] for i in order]
|
389 |
+
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
390 |
+
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
391 |
+
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
392 |
+
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
393 |
+
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
394 |
+
batch = {
|
395 |
+
"utts": utts,
|
396 |
+
"speech": speech,
|
397 |
+
"speech_len": speech_len,
|
398 |
+
"speech_token": speech_token,
|
399 |
+
"speech_token_len": speech_token_len,
|
400 |
+
"speech_feat": speech_feat,
|
401 |
+
"speech_feat_len": speech_feat_len,
|
402 |
+
"text": text,
|
403 |
+
"text_token": text_token,
|
404 |
+
"text_token_len": text_token_len,
|
405 |
+
"utt_embedding": utt_embedding,
|
406 |
+
"spk_embedding": spk_embedding,
|
407 |
+
}
|
408 |
+
if gan is True:
|
409 |
+
# in gan train, we need pitch_feat
|
410 |
+
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
411 |
+
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
412 |
+
pitch_feat = pad_sequence(pitch_feat,
|
413 |
+
batch_first=True,
|
414 |
+
padding_value=0)
|
415 |
+
batch["pitch_feat"] = pitch_feat
|
416 |
+
batch["pitch_feat_len"] = pitch_feat_len
|
417 |
+
else:
|
418 |
+
# only gan train needs speech, delete it to save memory
|
419 |
+
del batch["speech"]
|
420 |
+
del batch["speech_len"]
|
421 |
+
if mode == 'inference':
|
422 |
+
tts_text = [sample[i]['tts_text'] for i in order]
|
423 |
+
tts_index = [sample[i]['tts_index'] for i in order]
|
424 |
+
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
425 |
+
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
426 |
+
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
427 |
+
batch.update({'tts_text': tts_text,
|
428 |
+
'tts_index': tts_index,
|
429 |
+
'tts_text_token': tts_text_token,
|
430 |
+
'tts_text_token_len': tts_text_token_len})
|
431 |
+
if use_spk_embedding is True:
|
432 |
+
batch["embedding"] = batch["spk_embedding"]
|
433 |
+
else:
|
434 |
+
batch["embedding"] = batch["utt_embedding"]
|
435 |
+
yield batch
|
third_party/cosyvoice/flow/decoder.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from einops import pack, rearrange, repeat
|
18 |
+
from cosyvoice.utils.common import mask_to_bias
|
19 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
20 |
+
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
21 |
+
from matcha.models.components.transformer import BasicTransformerBlock
|
22 |
+
|
23 |
+
|
24 |
+
class Transpose(torch.nn.Module):
|
25 |
+
def __init__(self, dim0: int, dim1: int):
|
26 |
+
super().__init__()
|
27 |
+
self.dim0 = dim0
|
28 |
+
self.dim1 = dim1
|
29 |
+
|
30 |
+
def forward(self, x: torch.Tensor):
|
31 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class CausalBlock1D(Block1D):
|
36 |
+
def __init__(self, dim: int, dim_out: int):
|
37 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
38 |
+
self.block = torch.nn.Sequential(
|
39 |
+
CausalConv1d(dim, dim_out, 3),
|
40 |
+
Transpose(1, 2),
|
41 |
+
nn.LayerNorm(dim_out),
|
42 |
+
Transpose(1, 2),
|
43 |
+
nn.Mish(),
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
47 |
+
output = self.block(x * mask)
|
48 |
+
return output * mask
|
49 |
+
|
50 |
+
|
51 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
52 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
53 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
54 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
55 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
56 |
+
|
57 |
+
|
58 |
+
class CausalConv1d(torch.nn.Conv1d):
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
in_channels: int,
|
62 |
+
out_channels: int,
|
63 |
+
kernel_size: int,
|
64 |
+
stride: int = 1,
|
65 |
+
dilation: int = 1,
|
66 |
+
groups: int = 1,
|
67 |
+
bias: bool = True,
|
68 |
+
padding_mode: str = 'zeros',
|
69 |
+
device=None,
|
70 |
+
dtype=None
|
71 |
+
) -> None:
|
72 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
73 |
+
kernel_size, stride,
|
74 |
+
padding=0, dilation=dilation,
|
75 |
+
groups=groups, bias=bias,
|
76 |
+
padding_mode=padding_mode,
|
77 |
+
device=device, dtype=dtype)
|
78 |
+
assert stride == 1
|
79 |
+
self.causal_padding = (kernel_size - 1, 0)
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor):
|
82 |
+
x = F.pad(x, self.causal_padding)
|
83 |
+
x = super(CausalConv1d, self).forward(x)
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class ConditionalDecoder(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
in_channels,
|
91 |
+
out_channels,
|
92 |
+
causal=False,
|
93 |
+
channels=(256, 256),
|
94 |
+
dropout=0.05,
|
95 |
+
attention_head_dim=64,
|
96 |
+
n_blocks=1,
|
97 |
+
num_mid_blocks=2,
|
98 |
+
num_heads=4,
|
99 |
+
act_fn="snake",
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
103 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
104 |
+
"""
|
105 |
+
super().__init__()
|
106 |
+
channels = tuple(channels)
|
107 |
+
self.in_channels = in_channels
|
108 |
+
self.out_channels = out_channels
|
109 |
+
self.causal = causal
|
110 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
111 |
+
time_embed_dim = channels[0] * 4
|
112 |
+
self.time_mlp = TimestepEmbedding(
|
113 |
+
in_channels=in_channels,
|
114 |
+
time_embed_dim=time_embed_dim,
|
115 |
+
act_fn="silu",
|
116 |
+
)
|
117 |
+
self.down_blocks = nn.ModuleList([])
|
118 |
+
self.mid_blocks = nn.ModuleList([])
|
119 |
+
self.up_blocks = nn.ModuleList([])
|
120 |
+
|
121 |
+
output_channel = in_channels
|
122 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
123 |
+
input_channel = output_channel
|
124 |
+
output_channel = channels[i]
|
125 |
+
is_last = i == len(channels) - 1
|
126 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
127 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
128 |
+
transformer_blocks = nn.ModuleList(
|
129 |
+
[
|
130 |
+
BasicTransformerBlock(
|
131 |
+
dim=output_channel,
|
132 |
+
num_attention_heads=num_heads,
|
133 |
+
attention_head_dim=attention_head_dim,
|
134 |
+
dropout=dropout,
|
135 |
+
activation_fn=act_fn,
|
136 |
+
)
|
137 |
+
for _ in range(n_blocks)
|
138 |
+
]
|
139 |
+
)
|
140 |
+
downsample = (
|
141 |
+
Downsample1D(output_channel) if not is_last else
|
142 |
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
143 |
+
)
|
144 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
145 |
+
|
146 |
+
for _ in range(num_mid_blocks):
|
147 |
+
input_channel = channels[-1]
|
148 |
+
out_channels = channels[-1]
|
149 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
150 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
151 |
+
|
152 |
+
transformer_blocks = nn.ModuleList(
|
153 |
+
[
|
154 |
+
BasicTransformerBlock(
|
155 |
+
dim=output_channel,
|
156 |
+
num_attention_heads=num_heads,
|
157 |
+
attention_head_dim=attention_head_dim,
|
158 |
+
dropout=dropout,
|
159 |
+
activation_fn=act_fn,
|
160 |
+
)
|
161 |
+
for _ in range(n_blocks)
|
162 |
+
]
|
163 |
+
)
|
164 |
+
|
165 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
166 |
+
|
167 |
+
channels = channels[::-1] + (channels[0],)
|
168 |
+
for i in range(len(channels) - 1):
|
169 |
+
input_channel = channels[i] * 2
|
170 |
+
output_channel = channels[i + 1]
|
171 |
+
is_last = i == len(channels) - 2
|
172 |
+
resnet = CausalResnetBlock1D(
|
173 |
+
dim=input_channel,
|
174 |
+
dim_out=output_channel,
|
175 |
+
time_emb_dim=time_embed_dim,
|
176 |
+
) if self.causal else ResnetBlock1D(
|
177 |
+
dim=input_channel,
|
178 |
+
dim_out=output_channel,
|
179 |
+
time_emb_dim=time_embed_dim,
|
180 |
+
)
|
181 |
+
transformer_blocks = nn.ModuleList(
|
182 |
+
[
|
183 |
+
BasicTransformerBlock(
|
184 |
+
dim=output_channel,
|
185 |
+
num_attention_heads=num_heads,
|
186 |
+
attention_head_dim=attention_head_dim,
|
187 |
+
dropout=dropout,
|
188 |
+
activation_fn=act_fn,
|
189 |
+
)
|
190 |
+
for _ in range(n_blocks)
|
191 |
+
]
|
192 |
+
)
|
193 |
+
upsample = (
|
194 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
195 |
+
if not is_last
|
196 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
197 |
+
)
|
198 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
199 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
200 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
201 |
+
self.initialize_weights()
|
202 |
+
|
203 |
+
def initialize_weights(self):
|
204 |
+
for m in self.modules():
|
205 |
+
if isinstance(m, nn.Conv1d):
|
206 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
207 |
+
if m.bias is not None:
|
208 |
+
nn.init.constant_(m.bias, 0)
|
209 |
+
elif isinstance(m, nn.GroupNorm):
|
210 |
+
nn.init.constant_(m.weight, 1)
|
211 |
+
nn.init.constant_(m.bias, 0)
|
212 |
+
elif isinstance(m, nn.Linear):
|
213 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
214 |
+
if m.bias is not None:
|
215 |
+
nn.init.constant_(m.bias, 0)
|
216 |
+
|
217 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
218 |
+
"""Forward pass of the UNet1DConditional model.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
222 |
+
mask (_type_): shape (batch_size, 1, time)
|
223 |
+
t (_type_): shape (batch_size)
|
224 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
225 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
226 |
+
|
227 |
+
Raises:
|
228 |
+
ValueError: _description_
|
229 |
+
ValueError: _description_
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
_type_: _description_
|
233 |
+
"""
|
234 |
+
|
235 |
+
t = self.time_embeddings(t).to(t.dtype)
|
236 |
+
t = self.time_mlp(t)
|
237 |
+
|
238 |
+
x = pack([x, mu], "b * t")[0]
|
239 |
+
|
240 |
+
if spks is not None:
|
241 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
242 |
+
x = pack([x, spks], "b * t")[0]
|
243 |
+
if cond is not None:
|
244 |
+
x = pack([x, cond], "b * t")[0]
|
245 |
+
|
246 |
+
hiddens = []
|
247 |
+
masks = [mask]
|
248 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
249 |
+
mask_down = masks[-1]
|
250 |
+
x = resnet(x, mask_down, t)
|
251 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
252 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
253 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
254 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
255 |
+
for transformer_block in transformer_blocks:
|
256 |
+
x = transformer_block(
|
257 |
+
hidden_states=x,
|
258 |
+
attention_mask=attn_mask,
|
259 |
+
timestep=t,
|
260 |
+
)
|
261 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
262 |
+
hiddens.append(x) # Save hidden states for skip connections
|
263 |
+
x = downsample(x * mask_down)
|
264 |
+
masks.append(mask_down[:, :, ::2])
|
265 |
+
masks = masks[:-1]
|
266 |
+
mask_mid = masks[-1]
|
267 |
+
|
268 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
269 |
+
x = resnet(x, mask_mid, t)
|
270 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
271 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
272 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
273 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
274 |
+
for transformer_block in transformer_blocks:
|
275 |
+
x = transformer_block(
|
276 |
+
hidden_states=x,
|
277 |
+
attention_mask=attn_mask,
|
278 |
+
timestep=t,
|
279 |
+
)
|
280 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
281 |
+
|
282 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
283 |
+
mask_up = masks.pop()
|
284 |
+
skip = hiddens.pop()
|
285 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
286 |
+
x = resnet(x, mask_up, t)
|
287 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
288 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
289 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
290 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
291 |
+
for transformer_block in transformer_blocks:
|
292 |
+
x = transformer_block(
|
293 |
+
hidden_states=x,
|
294 |
+
attention_mask=attn_mask,
|
295 |
+
timestep=t,
|
296 |
+
)
|
297 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
298 |
+
x = upsample(x * mask_up)
|
299 |
+
x = self.final_block(x, mask_up)
|
300 |
+
output = self.final_proj(x * mask_up)
|
301 |
+
return output * mask
|
third_party/cosyvoice/flow/flow.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import logging
|
15 |
+
import random
|
16 |
+
from typing import Dict, Optional
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
from cosyvoice.utils.mask import make_pad_mask
|
22 |
+
|
23 |
+
|
24 |
+
class MaskedDiffWithXvec(torch.nn.Module):
|
25 |
+
def __init__(self,
|
26 |
+
input_size: int = 512,
|
27 |
+
output_size: int = 80,
|
28 |
+
spk_embed_dim: int = 192,
|
29 |
+
output_type: str = "mel",
|
30 |
+
vocab_size: int = 4096,
|
31 |
+
input_frame_rate: int = 50,
|
32 |
+
only_mask_loss: bool = True,
|
33 |
+
encoder: torch.nn.Module = None,
|
34 |
+
length_regulator: torch.nn.Module = None,
|
35 |
+
decoder: torch.nn.Module = None,
|
36 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
37 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
38 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
39 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
40 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
41 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
42 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
43 |
+
super().__init__()
|
44 |
+
self.input_size = input_size
|
45 |
+
self.output_size = output_size
|
46 |
+
self.decoder_conf = decoder_conf
|
47 |
+
self.mel_feat_conf = mel_feat_conf
|
48 |
+
self.vocab_size = vocab_size
|
49 |
+
self.output_type = output_type
|
50 |
+
self.input_frame_rate = input_frame_rate
|
51 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
52 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
53 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
54 |
+
self.encoder = encoder
|
55 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
56 |
+
self.decoder = decoder
|
57 |
+
self.length_regulator = length_regulator
|
58 |
+
self.only_mask_loss = only_mask_loss
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
batch: dict,
|
63 |
+
device: torch.device,
|
64 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
65 |
+
token = batch['speech_token'].to(device)
|
66 |
+
token_len = batch['speech_token_len'].to(device)
|
67 |
+
feat = batch['speech_feat'].to(device)
|
68 |
+
feat_len = batch['speech_feat_len'].to(device)
|
69 |
+
embedding = batch['embedding'].to(device)
|
70 |
+
|
71 |
+
# xvec projection
|
72 |
+
embedding = F.normalize(embedding, dim=1)
|
73 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
74 |
+
|
75 |
+
# concat text and prompt_text
|
76 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
77 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
78 |
+
|
79 |
+
# text encode
|
80 |
+
h, h_lengths = self.encoder(token, token_len)
|
81 |
+
h = self.encoder_proj(h)
|
82 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
83 |
+
|
84 |
+
# get conditions
|
85 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
86 |
+
for i, j in enumerate(feat_len):
|
87 |
+
if random.random() < 0.5:
|
88 |
+
continue
|
89 |
+
index = random.randint(0, int(0.3 * j))
|
90 |
+
conds[i, :index] = feat[i, :index]
|
91 |
+
conds = conds.transpose(1, 2)
|
92 |
+
|
93 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
94 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
95 |
+
loss, _ = self.decoder.compute_loss(
|
96 |
+
feat.transpose(1, 2).contiguous(),
|
97 |
+
mask.unsqueeze(1),
|
98 |
+
h.transpose(1, 2).contiguous(),
|
99 |
+
embedding,
|
100 |
+
cond=conds
|
101 |
+
)
|
102 |
+
return {'loss': loss}
|
103 |
+
|
104 |
+
@torch.inference_mode()
|
105 |
+
def inference(self,
|
106 |
+
token,
|
107 |
+
token_len,
|
108 |
+
prompt_token,
|
109 |
+
prompt_token_len,
|
110 |
+
prompt_feat,
|
111 |
+
prompt_feat_len,
|
112 |
+
embedding,
|
113 |
+
flow_cache):
|
114 |
+
if self.fp16 is True:
|
115 |
+
prompt_feat = prompt_feat.half()
|
116 |
+
embedding = embedding.half()
|
117 |
+
|
118 |
+
assert token.shape[0] == 1
|
119 |
+
# xvec projection
|
120 |
+
embedding = F.normalize(embedding, dim=1)
|
121 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
122 |
+
|
123 |
+
# concat text and prompt_text
|
124 |
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
125 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
126 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
127 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
128 |
+
|
129 |
+
# text encode
|
130 |
+
h, h_lengths = self.encoder(token, token_len)
|
131 |
+
h = self.encoder_proj(h)
|
132 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
133 |
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
134 |
+
|
135 |
+
# get conditions
|
136 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
137 |
+
conds[:, :mel_len1] = prompt_feat
|
138 |
+
conds = conds.transpose(1, 2)
|
139 |
+
|
140 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
141 |
+
feat, flow_cache = self.decoder(
|
142 |
+
mu=h.transpose(1, 2).contiguous(),
|
143 |
+
mask=mask.unsqueeze(1),
|
144 |
+
spks=embedding,
|
145 |
+
cond=conds,
|
146 |
+
n_timesteps=10,
|
147 |
+
prompt_len=mel_len1,
|
148 |
+
flow_cache=flow_cache
|
149 |
+
)
|
150 |
+
feat = feat[:, :, mel_len1:]
|
151 |
+
assert feat.shape[2] == mel_len2
|
152 |
+
return feat.float(), flow_cache
|
153 |
+
|
154 |
+
|
155 |
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
156 |
+
def __init__(self,
|
157 |
+
input_size: int = 512,
|
158 |
+
output_size: int = 80,
|
159 |
+
spk_embed_dim: int = 192,
|
160 |
+
output_type: str = "mel",
|
161 |
+
vocab_size: int = 4096,
|
162 |
+
input_frame_rate: int = 50,
|
163 |
+
only_mask_loss: bool = True,
|
164 |
+
token_mel_ratio: int = 2,
|
165 |
+
pre_lookahead_len: int = 3,
|
166 |
+
encoder: torch.nn.Module = None,
|
167 |
+
decoder: torch.nn.Module = None,
|
168 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
169 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
170 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
171 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
172 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
173 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
174 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
175 |
+
super().__init__()
|
176 |
+
self.input_size = input_size
|
177 |
+
self.output_size = output_size
|
178 |
+
self.decoder_conf = decoder_conf
|
179 |
+
self.mel_feat_conf = mel_feat_conf
|
180 |
+
self.vocab_size = vocab_size
|
181 |
+
self.output_type = output_type
|
182 |
+
self.input_frame_rate = input_frame_rate
|
183 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
184 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
185 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
186 |
+
self.encoder = encoder
|
187 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
188 |
+
self.decoder = decoder
|
189 |
+
self.only_mask_loss = only_mask_loss
|
190 |
+
self.token_mel_ratio = token_mel_ratio
|
191 |
+
self.pre_lookahead_len = pre_lookahead_len
|
192 |
+
|
193 |
+
@torch.inference_mode()
|
194 |
+
def inference(self,
|
195 |
+
token,
|
196 |
+
token_len,
|
197 |
+
prompt_token,
|
198 |
+
prompt_token_len,
|
199 |
+
prompt_feat,
|
200 |
+
prompt_feat_len,
|
201 |
+
embedding,
|
202 |
+
finalize):
|
203 |
+
if self.fp16 is True:
|
204 |
+
prompt_feat = prompt_feat.half()
|
205 |
+
embedding = embedding.half()
|
206 |
+
|
207 |
+
assert token.shape[0] == 1
|
208 |
+
# xvec projection
|
209 |
+
embedding = F.normalize(embedding, dim=1)
|
210 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
211 |
+
|
212 |
+
# concat text and prompt_text
|
213 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
214 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
215 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
216 |
+
|
217 |
+
# text encode
|
218 |
+
h, h_lengths = self.encoder(token, token_len)
|
219 |
+
if finalize is False:
|
220 |
+
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
|
221 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
222 |
+
h = self.encoder_proj(h)
|
223 |
+
|
224 |
+
# get conditions
|
225 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
226 |
+
conds[:, :mel_len1] = prompt_feat
|
227 |
+
conds = conds.transpose(1, 2)
|
228 |
+
|
229 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
230 |
+
feat, _ = self.decoder(
|
231 |
+
mu=h.transpose(1, 2).contiguous(),
|
232 |
+
mask=mask.unsqueeze(1),
|
233 |
+
spks=embedding,
|
234 |
+
cond=conds,
|
235 |
+
n_timesteps=10
|
236 |
+
)
|
237 |
+
feat = feat[:, :, mel_len1:]
|
238 |
+
assert feat.shape[2] == mel_len2
|
239 |
+
return feat.float(), None
|
third_party/cosyvoice/flow/flow_matching.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import threading
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from matcha.models.components.flow_matching import BASECFM
|
18 |
+
|
19 |
+
|
20 |
+
class ConditionalCFM(BASECFM):
|
21 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
22 |
+
super().__init__(
|
23 |
+
n_feats=in_channels,
|
24 |
+
cfm_params=cfm_params,
|
25 |
+
n_spks=n_spks,
|
26 |
+
spk_emb_dim=spk_emb_dim,
|
27 |
+
)
|
28 |
+
self.t_scheduler = cfm_params.t_scheduler
|
29 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
30 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
31 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
32 |
+
# Just change the architecture of the estimator here
|
33 |
+
self.estimator = estimator
|
34 |
+
self.lock = threading.Lock()
|
35 |
+
|
36 |
+
@torch.inference_mode()
|
37 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
38 |
+
"""Forward diffusion
|
39 |
+
|
40 |
+
Args:
|
41 |
+
mu (torch.Tensor): output of encoder
|
42 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
43 |
+
mask (torch.Tensor): output_mask
|
44 |
+
shape: (batch_size, 1, mel_timesteps)
|
45 |
+
n_timesteps (int): number of diffusion steps
|
46 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
47 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
48 |
+
shape: (batch_size, spk_emb_dim)
|
49 |
+
cond: Not used but kept for future purposes
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
sample: generated mel-spectrogram
|
53 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
54 |
+
"""
|
55 |
+
|
56 |
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
57 |
+
cache_size = flow_cache.shape[2]
|
58 |
+
# fix prompt and overlap part mu and z
|
59 |
+
if cache_size != 0:
|
60 |
+
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
|
61 |
+
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
|
62 |
+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
63 |
+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
64 |
+
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
|
65 |
+
|
66 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
67 |
+
if self.t_scheduler == 'cosine':
|
68 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
69 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
70 |
+
|
71 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
72 |
+
"""
|
73 |
+
Fixed euler solver for ODEs.
|
74 |
+
Args:
|
75 |
+
x (torch.Tensor): random noise
|
76 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
77 |
+
shape: (n_timesteps + 1,)
|
78 |
+
mu (torch.Tensor): output of encoder
|
79 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
80 |
+
mask (torch.Tensor): output_mask
|
81 |
+
shape: (batch_size, 1, mel_timesteps)
|
82 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
83 |
+
shape: (batch_size, spk_emb_dim)
|
84 |
+
cond: Not used but kept for future purposes
|
85 |
+
"""
|
86 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
87 |
+
t = t.unsqueeze(dim=0)
|
88 |
+
|
89 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
90 |
+
# Or in future might add like a return_all_steps flag
|
91 |
+
sol = []
|
92 |
+
|
93 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
94 |
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
95 |
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
96 |
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
97 |
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
98 |
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
99 |
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
100 |
+
for step in range(1, len(t_span)):
|
101 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
102 |
+
x_in[:] = x
|
103 |
+
mask_in[:] = mask
|
104 |
+
mu_in[0] = mu
|
105 |
+
t_in[:] = t.unsqueeze(0)
|
106 |
+
spks_in[0] = spks
|
107 |
+
cond_in[0] = cond
|
108 |
+
dphi_dt = self.forward_estimator(
|
109 |
+
x_in, mask_in,
|
110 |
+
mu_in, t_in,
|
111 |
+
spks_in,
|
112 |
+
cond_in
|
113 |
+
)
|
114 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
115 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
116 |
+
x = x + dt * dphi_dt
|
117 |
+
t = t + dt
|
118 |
+
sol.append(x)
|
119 |
+
if step < len(t_span) - 1:
|
120 |
+
dt = t_span[step + 1] - t
|
121 |
+
|
122 |
+
return sol[-1].float()
|
123 |
+
|
124 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
125 |
+
if isinstance(self.estimator, torch.nn.Module):
|
126 |
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
127 |
+
else:
|
128 |
+
with self.lock:
|
129 |
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
130 |
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
131 |
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
132 |
+
self.estimator.set_input_shape('t', (2,))
|
133 |
+
self.estimator.set_input_shape('spks', (2, 80))
|
134 |
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
135 |
+
# run trt engine
|
136 |
+
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
137 |
+
mask.contiguous().data_ptr(),
|
138 |
+
mu.contiguous().data_ptr(),
|
139 |
+
t.contiguous().data_ptr(),
|
140 |
+
spks.contiguous().data_ptr(),
|
141 |
+
cond.contiguous().data_ptr(),
|
142 |
+
x.data_ptr()])
|
143 |
+
return x
|
144 |
+
|
145 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
146 |
+
"""Computes diffusion loss
|
147 |
+
|
148 |
+
Args:
|
149 |
+
x1 (torch.Tensor): Target
|
150 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
151 |
+
mask (torch.Tensor): target mask
|
152 |
+
shape: (batch_size, 1, mel_timesteps)
|
153 |
+
mu (torch.Tensor): output of encoder
|
154 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
155 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
156 |
+
shape: (batch_size, spk_emb_dim)
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
loss: conditional flow matching loss
|
160 |
+
y: conditional flow
|
161 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
162 |
+
"""
|
163 |
+
b, _, t = mu.shape
|
164 |
+
|
165 |
+
# random timestep
|
166 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
167 |
+
if self.t_scheduler == 'cosine':
|
168 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
169 |
+
# sample noise p(x_0)
|
170 |
+
z = torch.randn_like(x1)
|
171 |
+
|
172 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
173 |
+
u = x1 - (1 - self.sigma_min) * z
|
174 |
+
|
175 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
176 |
+
if self.training_cfg_rate > 0:
|
177 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
178 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
179 |
+
spks = spks * cfg_mask.view(-1, 1)
|
180 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
181 |
+
|
182 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
183 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
184 |
+
return loss, y
|
185 |
+
|
186 |
+
|
187 |
+
class CausalConditionalCFM(ConditionalCFM):
|
188 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
189 |
+
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
190 |
+
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
191 |
+
|
192 |
+
@torch.inference_mode()
|
193 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
194 |
+
"""Forward diffusion
|
195 |
+
|
196 |
+
Args:
|
197 |
+
mu (torch.Tensor): output of encoder
|
198 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
199 |
+
mask (torch.Tensor): output_mask
|
200 |
+
shape: (batch_size, 1, mel_timesteps)
|
201 |
+
n_timesteps (int): number of diffusion steps
|
202 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
203 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
204 |
+
shape: (batch_size, spk_emb_dim)
|
205 |
+
cond: Not used but kept for future purposes
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
sample: generated mel-spectrogram
|
209 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
210 |
+
"""
|
211 |
+
|
212 |
+
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
213 |
+
# fix prompt and overlap part mu and z
|
214 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
215 |
+
if self.t_scheduler == 'cosine':
|
216 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
217 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
third_party/cosyvoice/flow/length_regulator.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Tuple
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch
|
17 |
+
from torch.nn import functional as F
|
18 |
+
from cosyvoice.utils.mask import make_pad_mask
|
19 |
+
|
20 |
+
|
21 |
+
class InterpolateRegulator(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
channels: int,
|
25 |
+
sampling_ratios: Tuple,
|
26 |
+
out_channels: int = None,
|
27 |
+
groups: int = 1,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.sampling_ratios = sampling_ratios
|
31 |
+
out_channels = out_channels or channels
|
32 |
+
model = nn.ModuleList([])
|
33 |
+
if len(sampling_ratios) > 0:
|
34 |
+
for _ in sampling_ratios:
|
35 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
36 |
+
norm = nn.GroupNorm(groups, channels)
|
37 |
+
act = nn.Mish()
|
38 |
+
model.extend([module, norm, act])
|
39 |
+
model.append(
|
40 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
41 |
+
)
|
42 |
+
self.model = nn.Sequential(*model)
|
43 |
+
|
44 |
+
def forward(self, x, ylens=None):
|
45 |
+
# x in (B, T, D)
|
46 |
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
47 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
48 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
49 |
+
olens = ylens
|
50 |
+
return out * mask, olens
|
51 |
+
|
52 |
+
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
53 |
+
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
54 |
+
# x in (B, T, D)
|
55 |
+
if x2.shape[1] > 40:
|
56 |
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
57 |
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
58 |
+
mode='linear')
|
59 |
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
60 |
+
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
61 |
+
else:
|
62 |
+
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
63 |
+
if x1.shape[1] != 0:
|
64 |
+
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
65 |
+
x = torch.concat([x1, x2], dim=2)
|
66 |
+
else:
|
67 |
+
x = x2
|
68 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
69 |
+
return out, mel_len1 + mel_len2
|