Instructions to use OpenMOSS-Team/MOSS-TTS-Realtime with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use OpenMOSS-Team/MOSS-TTS-Realtime with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-to-speech", model="OpenMOSS-Team/MOSS-TTS-Realtime")# Load model directly from transformers import MossTTSRealtime model = MossTTSRealtime.from_pretrained("OpenMOSS-Team/MOSS-TTS-Realtime", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Streaming inference utilities for MossTTSRealtime.""" | |
| from __future__ import annotations | |
| import contextlib | |
| import re | |
| from typing import Iterable, Iterator, List, Optional, Sequence | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers.cache_utils import StaticCache | |
| from transformers.utils import is_torchaudio_available, requires_backends | |
| from transformers.utils.import_utils import requires | |
| if is_torchaudio_available(): | |
| import torchaudio | |
| class MossTTSRealtimeInference: | |
| """Step-wise inference wrapper for MossTTSRealtime. | |
| This class mirrors the non-streaming inference logic but exposes a | |
| prefill/step/finish API for streaming usage. | |
| """ | |
| def __init__( | |
| self, | |
| model, | |
| tokenizer, | |
| max_length: int = 1000, | |
| channels: int = 16, | |
| audio_channel_pad: int = 1024, | |
| audio_bos_token: int = 1025, | |
| audio_eos_token: int = 1026, | |
| text_pad_id: int = 151655, | |
| aud_pad_id: int = 151654, | |
| ): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.channels = channels | |
| self.audio_channel_pad = audio_channel_pad | |
| self.audio_bos_token = audio_bos_token | |
| self.audio_eos_token = audio_eos_token | |
| self.text_pad_id = text_pad_id | |
| self.aud_pad_id = aud_pad_id | |
| self.past_key_values = None | |
| self.attention_mask = None | |
| self._generated_tokens: List[torch.Tensor] = [] | |
| self._is_stopping = None | |
| self._last_audio_tokens = None | |
| self._step_idx = 0 | |
| def device(self): | |
| return next(self.model.parameters()).device | |
| def is_finished(self) -> bool: | |
| return self._is_stopping is not None and bool(self._is_stopping.all()) | |
| def reset_generation_state(self, keep_cache: bool = True): | |
| if not keep_cache: | |
| self.past_key_values = None | |
| self.attention_mask = None | |
| # Keep the mask when reusing cache so it stays aligned with past_key_values. | |
| # This allows concatenation with the next turn prefill mask. | |
| self._generated_tokens = [] | |
| self._is_stopping = None | |
| self._last_audio_tokens = None | |
| self._step_idx = 0 | |
| def _normalize_input_ids(self, input_ids): | |
| if isinstance(input_ids, torch.Tensor): | |
| input_ids = input_ids.detach().cpu().numpy() | |
| if isinstance(input_ids, np.ndarray): | |
| if input_ids.ndim == 2: | |
| return [input_ids] | |
| if input_ids.ndim == 3: | |
| return [input_ids[i] for i in range(input_ids.shape[0])] | |
| if isinstance(input_ids, (list, tuple)): | |
| return [np.array(item) for item in input_ids] | |
| raise ValueError("input_ids must be a list/array/tensor of shape [T, C] or [B, T, C].") | |
| def _normalize_text_prefix(self, text_prefix_ids, batch_size: int) -> list[list[int]]: | |
| if text_prefix_ids is None: | |
| raise ValueError("text_prefix_ids must be provided for prefill.") | |
| if isinstance(text_prefix_ids, torch.Tensor): | |
| text_prefix_ids = text_prefix_ids.detach().cpu().tolist() | |
| if isinstance(text_prefix_ids, np.ndarray): | |
| text_prefix_ids = text_prefix_ids.tolist() | |
| if isinstance(text_prefix_ids, list): | |
| if len(text_prefix_ids) == 0: | |
| return [[] for _ in range(batch_size)] | |
| if isinstance(text_prefix_ids[0], (int, np.integer)): | |
| return [list(text_prefix_ids)] | |
| if len(text_prefix_ids) == 1 and batch_size > 1: | |
| return [list(text_prefix_ids[0]) for _ in range(batch_size)] | |
| if len(text_prefix_ids) != batch_size: | |
| raise ValueError( | |
| f"text_prefix_ids batch size mismatch: got {len(text_prefix_ids)}, expected {batch_size}." | |
| ) | |
| return [list(item) for item in text_prefix_ids] | |
| raise ValueError("text_prefix_ids must be list-like or tensor-like.") | |
| def prefill( | |
| self, | |
| input_ids, | |
| text_prefix_ids, | |
| max_prefill_len: Optional[int] = None, | |
| past_key_values=None, | |
| device: Optional[torch.device] = None, | |
| temperature: float = 0.8, | |
| top_p: float = 0.6, | |
| top_k: int = 30, | |
| do_sample: bool = True, | |
| repetition_penalty: Optional[float] = 1.1, | |
| repetition_window: Optional[int] = 50, | |
| ) -> torch.Tensor: | |
| if device is None: | |
| device = self.device | |
| if past_key_values is not None: | |
| self.past_key_values = past_key_values | |
| input_ids_list = self._normalize_input_ids(input_ids) | |
| batch_size = len(input_ids_list) | |
| text_prefix_list = self._normalize_text_prefix(text_prefix_ids, batch_size) | |
| concat_inputs_id_list = [] | |
| for i in range(batch_size): | |
| prefix = text_prefix_list[i] | |
| if max_prefill_len is not None: | |
| prefix = prefix[:max_prefill_len] | |
| if len(prefix) == 0: | |
| raise ValueError("Prefill requires at least one text token.") | |
| text_seg = np.full((len(prefix), self.channels + 1), self.audio_channel_pad, dtype=np.int64) | |
| text_seg[:, 0] = np.array(prefix, dtype=np.int64) | |
| text_seg[len(prefix) - 1, 1] = self.audio_bos_token | |
| concat_inputs_id = np.concatenate([input_ids_list[i], text_seg], axis=0) | |
| concat_inputs_id_list.append(concat_inputs_id) | |
| attention_masks = [np.ones(ids.shape[0], dtype=np.bool_) for ids in concat_inputs_id_list] | |
| max_len = max(ids.shape[0] for ids in concat_inputs_id_list) | |
| padded_input_ids, padded_attns = [], [] | |
| pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.text_pad_id | |
| for ids, attn in zip(concat_inputs_id_list, attention_masks): | |
| pad_len = max_len - ids.shape[0] | |
| input_pad = np.full((pad_len, self.channels + 1), self.audio_channel_pad, dtype=np.int64) | |
| input_pad[:, 0] = pad_token_id | |
| padded_input_ids.append(np.concatenate([input_pad, ids])) | |
| attn_pad = np.zeros(pad_len, dtype=np.bool_) | |
| padded_attns.append(np.concatenate([attn_pad, attn])) | |
| current_input_ids = torch.from_numpy(np.stack(padded_input_ids)).to(device) | |
| current_attention_mask = torch.from_numpy(np.stack(padded_attns)).to(device) | |
| # For multi-turn continuation, concatenate the cached mask and the current prefill mask. | |
| if self.attention_mask is not None and self.past_key_values is not None: | |
| current_attention_mask = torch.cat([self.attention_mask, current_attention_mask], dim=-1) | |
| outputs = self.model( | |
| input_ids=current_input_ids, | |
| attention_mask=current_attention_mask, | |
| past_key_values=self.past_key_values, | |
| use_cache=True, | |
| return_dict=True, | |
| ) | |
| self.past_key_values = outputs.past_key_values | |
| self.attention_mask = current_attention_mask | |
| backbone_hidden_states = outputs.last_hidden_state[:, -1:, :] | |
| audio_tokens = self.generate_local_transformer( | |
| hidden_states=backbone_hidden_states, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=do_sample, | |
| repetition_penalty=repetition_penalty, | |
| repetition_window=repetition_window, | |
| generated_tokens=None, | |
| gen_step=0, | |
| ) | |
| self._generated_tokens = [audio_tokens] | |
| self._last_audio_tokens = audio_tokens | |
| self._is_stopping = audio_tokens[:, 0] == self.audio_eos_token | |
| self._step_idx = 1 | |
| return audio_tokens | |
| def step( | |
| self, | |
| text_token: Optional[Iterable[int] | torch.Tensor | int], | |
| temperature: float = 0.8, | |
| top_p: float = 0.6, | |
| top_k: int = 30, | |
| do_sample: bool = True, | |
| repetition_penalty: Optional[float] = 1.1, | |
| repetition_window: Optional[int] = 50, | |
| ) -> torch.Tensor: | |
| if self._last_audio_tokens is None or self.attention_mask is None: | |
| raise ValueError("You must call prefill() before step().") | |
| if self.is_finished: | |
| return self._last_audio_tokens | |
| batch_size = self._last_audio_tokens.shape[0] | |
| if text_token is None: | |
| text_tokens = [self.text_pad_id] * batch_size | |
| elif isinstance(text_token, torch.Tensor): | |
| text_tokens = text_token.detach().cpu().tolist() | |
| elif isinstance(text_token, (list, tuple, np.ndarray)): | |
| text_tokens = list(text_token) | |
| else: | |
| text_tokens = [int(text_token)] | |
| if len(text_tokens) != batch_size: | |
| raise ValueError(f"text_token batch size mismatch: got {len(text_tokens)}, expected {batch_size}.") | |
| device = self._last_audio_tokens.device | |
| text_t = torch.tensor(text_tokens, device=device, dtype=torch.long) | |
| step_ids = torch.cat([text_t[:, None, None], self._last_audio_tokens.unsqueeze(1)], dim=2) | |
| self.attention_mask = torch.cat([self.attention_mask, (~self._is_stopping).unsqueeze(-1)], dim=-1) | |
| outputs = self.model( | |
| input_ids=step_ids, | |
| attention_mask=self.attention_mask, | |
| past_key_values=self.past_key_values, | |
| use_cache=True, | |
| return_dict=True, | |
| ) | |
| self.past_key_values = outputs.past_key_values | |
| backbone_hidden_states = outputs.last_hidden_state[:, -1:, :] | |
| history = torch.stack(self._generated_tokens, dim=1) if self._generated_tokens else None | |
| audio_tokens = self.generate_local_transformer( | |
| hidden_states=backbone_hidden_states, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=do_sample, | |
| repetition_penalty=repetition_penalty, | |
| repetition_window=repetition_window, | |
| generated_tokens=history, | |
| gen_step=self._step_idx, | |
| ) | |
| self._generated_tokens.append(audio_tokens) | |
| self._last_audio_tokens = audio_tokens | |
| self._is_stopping |= audio_tokens[:, 0] == self.audio_eos_token | |
| self._step_idx += 1 | |
| return audio_tokens | |
| def finish( | |
| self, | |
| max_steps: Optional[int] = None, | |
| temperature: float = 0.8, | |
| top_p: float = 0.6, | |
| top_k: int = 30, | |
| do_sample: bool = True, | |
| repetition_penalty: Optional[float] = 1.1, | |
| repetition_window: Optional[int] = 50, | |
| ) -> list[torch.Tensor]: | |
| outputs = [] | |
| steps_left = max_steps if max_steps is not None else self.max_length | |
| while steps_left > 0 and not self.is_finished: | |
| outputs.append( | |
| self.step( | |
| text_token=None, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=do_sample, | |
| repetition_penalty=repetition_penalty, | |
| repetition_window=repetition_window, | |
| ) | |
| ) | |
| steps_left -= 1 | |
| return outputs | |
| def generate_local_transformer( | |
| self, | |
| hidden_states: torch.Tensor, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| do_sample: bool, | |
| repetition_penalty: Optional[float], | |
| repetition_window: Optional[int], | |
| generated_tokens: Optional[torch.Tensor], | |
| gen_step: int, | |
| ) -> torch.Tensor: | |
| batch_size = hidden_states.shape[0] | |
| device = hidden_states.device | |
| local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size) | |
| output_token = torch.empty(batch_size, self.channels, dtype=torch.long, device=device) | |
| past_key_values = StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels) | |
| local_token = None | |
| cache_pos_t = torch.zeros(1, dtype=torch.long, device=device) | |
| for i in range(self.channels): | |
| cache_pos_t.fill_(i) | |
| local_outputs = self.model.local_transformer( | |
| input_ids=local_token, | |
| inputs_embeds=local_inputs, | |
| past_key_values=past_key_values, | |
| cache_position=cache_pos_t, | |
| codebook_idx=i, | |
| use_cache=True, | |
| logits_to_keep=1, | |
| ) | |
| logits = local_outputs.logits | |
| if repetition_penalty and repetition_penalty != 1.0 and generated_tokens is not None: | |
| logits = self.apply_repetition_penalty( | |
| scores=logits, | |
| history_tokens=generated_tokens[:, :gen_step, i], | |
| penalty=float(repetition_penalty), | |
| repetition_window=repetition_window, | |
| ) | |
| local_token = self.sample_token( | |
| logits=logits, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=do_sample, | |
| ) | |
| output_token[:, i] = local_token.squeeze(-1) | |
| if i == 0: | |
| local_inputs = None | |
| return output_token | |
| def apply_repetition_penalty( | |
| self, | |
| scores: torch.Tensor, | |
| history_tokens: torch.Tensor, | |
| penalty: float = 1.1, | |
| repetition_window: Optional[int] = None, | |
| ): | |
| scores_ = scores[:, 0, :] | |
| ht = history_tokens | |
| if repetition_window is not None and repetition_window > 0: | |
| ht = ht[:, -repetition_window:] | |
| cur = scores_.gather(1, ht) | |
| new = torch.where(cur < 0, cur * penalty, cur / penalty) | |
| scores_.scatter_(1, ht, new) | |
| return scores_ | |
| def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True): | |
| if not do_sample or temperature == 0: | |
| return torch.argmax(logits, dim=-1) | |
| logits = logits / temperature | |
| original_shape = logits.shape | |
| vocab_size = original_shape[-1] | |
| reshaped_logits = logits.reshape(-1, vocab_size) | |
| if top_k is not None: | |
| reshaped_logits = self.apply_top_k(reshaped_logits, top_k) | |
| if top_p is not None: | |
| reshaped_logits = self.apply_top_p(reshaped_logits, top_p) | |
| probs = F.softmax(reshaped_logits, dim=-1) | |
| next_tokens_flat = torch.multinomial(probs, num_samples=1) | |
| output_shape = original_shape[:-1] | |
| return next_tokens_flat.view(output_shape) | |
| def apply_top_k(self, logits, top_k, filter_value=float("-inf"), min_tokens_to_keep: int = 1): | |
| if not isinstance(top_k, int) or top_k <= 0: | |
| raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") | |
| batch_size, vocab_size = logits.shape | |
| top_k = max(top_k, min_tokens_to_keep) | |
| top_k = min(top_k, vocab_size) | |
| indices_to_remove = torch.topk(logits, top_k, dim=-1).values[..., -1, None] | |
| return logits.masked_fill(logits < indices_to_remove, filter_value) | |
| def apply_top_p(self, logits, top_p, filter_value=float("-inf"), min_tokens_to_keep: int = 1): | |
| top_p = float(top_p) | |
| if top_p < 0 or top_p > 1.0: | |
| raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=False) | |
| cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | |
| sorted_indices_to_remove = cumulative_probs <= (1 - top_p) | |
| sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 | |
| indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter(1, sorted_indices, sorted_indices_to_remove) | |
| logits_processed = logits.masked_fill(indices_to_remove, filter_value) | |
| return logits_processed | |
| class MossTTSRealtimeStreamingSession: | |
| """Manage text-to-audio streaming for a single conversation.""" | |
| _split_pattern = re.compile( | |
| r"[。!?!?\.\u2026]\s*" # sentence boundaries: 。!? ! ? . … | |
| r"|[,,;;::\u2014\u2013\-]\s*" # short pauses: , , ; ; : : — – - | |
| r"|\)\s*|\]\s*" # closing brackets: ) ] | |
| r"|\n" | |
| ) | |
| def __init__( | |
| self, | |
| inferencer: MossTTSRealtimeInference, | |
| processor, | |
| codec=None, | |
| codec_sample_rate: int = 24000, | |
| codec_encode_kwargs: Optional[dict] = None, | |
| prefill_text_len: int = 12, | |
| text_buffer_size: int = 32, | |
| min_text_chunk_chars: int = 8, | |
| temperature: float = 0.8, | |
| top_p: float = 0.6, | |
| top_k: int = 30, | |
| do_sample: bool = True, | |
| repetition_penalty: Optional[float] = 1.1, | |
| repetition_window: Optional[int] = 50, | |
| ): | |
| self.inferencer = inferencer | |
| self.processor = processor | |
| self.tokenizer = processor.tokenizer | |
| self.codec = codec | |
| self.codec_sample_rate = codec_sample_rate | |
| self.codec_encode_kwargs = codec_encode_kwargs or {} | |
| self.prefill_text_len = prefill_text_len | |
| self.text_buffer_size = text_buffer_size | |
| self.min_text_chunk_chars = min_text_chunk_chars | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| self.top_k = top_k | |
| self.do_sample = do_sample | |
| self.repetition_penalty = repetition_penalty | |
| self.repetition_window = repetition_window | |
| self._voice_prompt_tokens = None | |
| self._turn_input_ids = None | |
| self._turn_idx = 0 | |
| self._text_cache = "" | |
| self._pending_tokens: list[int] = [] | |
| self._prefilled = False | |
| self._text_ended = False | |
| def set_voice_prompt_tokens(self, audio_tokens: np.ndarray): | |
| self._voice_prompt_tokens = audio_tokens | |
| def set_voice_prompt(self, audio, sample_rate: Optional[int] = None): | |
| """Set voice prompt from either audio tokens or waveform. | |
| If `audio` is a 2D array whose shape matches the codebook channels, it is | |
| treated as audio tokens. Otherwise a codec is required to encode waveform | |
| prompts into tokens. | |
| """ | |
| if isinstance(audio, np.ndarray) and audio.ndim == 2: | |
| if self.processor.channels in audio.shape: | |
| self._voice_prompt_tokens = audio | |
| return | |
| if isinstance(audio, torch.Tensor) and audio.dim() == 2: | |
| if self.processor.channels in audio.shape: | |
| self._voice_prompt_tokens = audio.detach().cpu().numpy() | |
| return | |
| if self.codec is None: | |
| raise ValueError("codec is required to encode waveform prompts.") | |
| waveform = audio | |
| if isinstance(audio, (str, bytes)): | |
| requires_backends(self, ["torchaudio"]) | |
| wav, sr = torchaudio.load(audio) | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| waveform = wav.squeeze(0) | |
| sample_rate = sr | |
| if isinstance(waveform, np.ndarray): | |
| waveform = torch.from_numpy(waveform) | |
| if not isinstance(waveform, torch.Tensor): | |
| raise ValueError("Unsupported audio type for voice prompt.") | |
| if sample_rate is not None and sample_rate != self.codec_sample_rate: | |
| requires_backends(self, ["torchaudio"]) | |
| waveform = torchaudio.functional.resample(waveform, sample_rate, self.codec_sample_rate) | |
| waveform = waveform.to(self.inferencer.device) | |
| encode_out = self.codec.encode([waveform], **self.codec_encode_kwargs) | |
| if isinstance(encode_out, dict): | |
| if "codes_list" in encode_out: | |
| tokens = encode_out["codes_list"][0] | |
| elif "audio_codes" in encode_out: | |
| tokens = encode_out["audio_codes"][0] | |
| else: | |
| raise ValueError("codec.encode output missing audio codes.") | |
| else: | |
| tokens = encode_out | |
| if isinstance(tokens, torch.Tensor): | |
| tokens = tokens.detach().cpu().numpy() | |
| self._voice_prompt_tokens = tokens | |
| def clear_voice_prompt(self): | |
| self._voice_prompt_tokens = None | |
| def reset_turn( | |
| self, | |
| user_text: Optional[str] = None, | |
| user_audio_tokens: Optional[np.ndarray] = None, | |
| input_ids: Optional[np.ndarray] = None, | |
| include_system_prompt: Optional[bool] = None, | |
| reset_cache: bool = False, | |
| ): | |
| if include_system_prompt is None: | |
| include_system_prompt = self._turn_idx == 0 | |
| if input_ids is None: | |
| if user_text is None or user_audio_tokens is None: | |
| raise ValueError("user_text and user_audio_tokens are required when input_ids is not provided.") | |
| user_prompt = self.processor.make_user_prompt(user_text, user_audio_tokens) | |
| if include_system_prompt: | |
| system_prompt = self.processor.make_ensemble(self._voice_prompt_tokens) | |
| input_ids = np.concatenate([system_prompt, user_prompt], axis=0) | |
| else: | |
| input_ids = user_prompt | |
| self._turn_input_ids = input_ids | |
| self._turn_idx += 1 | |
| self._text_cache = "" | |
| self._pending_tokens = [] | |
| self._prefilled = False | |
| self._text_ended = False | |
| self.inferencer.reset_generation_state(keep_cache=not reset_cache) | |
| def push_text_tokens(self, tokens: Iterable[int]) -> list[torch.Tensor]: | |
| self._pending_tokens.extend([int(t) for t in tokens]) | |
| return self._drain_pending_tokens() | |
| def push_text(self, text_fragment: str) -> list[torch.Tensor]: | |
| self._text_cache += text_fragment | |
| segments = self._extract_text_segments(force=False) | |
| for segment in segments: | |
| self._pending_tokens.extend(self._tokenize(segment)) | |
| return self._drain_pending_tokens() | |
| def end_text(self) -> list[torch.Tensor]: | |
| self._text_ended = True | |
| if self._text_cache: | |
| self._pending_tokens.extend(self._tokenize(self._text_cache)) | |
| self._text_cache = "" | |
| return self._drain_pending_tokens() | |
| def drain(self, max_steps: Optional[int] = None) -> list[torch.Tensor]: | |
| if not self._prefilled: | |
| return [] | |
| return self.inferencer.finish( | |
| max_steps=max_steps, | |
| temperature=self.temperature, | |
| top_p=self.top_p, | |
| top_k=self.top_k, | |
| do_sample=self.do_sample, | |
| repetition_penalty=self.repetition_penalty, | |
| repetition_window=self.repetition_window, | |
| ) | |
| def _tokenize(self, text: str) -> list[int]: | |
| return self.tokenizer.encode(text, add_special_tokens=False) | |
| def _extract_text_segments(self, force: bool) -> list[str]: | |
| segments = [] | |
| if force: | |
| if self._text_cache: | |
| segments.append(self._text_cache) | |
| self._text_cache = "" | |
| return segments | |
| while self._text_cache: | |
| cut_idx = None | |
| if len(self._text_cache) >= self.min_text_chunk_chars: | |
| matches = list(self._split_pattern.finditer(self._text_cache)) | |
| for match in matches: | |
| if match.end() >= self.min_text_chunk_chars: | |
| cut_idx = match.end() | |
| break | |
| if cut_idx is None and len(self._text_cache) >= self.text_buffer_size: | |
| whitespace_idx = self._text_cache.rfind(" ") | |
| if whitespace_idx != -1: | |
| cut_idx = whitespace_idx + 1 | |
| if cut_idx is None: | |
| break | |
| segments.append(self._text_cache[:cut_idx]) | |
| self._text_cache = self._text_cache[cut_idx:] | |
| return segments | |
| def _prefill_if_needed(self) -> list[torch.Tensor]: | |
| if self._prefilled: | |
| return [] | |
| if not self._pending_tokens and not self._text_ended: | |
| return [] | |
| if len(self._pending_tokens) < self.prefill_text_len and not self._text_ended: | |
| return [] | |
| if self._turn_input_ids is None: | |
| raise ValueError("reset_turn must be called before streaming text.") | |
| if self._text_ended: | |
| prefill_len = len(self._pending_tokens) | |
| else: | |
| prefill_len = min(len(self._pending_tokens), self.prefill_text_len) | |
| if prefill_len == 0: | |
| return [] | |
| prefix_tokens = [self._pending_tokens.pop(0) for _ in range(prefill_len)] | |
| audio_tokens = self.inferencer.prefill( | |
| input_ids=[self._turn_input_ids], | |
| text_prefix_ids=[prefix_tokens], | |
| temperature=self.temperature, | |
| top_p=self.top_p, | |
| top_k=self.top_k, | |
| do_sample=self.do_sample, | |
| repetition_penalty=None, | |
| repetition_window=self.repetition_window, | |
| ) | |
| self._prefilled = True | |
| return [audio_tokens] | |
| def _drain_pending_tokens(self) -> list[torch.Tensor]: | |
| outputs: list[torch.Tensor] = [] | |
| outputs.extend(self._prefill_if_needed()) | |
| if not self._prefilled: | |
| return outputs | |
| while self._pending_tokens and not self.inferencer.is_finished: | |
| token = self._pending_tokens.pop(0) | |
| outputs.append( | |
| self.inferencer.step( | |
| token, | |
| temperature=self.temperature, | |
| top_p=self.top_p, | |
| top_k=self.top_k, | |
| do_sample=self.do_sample, | |
| repetition_penalty=self.repetition_penalty, | |
| repetition_window=self.repetition_window, | |
| ) | |
| ) | |
| return outputs | |
| class AudioStreamDecoder: | |
| """Decode audio tokens into waveform chunks with optional crossfade.""" | |
| def __init__( | |
| self, | |
| codec, | |
| chunk_frames: int = 40, | |
| overlap_frames: int = 4, | |
| decode_kwargs: Optional[dict] = None, | |
| device: Optional[torch.device] = None, | |
| ): | |
| self.codec = codec | |
| self.chunk_frames = chunk_frames | |
| self.overlap_frames = overlap_frames | |
| self.decode_kwargs = decode_kwargs or {} | |
| self.device = device | |
| self._buffer: list[torch.Tensor] = [] | |
| self._buffer_len = 0 | |
| self._prev_tail: Optional[torch.Tensor] = None | |
| def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor): | |
| if isinstance(audio_tokens, np.ndarray): | |
| audio_tokens = torch.from_numpy(audio_tokens) | |
| if audio_tokens.dim() != 2: | |
| raise ValueError(f"Expected [T, C] audio tokens, got {tuple(audio_tokens.shape)}") | |
| self._buffer.append(audio_tokens) | |
| self._buffer_len += audio_tokens.shape[0] | |
| def audio_chunks(self) -> Iterable[torch.Tensor]: | |
| while self._buffer_len >= self.chunk_frames: | |
| chunk_tokens = self._consume_frames(self.chunk_frames) | |
| wav = self._decode(chunk_tokens, chunk_duration=0.32) | |
| yield self._apply_crossfade(wav) | |
| def flush(self) -> Optional[torch.Tensor]: | |
| if self._buffer_len == 0: | |
| return None | |
| chunk_tokens = self._consume_frames(self._buffer_len) | |
| wav = self._decode(chunk_tokens) | |
| return self._apply_crossfade(wav, final_chunk=True) | |
| def _consume_frames(self, num_frames: int) -> torch.Tensor: | |
| frames = [] | |
| remaining = num_frames | |
| while remaining > 0 and self._buffer: | |
| head = self._buffer[0] | |
| if head.shape[0] <= remaining: | |
| frames.append(head) | |
| remaining -= head.shape[0] | |
| self._buffer.pop(0) | |
| else: | |
| frames.append(head[:remaining]) | |
| self._buffer[0] = head[remaining:] | |
| remaining = 0 | |
| self._buffer_len -= num_frames - remaining | |
| return torch.cat(frames, dim=0) | |
| def _decode(self, tokens: torch.Tensor, chunk_duration: float = 0.32) -> torch.Tensor: | |
| device = self.device | |
| if device is None: | |
| if hasattr(self.codec, "device"): | |
| device = self.codec.device | |
| else: | |
| try: | |
| device = next(self.codec.parameters()).device | |
| except Exception: | |
| device = None | |
| if device is not None: | |
| tokens = tokens.to(device) | |
| tokens_t = tokens.permute(1, 0) | |
| # allow callers to override decode settings (e.g. chunk_duration=-1 to disable internal streaming) | |
| decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {} | |
| if "chunk_duration" in decode_kwargs: | |
| override = decode_kwargs.pop("chunk_duration") | |
| if override is None: | |
| chunk_duration_arg = None | |
| else: | |
| try: | |
| override_f = float(override) | |
| except Exception: | |
| override_f = None | |
| chunk_duration_arg = None if override_f is None or override_f <= 0 else override_f | |
| else: | |
| chunk_duration_arg = chunk_duration | |
| decoded = self.codec.decode(tokens_t, chunk_duration=chunk_duration_arg, **decode_kwargs) | |
| if isinstance(decoded, dict): | |
| wav = decoded["audio"][0] | |
| else: | |
| wav = decoded | |
| if isinstance(wav, np.ndarray): | |
| wav = torch.from_numpy(wav) | |
| if wav.dim() > 1: | |
| wav = wav.squeeze(0) | |
| return wav | |
| def _apply_crossfade(self, wav: torch.Tensor, final_chunk: bool = False) -> torch.Tensor: | |
| if self.overlap_frames <= 0: | |
| return wav | |
| if self._prev_tail is None: | |
| self._prev_tail = wav[-self._overlap_samples(wav) :].clone() if not final_chunk else None | |
| return wav | |
| overlap = self._overlap_samples(wav) | |
| if overlap == 0: | |
| return wav | |
| prev_tail = self._prev_tail | |
| if prev_tail.numel() < overlap: | |
| overlap = prev_tail.numel() | |
| if overlap == 0: | |
| return wav | |
| fade_out = torch.linspace(1.0, 0.0, overlap, device=wav.device) | |
| fade_in = 1.0 - fade_out | |
| cross = prev_tail[-overlap:] * fade_out + wav[:overlap] * fade_in | |
| merged = torch.cat([prev_tail[:-overlap], cross, wav[overlap:]], dim=-1) | |
| self._prev_tail = None if final_chunk else wav[-overlap:].clone() | |
| return merged | |
| def _overlap_samples(self, wav: torch.Tensor) -> int: | |
| if self.chunk_frames <= 0: | |
| return 0 | |
| return int(wav.numel() * (self.overlap_frames / self.chunk_frames)) | |
| class TextDeltaTokenizer: | |
| """ | |
| Convert LLM streaming text (delta) into “incremental token IDs”. | |
| Notes: | |
| - The input is a delta that is progressively appended to the same string | |
| (consistent with the common delta output behavior in vLLM). | |
| - Each time, re-encode the *full text* with the tokenizer, then take only | |
| the newly added token IDs. | |
| - This guarantees that tokenization is consistent with the final complete | |
| text, avoiding boundary mismatches caused by tokenizing partial segments. | |
| """ | |
| def __init__(self, tokenizer, *, hold_back: int = 3): | |
| self.tokenizer = tokenizer | |
| self.hold_back = max(0, int(hold_back)) | |
| self._text = "" | |
| self._all_ids: list[int] = [] | |
| self._emitted_count: int = 0 | |
| def text(self) -> str: | |
| return self._text | |
| def token_ids(self) -> list[int]: | |
| return list(self._all_ids) | |
| def push_delta(self, delta: str) -> list[int]: | |
| """Append a text delta and return newly stable token ids (may be empty).""" | |
| if not delta: | |
| return [] | |
| self._text += str(delta) | |
| self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False) | |
| # Keep the tail un-emitted because the latest tokens can still change. | |
| stable_count = max(self._emitted_count, len(self._all_ids) - self.hold_back) | |
| new_ids = self._all_ids[self._emitted_count : stable_count] | |
| self._emitted_count = stable_count | |
| return new_ids | |
| def flush(self) -> list[int]: | |
| """Emit all remaining token ids at end of stream.""" | |
| self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False) | |
| remaining = self._all_ids[self._emitted_count :] | |
| self._emitted_count = len(self._all_ids) | |
| return remaining | |
| def _sanitize_audio_tokens( | |
| tokens: torch.Tensor, | |
| *, | |
| codebook_size: int, | |
| audio_eos_token: int, | |
| ) -> tuple[torch.Tensor, bool]: | |
| """Trim rows after EOS/invalid tokens and return whether decoding should stop.""" | |
| if tokens.dim() == 1: | |
| tokens = tokens.unsqueeze(0) | |
| if tokens.numel() == 0: | |
| return tokens, False | |
| eos_rows = (tokens[:, 0] == audio_eos_token).nonzero(as_tuple=False) | |
| invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(dim=1) | |
| stop_idx = None | |
| if eos_rows.numel() > 0: | |
| stop_idx = int(eos_rows[0].item()) | |
| if invalid_rows.any(): | |
| invalid_idx = int(invalid_rows.nonzero(as_tuple=False)[0].item()) | |
| stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx) | |
| if stop_idx is not None: | |
| return tokens[:stop_idx], True | |
| return tokens, False | |
| def _maybe_codec_streaming(codec, *, batch_size: int): | |
| if codec is None or not hasattr(codec, "streaming"): | |
| return contextlib.nullcontext() | |
| return codec.streaming(batch_size=batch_size) | |
| class MossTTSRealtimeTextStreamBridge: | |
| """ | |
| Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks. | |
| Usage overview: | |
| - First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`). | |
| - Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via | |
| `push_text_delta()`. | |
| - Once the accumulated token count reaches `prefill_text_len`, the session will | |
| start generating audio tokens; the bridge will immediately decode them into WAV | |
| chunks and yield them. | |
| """ | |
| def __init__( | |
| self, | |
| session: MossTTSRealtimeStreamingSession, | |
| decoder: AudioStreamDecoder, | |
| *, | |
| codebook_size: Optional[int] = None, | |
| audio_eos_token: Optional[int] = None, | |
| batch_size: int = 1, | |
| ): | |
| self.session = session | |
| self.decoder = decoder | |
| self.batch_size = int(batch_size) | |
| if codebook_size is None: | |
| codebook_size = int(getattr(getattr(session, "codec", None), "codebook_size", 1024)) | |
| if audio_eos_token is None: | |
| audio_eos_token = int(getattr(session.inferencer, "audio_eos_token", 1026)) | |
| self.codebook_size = int(codebook_size) | |
| self.audio_eos_token = int(audio_eos_token) | |
| def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]: | |
| """ | |
| Push a chunk of incremental text output from the LLM and return newly generated WAV chunks. | |
| Internally, this directly calls `session.push_text()`, which segments the text | |
| based on punctuation/length and then tokenizes the *entire segment* at once, | |
| avoiding the prefix instability issues of incremental BPE tokenization. | |
| """ | |
| audio_frames = self.session.push_text(delta) | |
| yield from self._decode_audio_frames(audio_frames) | |
| def push_text_tokens(self, token_ids: Sequence[int]) -> Iterator[torch.Tensor]: | |
| """Push token ids directly (for sources that stream token ids).""" | |
| if not token_ids: | |
| return | |
| audio_frames = self.session.push_text_tokens(token_ids) | |
| yield from self._decode_audio_frames(audio_frames) | |
| def finish(self, *, drain_step: int = 1) -> Iterator[torch.Tensor]: | |
| """Mark text stream end and emit all remaining audio chunks (including flush).""" | |
| audio_frames = self.session.end_text() | |
| yield from self._decode_audio_frames(audio_frames) | |
| while True: | |
| more_frames = self.session.drain(max_steps=drain_step) | |
| if not more_frames: | |
| break | |
| yield from self._decode_audio_frames(more_frames) | |
| if self.session.inferencer.is_finished: | |
| break | |
| final = self.decoder.flush() | |
| if final is not None and final.numel() > 0: | |
| yield final.detach().cpu() | |
| def stream_from_text_deltas(self, deltas: Iterable[str], *, drain_step: int = 1) -> Iterator[torch.Tensor]: | |
| """Consume a full delta iterator and continuously yield waveform chunks.""" | |
| with _maybe_codec_streaming(getattr(self.session, "codec", None), batch_size=self.batch_size): | |
| for delta in deltas: | |
| yield from self.push_text_delta(delta) | |
| yield from self.finish(drain_step=drain_step) | |
| def _decode_audio_frames(self, audio_frames: list[torch.Tensor]) -> Iterator[torch.Tensor]: | |
| for frame in audio_frames: | |
| tokens = frame | |
| if tokens.dim() == 3: | |
| tokens = tokens[0] | |
| if tokens.dim() != 2: | |
| raise ValueError(f"Expected [B, C] or [1, C] audio tokens, got {tuple(tokens.shape)}") | |
| if tokens.shape[0] != 1: | |
| raise ValueError( | |
| f"This bridge currently supports batch_size=1 for decoding, got batch={tokens.shape[0]}." | |
| ) | |
| tokens, stop = _sanitize_audio_tokens( | |
| tokens, | |
| codebook_size=self.codebook_size, | |
| audio_eos_token=self.audio_eos_token, | |
| ) | |
| if tokens.numel() == 0: | |
| if stop: | |
| break | |
| continue | |
| self.decoder.push_tokens(tokens.detach()) | |
| for wav in self.decoder.audio_chunks(): | |
| if wav.numel() == 0: | |
| continue | |
| yield wav.detach().cpu() | |
| if stop: | |
| break | |
| __all__ = [ | |
| "AudioStreamDecoder", | |
| "MossTTSRealtimeInference", | |
| "MossTTSRealtimeStreamingSession", | |
| "MossTTSRealtimeTextStreamBridge", | |
| "TextDeltaTokenizer", | |
| ] | |