beyoru commited on
Commit
0bfc303
·
verified ·
1 Parent(s): 6434702

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +118 -0
README.md CHANGED
@@ -9,3 +9,121 @@ library_tag: spark-tts
9
  ---
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ---
10
 
11
 
12
+ ### USAGE:
13
+ ```
14
+ import torch
15
+ import re
16
+ import numpy as np
17
+ from typing import Dict, Any
18
+ import torchaudio.transforms as T
19
+
20
+
21
+ @torch.inference_mode()
22
+ def generate_speech_from_text(
23
+ text: str,
24
+ temperature: float = 0.8, # Generation temperature
25
+ top_k: int = 50, # Generation top_k
26
+ top_p: float = 1, # Generation top_p
27
+ max_new_audio_tokens: int = 2048, # Max tokens for audio part
28
+ device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ ) -> np.ndarray:
30
+ """
31
+ Generates speech audio from text using default voice control parameters.
32
+
33
+ Args:
34
+ text (str): The text input to be converted to speech.
35
+ temperature (float): Sampling temperature for generation.
36
+ top_k (int): Top-k sampling parameter.
37
+ top_p (float): Top-p (nucleus) sampling parameter.
38
+ max_new_audio_tokens (int): Max number of new tokens to generate (limits audio length).
39
+ device (torch.device): Device to run inference on.
40
+
41
+ Returns:
42
+ np.ndarray: Generated waveform as a NumPy array.
43
+ """
44
+
45
+ prompt = "".join([
46
+ "<|task_tts|>",
47
+ "<|start_content|>",
48
+ text,
49
+ "<|end_content|>",
50
+ "<|start_global_token|>"
51
+ ])
52
+
53
+ model_inputs = tokenizer([prompt], return_tensors="pt")
54
+
55
+ print("Generating token sequence...")
56
+ generated_ids = model.generate(
57
+ **model_inputs,
58
+ max_new_tokens=max_new_audio_tokens, # Limit generation length
59
+ do_sample=True,
60
+ temperature=temperature,
61
+ top_k=top_k,
62
+ top_p=top_p,
63
+ eos_token_id=tokenizer.eos_token_id, # Stop token
64
+ pad_token_id=tokenizer.pad_token_id # Use models pad token id
65
+ )
66
+ print("Token sequence generated.")
67
+
68
+
69
+ generated_ids_trimmed = generated_ids[:, model_inputs.input_ids.shape[1]:]
70
+
71
+
72
+ predicts_text = tokenizer.batch_decode(generated_ids_trimmed, skip_special_tokens=False)[0]
73
+ # print(f"\nGenerated Text (for parsing):\n{predicts_text}\n") # Debugging
74
+
75
+ # Extract semantic token IDs using regex
76
+ semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", predicts_text)
77
+ if not semantic_matches:
78
+ print("Warning: No semantic tokens found in the generated output.")
79
+ # Handle appropriately - perhaps return silence or raise error
80
+ return np.array([], dtype=np.float32)
81
+
82
+ pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0) # Add batch dim
83
+
84
+ # Extract global token IDs using regex (assuming controllable mode also generates these)
85
+ global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", predicts_text)
86
+ if not global_matches:
87
+ print("Warning: No global tokens found in the generated output (controllable mode). Might use defaults or fail.")
88
+ pred_global_ids = torch.zeros((1, 1), dtype=torch.long)
89
+ else:
90
+ pred_global_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0) # Add batch dim
91
+
92
+ pred_global_ids = pred_global_ids.unsqueeze(0) # Shape becomes (1, 1, N_global)
93
+
94
+ print(f"Found {pred_semantic_ids.shape[1]} semantic tokens.")
95
+ print(f"Found {pred_global_ids.shape[2]} global tokens.")
96
+
97
+
98
+ # 5. Detokenize using BiCodecTokenizer
99
+ print("Detokenizing audio tokens...")
100
+ # Ensure audio_tokenizer and its internal model are on the correct device
101
+
102
+ # Squeeze the extra dimension from global tokens as seen in SparkTTS example
103
+ wav_np = audio_tokenizer.detokenize(
104
+ pred_global_ids.squeeze(0), # Shape (1, N_global)
105
+ pred_semantic_ids # Shape (1, N_semantic)
106
+ )
107
+ print("Detokenization complete.")
108
+
109
+ return wav_np
110
+
111
+ if __name__ == "__main__":
112
+ print(f"Generating speech for: '{input_text}'")
113
+ text = f"{chosen_voice}: " + input_text if chosen_voice else input_text
114
+ generated_waveform = generate_speech_from_text(input_text)
115
+
116
+ if generated_waveform.size > 0:
117
+ import soundfile as sf
118
+ output_filename = "generated_speech_controllable.wav"
119
+ sample_rate = audio_tokenizer.config.get("sample_rate", 16000)
120
+ sf.write(output_filename, generated_waveform, sample_rate)
121
+ print(f"Audio saved to {output_filename}")
122
+
123
+ # Optional: Play in notebook
124
+ from IPython.display import Audio, display
125
+ display(Audio(generated_waveform, rate=sample_rate))
126
+ else:
127
+ print("Audio generation failed (no tokens found?).")
128
+ ```
129
+