MOSS-TTSD / app.py
rulerman's picture
init
d899551
import argparse
import functools
import importlib.util
import os
import re
import time
from pathlib import Path
from typing import Optional
try:
import spaces
except ImportError:
class _SpacesFallback:
@staticmethod
def GPU(*_args, **_kwargs):
def _decorator(func):
return func
return _decorator
spaces = _SpacesFallback()
import gradio as gr
import numpy as np
import torch
import soundfile as sf
from transformers import AutoModel, AutoProcessor
# Disable the broken cuDNN SDPA backend
torch.backends.cuda.enable_cudnn_sdp(False)
# Keep these enabled as fallbacks
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
MODEL_PATH = "OpenMOSS-Team/MOSS-TTSD-v1.0"
CODEC_MODEL_PATH = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
DEFAULT_ATTN_IMPLEMENTATION = "auto"
DEFAULT_MAX_NEW_TOKENS = 2000
MIN_SPEAKERS = 1
MAX_SPEAKERS = 5
def resolve_attn_implementation(requested: str, device: torch.device, dtype: torch.dtype) -> str | None:
requested_norm = (requested or "").strip().lower()
if requested_norm in {"none"}:
return None
if requested_norm not in {"", "auto"}:
return requested
# Prefer FlashAttention 2 when package + device conditions are met.
if (
device.type == "cuda"
and importlib.util.find_spec("flash_attn") is not None
and dtype in {torch.float16, torch.bfloat16}
):
major, _ = torch.cuda.get_device_capability(device)
if major >= 8:
return "flash_attention_2"
# CUDA fallback: use PyTorch SDPA kernels.
if device.type == "cuda":
return "sdpa"
# CPU fallback.
return "eager"
@functools.lru_cache(maxsize=1)
def load_backend(model_path: str, codec_path: str, device_str: str, attn_implementation: str):
device = torch.device(device_str if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
resolved_attn_implementation = resolve_attn_implementation(
requested=attn_implementation,
device=device,
dtype=dtype,
)
processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True,
codec_path=codec_path,
)
if hasattr(processor, "audio_tokenizer"):
processor.audio_tokenizer = processor.audio_tokenizer.to(device)
processor.audio_tokenizer.eval()
model_kwargs = {
"trust_remote_code": True,
"torch_dtype": dtype,
}
if resolved_attn_implementation:
model_kwargs["attn_implementation"] = resolved_attn_implementation
model = AutoModel.from_pretrained(model_path, **model_kwargs).to(device)
model.eval()
sample_rate = int(getattr(processor.model_config, "sampling_rate", 24000))
return model, processor, device, sample_rate
def _resample_wav(wav: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor:
if int(orig_sr) == int(target_sr):
return wav
new_num_samples = int(round(wav.shape[-1] * float(target_sr) / float(orig_sr)))
if new_num_samples <= 0:
raise ValueError(f"Invalid resample length from {orig_sr}Hz to {target_sr}Hz.")
return torch.nn.functional.interpolate(
wav.unsqueeze(0),
size=new_num_samples,
mode="linear",
align_corners=False,
).squeeze(0)
def _load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
path = Path(audio_path).expanduser()
if not path.exists():
raise FileNotFoundError(f"Reference audio not found: {path}")
wav_np, sr = sf.read(path, dtype="float32", always_2d=True)
if wav_np.size == 0:
raise ValueError(f"Reference audio is empty: {path}")
if wav_np.shape[1] > 1:
wav_np = wav_np.mean(axis=1, keepdims=True)
wav = torch.from_numpy(wav_np.T)
return wav, int(sr)
def normalize_text(text: str) -> str:
text = re.sub(r"\[(\d+)\]", r"[S\1]", text)
remove_chars = "【】《》()『』「」" '"-_“”~~‘’'
segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " "))
processed_parts = []
for seg in segments:
seg = seg.strip()
if not seg:
continue
matched = re.match(r"^(\[S\d+\])\s*(.*)", seg)
tag, content = matched.groups() if matched else ("", seg)
content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
content = re.sub(r"哈{2,}", "[笑]", content)
content = re.sub(r"\b(ha(\s*ha)+)\b", "[laugh]", content, flags=re.IGNORECASE)
content = content.replace("——", ",")
content = content.replace("……", ",")
content = content.replace("...", ",")
content = content.replace("⸺", ",")
content = content.replace("―", ",")
content = content.replace("—", ",")
content = content.replace("…", ",")
internal_punct_map = str.maketrans(
{";": ",", ";": ",", ":": ",", ":": ",", "、": ","}
)
content = content.translate(internal_punct_map)
content = content.strip()
content = re.sub(r"([,。?!,.?!])[,。?!,.?!]+", r"\1", content)
if len(content) > 1:
last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1])
body = content[:-1].replace("。", ",")
content = body + last_ch
processed_parts.append({"tag": tag, "content": content})
if not processed_parts:
return ""
merged_lines = []
current_tag = processed_parts[0]["tag"]
current_content = [processed_parts[0]["content"]]
for part in processed_parts[1:]:
if part["tag"] == current_tag and current_tag:
current_content.append(part["content"])
else:
merged_lines.append(f"{current_tag}{''.join(current_content)}".strip())
current_tag = part["tag"]
current_content = [part["content"]]
merged_lines.append(f"{current_tag}{''.join(current_content)}".strip())
return "".join(merged_lines).replace("‘", "'").replace("’", "'")
def _validate_dialogue_text(dialogue_text: str, speaker_count: int) -> str:
text = (dialogue_text or "").strip()
if not text:
raise ValueError("Please enter dialogue text.")
tags = re.findall(r"\[S(\d+)\]", text)
if not tags:
raise ValueError("Dialogue must include speaker tags like [S1], [S2], ...")
max_tag = max(int(t) for t in tags)
if max_tag > speaker_count:
raise ValueError(
f"Dialogue contains [S{max_tag}], but speaker count is set to {speaker_count}."
)
return text
def update_speaker_panels(speaker_count: int):
count = int(speaker_count)
count = max(MIN_SPEAKERS, min(MAX_SPEAKERS, count))
return [gr.update(visible=(idx < count)) for idx in range(MAX_SPEAKERS)]
def _merge_consecutive_speaker_tags(text: str) -> str:
segments = re.split(r"(?=\[S\d+\])", text)
if not segments:
return text
merged_parts = []
current_tag = None
for seg in segments:
seg = seg.strip()
if not seg:
continue
matched = re.match(r"^(\[S\d+\])\s*(.*)", seg, re.DOTALL)
if not matched:
merged_parts.append(seg)
continue
tag, content = matched.groups()
if tag == current_tag:
merged_parts.append(content)
else:
current_tag = tag
merged_parts.append(f"{tag}{content}")
return "".join(merged_parts)
def _normalize_prompt_text(prompt_text: str, speaker_id: int) -> str:
text = (prompt_text or "").strip()
if not text:
raise ValueError(f"S{speaker_id} prompt text is empty.")
expected_tag = f"[S{speaker_id}]"
if not text.lstrip().startswith(expected_tag):
text = f"{expected_tag} {text}"
return text
def _build_prefixed_text(
dialogue_text: str,
prompt_text_map: dict[int, str],
cloned_speakers: list[int],
) -> str:
prompt_prefix = "".join([prompt_text_map[speaker_id] for speaker_id in cloned_speakers])
return _merge_consecutive_speaker_tags(prompt_prefix + dialogue_text)
def _encode_reference_audio_codes(
processor,
clone_wavs: list[torch.Tensor],
cloned_speakers: list[int],
speaker_count: int,
sample_rate: int,
) -> list[Optional[torch.Tensor]]:
encoded_list = processor.encode_audios_from_wav(clone_wavs, sampling_rate=sample_rate)
reference_audio_codes: list[Optional[torch.Tensor]] = [None for _ in range(speaker_count)]
for speaker_id, audio_codes in zip(cloned_speakers, encoded_list):
reference_audio_codes[speaker_id - 1] = audio_codes
return reference_audio_codes
def build_conversation(
dialogue_text: str,
reference_audio_codes: list[Optional[torch.Tensor]],
prompt_audio: torch.Tensor | None,
processor,
):
if prompt_audio is None:
return [[processor.build_user_message(text=dialogue_text)]], "generation", "Generation"
user_message = processor.build_user_message(
text=dialogue_text,
reference=reference_audio_codes,
)
return (
[
[
user_message,
processor.build_assistant_message(audio_codes_list=[prompt_audio]),
],
],
"continuation",
"voice_clone_and_continuation",
)
@spaces.GPU(duration=180)
def run_inference(speaker_count: int, *all_inputs):
speaker_count = int(speaker_count)
speaker_count = max(MIN_SPEAKERS, min(MAX_SPEAKERS, speaker_count))
reference_audio_values = all_inputs[:MAX_SPEAKERS]
prompt_text_values = all_inputs[MAX_SPEAKERS : 2 * MAX_SPEAKERS]
dialogue_text = all_inputs[2 * MAX_SPEAKERS]
text_normalize, sample_rate_normalize, temperature, top_p, top_k, repetition_penalty, max_new_tokens, model_path, codec_path, device, attn_implementation = all_inputs[
2 * MAX_SPEAKERS + 1 :
]
started_at = time.monotonic()
model, processor, torch_device, sample_rate = load_backend(
model_path=str(model_path),
codec_path=str(codec_path),
device_str=str(device),
attn_implementation=str(attn_implementation),
)
text_normalize = bool(text_normalize)
sample_rate_normalize = bool(sample_rate_normalize)
normalized_dialogue = str(dialogue_text or "").strip()
if text_normalize:
normalized_dialogue = normalize_text(normalized_dialogue)
normalized_dialogue = _validate_dialogue_text(normalized_dialogue, speaker_count)
cloned_speakers: list[int] = []
loaded_clone_wavs: list[tuple[torch.Tensor, int]] = []
prompt_text_map: dict[int, str] = {}
for idx in range(speaker_count):
ref_audio = reference_audio_values[idx]
prompt_text = str(prompt_text_values[idx] or "").strip()
has_reference = bool(ref_audio)
has_prompt_text = bool(prompt_text)
if has_reference != has_prompt_text:
raise ValueError(
f"S{idx + 1} must provide both reference audio and prompt text together."
)
if has_reference:
speaker_id = idx + 1
ref_audio_path = str(ref_audio)
cloned_speakers.append(speaker_id)
loaded_clone_wavs.append(_load_audio(ref_audio_path))
prompt_text_map[speaker_id] = _normalize_prompt_text(prompt_text, speaker_id)
prompt_audio: Optional[torch.Tensor] = None
reference_audio_codes: list[Optional[torch.Tensor]] = []
conversation_text = normalized_dialogue
if cloned_speakers:
conversation_text = _build_prefixed_text(
dialogue_text=normalized_dialogue,
prompt_text_map=prompt_text_map,
cloned_speakers=cloned_speakers,
)
if text_normalize:
conversation_text = normalize_text(conversation_text)
conversation_text = _validate_dialogue_text(conversation_text, speaker_count)
if sample_rate_normalize:
min_sr = min(sr for _, sr in loaded_clone_wavs)
else:
min_sr = None
clone_wavs: list[torch.Tensor] = []
for wav, orig_sr in loaded_clone_wavs:
processed_wav = wav
current_sr = int(orig_sr)
if min_sr is not None:
processed_wav = _resample_wav(processed_wav, current_sr, int(min_sr))
current_sr = int(min_sr)
processed_wav = _resample_wav(processed_wav, current_sr, sample_rate)
clone_wavs.append(processed_wav)
reference_audio_codes = _encode_reference_audio_codes(
processor=processor,
clone_wavs=clone_wavs,
cloned_speakers=cloned_speakers,
speaker_count=speaker_count,
sample_rate=sample_rate,
)
concat_prompt_wav = torch.cat(clone_wavs, dim=-1)
prompt_audio = processor.encode_audios_from_wav([concat_prompt_wav], sampling_rate=sample_rate)[0]
conversations, mode, mode_name = build_conversation(
dialogue_text=conversation_text,
reference_audio_codes=reference_audio_codes,
prompt_audio=prompt_audio,
processor=processor,
)
batch = processor(conversations, mode=mode)
input_ids = batch["input_ids"].to(torch_device)
attention_mask = batch["attention_mask"].to(torch_device)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=int(max_new_tokens),
audio_temperature=float(temperature),
audio_top_p=float(top_p),
audio_top_k=int(top_k),
audio_repetition_penalty=float(repetition_penalty),
)
messages = processor.decode(outputs)
if not messages or messages[0] is None:
raise RuntimeError("The model did not return a decodable audio result.")
audio = messages[0].audio_codes_list[0]
if isinstance(audio, torch.Tensor):
audio_np = audio.detach().float().cpu().numpy()
else:
audio_np = np.asarray(audio, dtype=np.float32)
if audio_np.ndim > 1:
audio_np = audio_np.reshape(-1)
audio_np = audio_np.astype(np.float32, copy=False)
clone_summary = "none" if not cloned_speakers else ",".join([f"S{i}" for i in cloned_speakers])
elapsed = time.monotonic() - started_at
status = (
f"Done | mode={mode_name} | speakers={speaker_count} | cloned={clone_summary} | elapsed={elapsed:.2f}s | "
f"text_normalize={text_normalize}, sample_rate_normalize={sample_rate_normalize} | "
f"max_new_tokens={int(max_new_tokens)}, "
f"audio_temperature={float(temperature):.2f}, audio_top_p={float(top_p):.2f}, "
f"audio_top_k={int(top_k)}, audio_repetition_penalty={float(repetition_penalty):.2f}"
)
return (sample_rate, audio_np), status
def build_demo(args: argparse.Namespace):
custom_css = """
:root {
--bg: #f6f7f8;
--panel: #ffffff;
--ink: #111418;
--muted: #4d5562;
--line: #e5e7eb;
--accent: #0f766e;
}
.gradio-container {
background: linear-gradient(180deg, #f7f8fa 0%, #f3f5f7 100%);
color: var(--ink);
}
.app-card {
border: 1px solid var(--line);
border-radius: 16px;
background: var(--panel);
padding: 14px;
}
.app-title {
font-size: 22px;
font-weight: 700;
margin-bottom: 6px;
letter-spacing: 0.2px;
}
.app-subtitle {
color: var(--muted);
font-size: 14px;
margin-bottom: 8px;
}
#output_panel {
overflow: hidden !important;
}
#output_audio {
padding-bottom: 24px;
margin-bottom: 0;
overflow: hidden !important;
}
#output_audio > .wrap,
#output_audio .wrap,
#output_audio .audio-container,
#output_audio .block {
overflow: hidden !important;
}
#output_audio .audio-container {
padding-bottom: 10px;
min-height: 96px;
}
#output_audio_spacer {
height: 12px;
}
#output_status {
margin-top: 0;
}
#run-btn {
background: var(--accent);
border: none;
}
"""
with gr.Blocks(title="MOSS-TTSD Demo", css=custom_css) as demo:
gr.Markdown(
"""
<div class="app-card">
<div class="app-title">MOSS-TTSD</div>
<div class="app-subtitle">Multi-speaker dialogue synthesis with optional per-speaker voice cloning.</div>
</div>
"""
)
speaker_panels: list[gr.Group] = []
speaker_refs = []
speaker_prompts = []
with gr.Row(equal_height=False):
with gr.Column(scale=3):
speaker_count = gr.Slider(
minimum=MIN_SPEAKERS,
maximum=MAX_SPEAKERS,
step=1,
value=2,
label="Speaker Count",
info="Default 2 speakers. Minimum 1, maximum 5.",
)
gr.Markdown("### Voice Cloning (Optional, placed first)")
gr.Markdown(
"If you provide reference audio for a speaker, you must also provide that speaker's prompt text. "
"Prompt text may omit [Sx]; the app will auto-prepend it."
)
for idx in range(1, MAX_SPEAKERS + 1):
with gr.Group(visible=idx <= 2) as panel:
speaker_ref = gr.Audio(
label=f"S{idx} Reference Audio (Optional)",
type="filepath",
)
speaker_prompt = gr.Textbox(
label=f"S{idx} Prompt Text (Required with reference audio)",
lines=2,
placeholder=f"Example: [S{idx}] This is a prompt line for S{idx}.",
)
speaker_panels.append(panel)
speaker_refs.append(speaker_ref)
speaker_prompts.append(speaker_prompt)
gr.Markdown("### Multi-turn Dialogue")
dialogue_text = gr.Textbox(
label="Dialogue Text",
lines=12,
placeholder=(
"Use explicit tags in a single box, e.g.\n"
"[S1] Hello.\n"
"[S2] Hi, how are you?\n"
"[S1] Great, let's continue."
),
)
gr.Markdown(
"Without any reference audio, the model runs in generation mode. "
"Once any reference audio is provided, the model switches to voice-clone continuation mode."
)
with gr.Accordion("Sampling Parameters (Audio)", open=True):
gr.Markdown(
"- `text_normalize`: Normalize input text (**recommended to always enable**).\n"
"- `sample_rate_normalize`: Resample prompt audios to the lowest sample rate before encoding "
"(**recommended when using 2 or more speakers**)."
)
text_normalize = gr.Checkbox(
value=True,
label="text_normalize",
)
sample_rate_normalize = gr.Checkbox(
value=False,
label="sample_rate_normalize",
)
temperature = gr.Slider(
minimum=0.1,
maximum=3.0,
step=0.05,
value=1.1,
label="temperature",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
step=0.01,
value=0.9,
label="top_p",
)
top_k = gr.Slider(
minimum=1,
maximum=200,
step=1,
value=50,
label="top_k",
)
repetition_penalty = gr.Slider(
minimum=0.8,
maximum=2.0,
step=0.05,
value=1.1,
label="repetition_penalty",
)
max_new_tokens = gr.Slider(
minimum=256,
maximum=8192,
step=128,
value=DEFAULT_MAX_NEW_TOKENS,
label="max_new_tokens",
)
run_btn = gr.Button("Generate Dialogue Audio", variant="primary", elem_id="run-btn")
with gr.Column(scale=2, elem_id="output_panel"):
output_audio = gr.Audio(label="Output Audio", type="numpy", elem_id="output_audio")
gr.HTML("", elem_id="output_audio_spacer")
status = gr.Textbox(label="Status", lines=4, interactive=False, elem_id="output_status")
speaker_count.change(
fn=update_speaker_panels,
inputs=[speaker_count],
outputs=speaker_panels,
)
run_btn.click(
fn=run_inference,
inputs=[
speaker_count,
*speaker_refs,
*speaker_prompts,
dialogue_text,
text_normalize,
sample_rate_normalize,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
gr.State(args.model_path),
gr.State(args.codec_path),
gr.State(args.device),
gr.State(args.attn_implementation),
],
outputs=[output_audio, status],
)
return demo
def main() -> None:
parser = argparse.ArgumentParser(description="MOSS-TTSD Gradio Demo")
parser.add_argument("--model_path", type=str, default=MODEL_PATH)
parser.add_argument("--codec_path", type=str, default=CODEC_MODEL_PATH)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--attn_implementation", type=str, default=DEFAULT_ATTN_IMPLEMENTATION)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))))
parser.add_argument("--share", action="store_true")
parser.add_argument("--preload_backend", action="store_true")
args = parser.parse_args()
runtime_device = torch.device(args.device if torch.cuda.is_available() else "cpu")
runtime_dtype = torch.bfloat16 if runtime_device.type == "cuda" else torch.float32
startup_attn_implementation = resolve_attn_implementation(
requested=args.attn_implementation,
device=runtime_device,
dtype=runtime_dtype,
) or "none"
print(
f"[INFO] Startup runtime probe: device={runtime_device}, attn={startup_attn_implementation}",
flush=True,
)
if args.preload_backend:
preload_started_at = time.monotonic()
print(
f"[Startup] Preloading backend: model={args.model_path}, codec={args.codec_path}, "
f"device={args.device}, attn={args.attn_implementation}",
flush=True,
)
load_backend(
model_path=args.model_path,
codec_path=args.codec_path,
device_str=args.device,
attn_implementation=args.attn_implementation,
)
print(
f"[Startup] Backend preload finished in {time.monotonic() - preload_started_at:.2f}s",
flush=True,
)
else:
print("[Startup] Backend preload skipped; backend will load lazily on first inference.", flush=True)
demo = build_demo(args)
demo.queue(default_concurrency_limit=2).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
ssr_mode=False,
)
if __name__ == "__main__":
main()