danielhanchen commited on
Commit
e92902d
·
verified ·
1 Parent(s): 6cbe4f2

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ tokenizer_config.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ checkpoint-*/
README.md ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ language:
4
+ - zh
5
+ - en
6
+ base_model:
7
+ - meta-llama/Llama-3.2-1B-Instruct
8
+ tags:
9
+ - Text-to-Speech
10
+ pipeline_tag: text-to-speech
11
+ ---
12
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2502.04128)
13
+
14
+ **Update (2025-05-10):** Sometimes I find that top_p=0.95 and temperature=0.9 produce more stable results.
15
+
16
+ **Update (2025-02-13):** Add [Llasa finetune instruction](https://github.com/zhenye234/LLaSA_training/tree/main/finetune).
17
+
18
+
19
+ **Update (2025-02-07):** Our paper has been released!
20
+
21
+
22
+ LLaSA: Scaling Train-Time and Inference-Time Compute for LLaMA-based Speech Synthesis
23
+
24
+ - **Train from Scratch**: If you want to train the model from scratch, use the [LLaSA Training Repository](https://github.com/zhenye234/LLaSA_training).
25
+
26
+ - **Scale for Test-Time Computation**: If you want to experiment with scaling for test-time computation, use the [LLaSA Testing Repository](https://github.com/zhenye234/LLaSA_inference).
27
+
28
+ ## Model Information
29
+ Our model, Llasa, is a text-to-speech (TTS) system that extends the text-based LLaMA (1B,3B, and 8B) language model by incorporating speech tokens from the XCodec2 codebook,
30
+ which contains 65,536 tokens. We trained Llasa on a dataset comprising 250,000 hours of Chinese-English speech data.
31
+ The model is capable of generating speech **either solely from input text or by utilizing a given speech prompt.**
32
+
33
+
34
+
35
+ ## How to use
36
+ Install [XCodec2](https://huggingface.co/HKUSTAudio/xcodec2).
37
+
38
+
39
+ **1. Speech synthesis solely from input text**
40
+ ```python
41
+ from transformers import AutoTokenizer, AutoModelForCausalLM
42
+ import torch
43
+ import soundfile as sf
44
+
45
+ llasa_1b ='HKUSTAudio/Llasa-1B'
46
+
47
+ tokenizer = AutoTokenizer.from_pretrained(llasa_1b)
48
+ model = AutoModelForCausalLM.from_pretrained(llasa_1b)
49
+ model.eval()
50
+ model.to('cuda')
51
+
52
+ from xcodec2.modeling_xcodec2 import XCodec2Model
53
+
54
+ model_path = "HKUSTAudio/xcodec2"
55
+
56
+ Codec_model = XCodec2Model.from_pretrained(model_path)
57
+ Codec_model.eval().cuda()
58
+
59
+ input_text = 'Dealing with family secrets is never easy. Yet, sometimes, omission is a form of protection, intending to safeguard some from the harsh truths. One day, I hope you understand the reasons behind my actions. Until then, Anna, please, bear with me.'
60
+ # input_text = '突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"'
61
+ def ids_to_speech_tokens(speech_ids):
62
+
63
+ speech_tokens_str = []
64
+ for speech_id in speech_ids:
65
+ speech_tokens_str.append(f"<|s_{speech_id}|>")
66
+ return speech_tokens_str
67
+
68
+ def extract_speech_ids(speech_tokens_str):
69
+
70
+ speech_ids = []
71
+ for token_str in speech_tokens_str:
72
+ if token_str.startswith('<|s_') and token_str.endswith('|>'):
73
+ num_str = token_str[4:-2]
74
+
75
+ num = int(num_str)
76
+ speech_ids.append(num)
77
+ else:
78
+ print(f"Unexpected token: {token_str}")
79
+ return speech_ids
80
+
81
+ #TTS start!
82
+ with torch.no_grad():
83
+
84
+ formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
85
+
86
+ # Tokenize the text
87
+ chat = [
88
+ {"role": "user", "content": "Convert the text to speech:" + formatted_text},
89
+ {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
90
+ ]
91
+
92
+ input_ids = tokenizer.apply_chat_template(
93
+ chat,
94
+ tokenize=True,
95
+ return_tensors='pt',
96
+ continue_final_message=True
97
+ )
98
+ input_ids = input_ids.to('cuda')
99
+ speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
100
+
101
+ # Generate the speech autoregressively
102
+ outputs = model.generate(
103
+ input_ids,
104
+ max_length=2048, # We trained our model with a max length of 2048
105
+ eos_token_id= speech_end_id ,
106
+ do_sample=True,
107
+ top_p=1, # Adjusts the diversity of generated content
108
+ temperature=0.8, # Controls randomness in output
109
+ )
110
+ # Extract the speech tokens
111
+ generated_ids = outputs[0][input_ids.shape[1]:-1]
112
+
113
+ speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
114
+
115
+ # Convert token <|s_23456|> to int 23456
116
+ speech_tokens = extract_speech_ids(speech_tokens)
117
+
118
+ speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
119
+
120
+ # Decode the speech tokens to speech waveform
121
+ gen_wav = Codec_model.decode_code(speech_tokens)
122
+
123
+
124
+ sf.write("gen.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
125
+ ```
126
+
127
+ **2. Speech synthesis utilizing a given speech prompt**
128
+
129
+ ```python
130
+ from transformers import AutoTokenizer, AutoModelForCausalLM
131
+ import torch
132
+ import soundfile as sf
133
+
134
+ llasa_1b ='HKUSTAudio/Llasa-1b'
135
+
136
+ tokenizer = AutoTokenizer.from_pretrained(llasa_1b)
137
+ model = AutoModelForCausalLM.from_pretrained(llasa_1b)
138
+ model.eval()
139
+ model.to('cuda')
140
+
141
+ from xcodec2.modeling_xcodec2 import XCodec2Model
142
+
143
+ model_path = "HKUSTAudio/xcodec2"
144
+
145
+ Codec_model = XCodec2Model.from_pretrained(model_path)
146
+ Codec_model.eval().cuda()
147
+ # only 16khz speech support!
148
+ prompt_wav, sr = sf.read("太乙真人.wav") # you can find wav in Files
149
+ #prompt_wav, sr = sf.read("Anna.wav") # English prompt
150
+ prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0)
151
+
152
+ prompt_text ="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"
153
+ #promt_text = "A chance to leave him alone, but... No. She just wanted to see him again. Anna, you don't know how it feels to lose a sister. Anna, I'm sorry, but your father asked me not to tell you anything."
154
+ target_text = '突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"'
155
+ #target_text = "Dealing with family secrets is never easy. Yet, sometimes, omission is a form of protection, intending to safeguard some from the harsh truths. One day, I hope you understand the reasons behind my actions. Until then, Anna, please, bear with me."
156
+ input_text = prompt_text + ' ' + target_text
157
+
158
+ def ids_to_speech_tokens(speech_ids):
159
+
160
+ speech_tokens_str = []
161
+ for speech_id in speech_ids:
162
+ speech_tokens_str.append(f"<|s_{speech_id}|>")
163
+ return speech_tokens_str
164
+
165
+ def extract_speech_ids(speech_tokens_str):
166
+
167
+ speech_ids = []
168
+ for token_str in speech_tokens_str:
169
+ if token_str.startswith('<|s_') and token_str.endswith('|>'):
170
+ num_str = token_str[4:-2]
171
+
172
+ num = int(num_str)
173
+ speech_ids.append(num)
174
+ else:
175
+ print(f"Unexpected token: {token_str}")
176
+ return speech_ids
177
+
178
+ #TTS start!
179
+ with torch.no_grad():
180
+ # Encode the prompt wav
181
+ vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
182
+ print("Prompt Vq Code Shape:", vq_code_prompt.shape )
183
+
184
+ vq_code_prompt = vq_code_prompt[0,0,:]
185
+ # Convert int 12345 to token <|s_12345|>
186
+ speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
187
+
188
+ formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
189
+
190
+ # Tokenize the text and the speech prefix
191
+ chat = [
192
+ {"role": "user", "content": "Convert the text to speech:" + formatted_text},
193
+ {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
194
+ ]
195
+
196
+ input_ids = tokenizer.apply_chat_template(
197
+ chat,
198
+ tokenize=True,
199
+ return_tensors='pt',
200
+ continue_final_message=True
201
+ )
202
+ input_ids = input_ids.to('cuda')
203
+ speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
204
+
205
+ # Generate the speech autoregressively
206
+ outputs = model.generate(
207
+ input_ids,
208
+ max_length=2048, # We trained our model with a max length of 2048
209
+ eos_token_id= speech_end_id ,
210
+ do_sample=True,
211
+ top_p=1,
212
+ temperature=0.8,
213
+ )
214
+ # Extract the speech tokens
215
+ generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
216
+
217
+ speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
218
+
219
+ # Convert token <|s_23456|> to int 23456
220
+ speech_tokens = extract_speech_ids(speech_tokens)
221
+
222
+ speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
223
+
224
+ # Decode the speech tokens to speech waveform
225
+ gen_wav = Codec_model.decode_code(speech_tokens)
226
+
227
+ # if only need the generated part
228
+ # gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
229
+
230
+ sf.write("gen.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
231
+ ```
232
+
233
+
234
+ ## Disclaimer
235
+
236
+ This model is licensed under the CC BY-NC 4.0 License, which prohibits free commercial use because of ethics and privacy concerns; detected violations will result in legal consequences.
237
+
238
+ This codebase is strictly prohibited from being used for any illegal purposes in any country or region. Please refer to your local laws about DMCA and other related laws.
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/aifs4su/data/zheny/speechllm/logs_llasa/data_v12_25_tts_25w_mix_final_1b/checkpoint-240000",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "bos_token_id": 128000,
9
+ "eos_token_id": [
10
+ 128001,
11
+ 128008,
12
+ 128009
13
+ ],
14
+ "head_dim": 64,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2048,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 8192,
19
+ "max_position_embeddings": 131072,
20
+ "mlp_bias": false,
21
+ "model_type": "llama",
22
+ "num_attention_heads": 32,
23
+ "num_hidden_layers": 16,
24
+ "num_key_value_heads": 8,
25
+ "pretraining_tp": 1,
26
+ "rms_norm_eps": 1e-05,
27
+ "rope_scaling": {
28
+ "factor": 32.0,
29
+ "high_freq_factor": 4.0,
30
+ "low_freq_factor": 1.0,
31
+ "original_max_position_embeddings": 8192,
32
+ "rope_type": "llama3"
33
+ },
34
+ "rope_theta": 500000.0,
35
+ "tie_word_embeddings": true,
36
+ "torch_dtype": "bfloat16",
37
+ "transformers_version": "4.46.1",
38
+ "use_cache": true,
39
+ "vocab_size": 193800
40
+ }
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 128000,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 128001,
6
+ 128008,
7
+ 128009
8
+ ],
9
+ "temperature": 0.6,
10
+ "top_p": 0.9,
11
+ "transformers_version": "4.46.1"
12
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:042cc86de450e178c1d177a49aa5b570164f3e286ccbdb3455c1192d4e3cfbf6
3
+ size 2740113872
special_tokens_map.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin_of_text|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|eot_id|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|eot_id|>"
17
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71d92f3dbf3c23d734e6356241cef149b42fe79848176a54145b6f9a886fd73b
3
+ size 29521206
tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4003bfd44e3c1e936f97823308c868d275d93f0c31e108a4f26ad3d2e3703fbf
3
+ size 11710428