Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import functools | |
| import importlib.util | |
| import os | |
| import re | |
| import time | |
| from pathlib import Path | |
| from typing import Optional | |
| try: | |
| import spaces | |
| except ImportError: | |
| class _SpacesFallback: | |
| def GPU(*_args, **_kwargs): | |
| def _decorator(func): | |
| return func | |
| return _decorator | |
| spaces = _SpacesFallback() | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import soundfile as sf | |
| from transformers import AutoModel, AutoProcessor | |
| # Disable the broken cuDNN SDPA backend | |
| torch.backends.cuda.enable_cudnn_sdp(False) | |
| # Keep these enabled as fallbacks | |
| torch.backends.cuda.enable_flash_sdp(True) | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| torch.backends.cuda.enable_math_sdp(True) | |
| MODEL_PATH = "OpenMOSS-Team/MOSS-TTSD-v1.0" | |
| CODEC_MODEL_PATH = "OpenMOSS-Team/MOSS-Audio-Tokenizer" | |
| DEFAULT_ATTN_IMPLEMENTATION = "auto" | |
| DEFAULT_MAX_NEW_TOKENS = 2000 | |
| MIN_SPEAKERS = 1 | |
| MAX_SPEAKERS = 5 | |
| def resolve_attn_implementation(requested: str, device: torch.device, dtype: torch.dtype) -> str | None: | |
| requested_norm = (requested or "").strip().lower() | |
| if requested_norm in {"none"}: | |
| return None | |
| if requested_norm not in {"", "auto"}: | |
| return requested | |
| # Prefer FlashAttention 2 when package + device conditions are met. | |
| if ( | |
| device.type == "cuda" | |
| and importlib.util.find_spec("flash_attn") is not None | |
| and dtype in {torch.float16, torch.bfloat16} | |
| ): | |
| major, _ = torch.cuda.get_device_capability(device) | |
| if major >= 8: | |
| return "flash_attention_2" | |
| # CUDA fallback: use PyTorch SDPA kernels. | |
| if device.type == "cuda": | |
| return "sdpa" | |
| # CPU fallback. | |
| return "eager" | |
| def load_backend(model_path: str, codec_path: str, device_str: str, attn_implementation: str): | |
| device = torch.device(device_str if torch.cuda.is_available() else "cpu") | |
| dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 | |
| resolved_attn_implementation = resolve_attn_implementation( | |
| requested=attn_implementation, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| processor = AutoProcessor.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| codec_path=codec_path, | |
| ) | |
| if hasattr(processor, "audio_tokenizer"): | |
| processor.audio_tokenizer = processor.audio_tokenizer.to(device) | |
| processor.audio_tokenizer.eval() | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "torch_dtype": dtype, | |
| } | |
| if resolved_attn_implementation: | |
| model_kwargs["attn_implementation"] = resolved_attn_implementation | |
| model = AutoModel.from_pretrained(model_path, **model_kwargs).to(device) | |
| model.eval() | |
| sample_rate = int(getattr(processor.model_config, "sampling_rate", 24000)) | |
| return model, processor, device, sample_rate | |
| def _resample_wav(wav: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor: | |
| if int(orig_sr) == int(target_sr): | |
| return wav | |
| new_num_samples = int(round(wav.shape[-1] * float(target_sr) / float(orig_sr))) | |
| if new_num_samples <= 0: | |
| raise ValueError(f"Invalid resample length from {orig_sr}Hz to {target_sr}Hz.") | |
| return torch.nn.functional.interpolate( | |
| wav.unsqueeze(0), | |
| size=new_num_samples, | |
| mode="linear", | |
| align_corners=False, | |
| ).squeeze(0) | |
| def _load_audio(audio_path: str) -> tuple[torch.Tensor, int]: | |
| path = Path(audio_path).expanduser() | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Reference audio not found: {path}") | |
| wav_np, sr = sf.read(path, dtype="float32", always_2d=True) | |
| if wav_np.size == 0: | |
| raise ValueError(f"Reference audio is empty: {path}") | |
| if wav_np.shape[1] > 1: | |
| wav_np = wav_np.mean(axis=1, keepdims=True) | |
| wav = torch.from_numpy(wav_np.T) | |
| return wav, int(sr) | |
| def normalize_text(text: str) -> str: | |
| text = re.sub(r"\[(\d+)\]", r"[S\1]", text) | |
| remove_chars = "【】《》()『』「」" '"-_“”~~‘’' | |
| segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " ")) | |
| processed_parts = [] | |
| for seg in segments: | |
| seg = seg.strip() | |
| if not seg: | |
| continue | |
| matched = re.match(r"^(\[S\d+\])\s*(.*)", seg) | |
| tag, content = matched.groups() if matched else ("", seg) | |
| content = re.sub(f"[{re.escape(remove_chars)}]", "", content) | |
| content = re.sub(r"哈{2,}", "[笑]", content) | |
| content = re.sub(r"\b(ha(\s*ha)+)\b", "[laugh]", content, flags=re.IGNORECASE) | |
| content = content.replace("——", ",") | |
| content = content.replace("……", ",") | |
| content = content.replace("...", ",") | |
| content = content.replace("⸺", ",") | |
| content = content.replace("―", ",") | |
| content = content.replace("—", ",") | |
| content = content.replace("…", ",") | |
| internal_punct_map = str.maketrans( | |
| {";": ",", ";": ",", ":": ",", ":": ",", "、": ","} | |
| ) | |
| content = content.translate(internal_punct_map) | |
| content = content.strip() | |
| content = re.sub(r"([,。?!,.?!])[,。?!,.?!]+", r"\1", content) | |
| if len(content) > 1: | |
| last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1]) | |
| body = content[:-1].replace("。", ",") | |
| content = body + last_ch | |
| processed_parts.append({"tag": tag, "content": content}) | |
| if not processed_parts: | |
| return "" | |
| merged_lines = [] | |
| current_tag = processed_parts[0]["tag"] | |
| current_content = [processed_parts[0]["content"]] | |
| for part in processed_parts[1:]: | |
| if part["tag"] == current_tag and current_tag: | |
| current_content.append(part["content"]) | |
| else: | |
| merged_lines.append(f"{current_tag}{''.join(current_content)}".strip()) | |
| current_tag = part["tag"] | |
| current_content = [part["content"]] | |
| merged_lines.append(f"{current_tag}{''.join(current_content)}".strip()) | |
| return "".join(merged_lines).replace("‘", "'").replace("’", "'") | |
| def _validate_dialogue_text(dialogue_text: str, speaker_count: int) -> str: | |
| text = (dialogue_text or "").strip() | |
| if not text: | |
| raise ValueError("Please enter dialogue text.") | |
| tags = re.findall(r"\[S(\d+)\]", text) | |
| if not tags: | |
| raise ValueError("Dialogue must include speaker tags like [S1], [S2], ...") | |
| max_tag = max(int(t) for t in tags) | |
| if max_tag > speaker_count: | |
| raise ValueError( | |
| f"Dialogue contains [S{max_tag}], but speaker count is set to {speaker_count}." | |
| ) | |
| return text | |
| def update_speaker_panels(speaker_count: int): | |
| count = int(speaker_count) | |
| count = max(MIN_SPEAKERS, min(MAX_SPEAKERS, count)) | |
| return [gr.update(visible=(idx < count)) for idx in range(MAX_SPEAKERS)] | |
| def _merge_consecutive_speaker_tags(text: str) -> str: | |
| segments = re.split(r"(?=\[S\d+\])", text) | |
| if not segments: | |
| return text | |
| merged_parts = [] | |
| current_tag = None | |
| for seg in segments: | |
| seg = seg.strip() | |
| if not seg: | |
| continue | |
| matched = re.match(r"^(\[S\d+\])\s*(.*)", seg, re.DOTALL) | |
| if not matched: | |
| merged_parts.append(seg) | |
| continue | |
| tag, content = matched.groups() | |
| if tag == current_tag: | |
| merged_parts.append(content) | |
| else: | |
| current_tag = tag | |
| merged_parts.append(f"{tag}{content}") | |
| return "".join(merged_parts) | |
| def _normalize_prompt_text(prompt_text: str, speaker_id: int) -> str: | |
| text = (prompt_text or "").strip() | |
| if not text: | |
| raise ValueError(f"S{speaker_id} prompt text is empty.") | |
| expected_tag = f"[S{speaker_id}]" | |
| if not text.lstrip().startswith(expected_tag): | |
| text = f"{expected_tag} {text}" | |
| return text | |
| def _build_prefixed_text( | |
| dialogue_text: str, | |
| prompt_text_map: dict[int, str], | |
| cloned_speakers: list[int], | |
| ) -> str: | |
| prompt_prefix = "".join([prompt_text_map[speaker_id] for speaker_id in cloned_speakers]) | |
| return _merge_consecutive_speaker_tags(prompt_prefix + dialogue_text) | |
| def _encode_reference_audio_codes( | |
| processor, | |
| clone_wavs: list[torch.Tensor], | |
| cloned_speakers: list[int], | |
| speaker_count: int, | |
| sample_rate: int, | |
| ) -> list[Optional[torch.Tensor]]: | |
| encoded_list = processor.encode_audios_from_wav(clone_wavs, sampling_rate=sample_rate) | |
| reference_audio_codes: list[Optional[torch.Tensor]] = [None for _ in range(speaker_count)] | |
| for speaker_id, audio_codes in zip(cloned_speakers, encoded_list): | |
| reference_audio_codes[speaker_id - 1] = audio_codes | |
| return reference_audio_codes | |
| def build_conversation( | |
| dialogue_text: str, | |
| reference_audio_codes: list[Optional[torch.Tensor]], | |
| prompt_audio: torch.Tensor | None, | |
| processor, | |
| ): | |
| if prompt_audio is None: | |
| return [[processor.build_user_message(text=dialogue_text)]], "generation", "Generation" | |
| user_message = processor.build_user_message( | |
| text=dialogue_text, | |
| reference=reference_audio_codes, | |
| ) | |
| return ( | |
| [ | |
| [ | |
| user_message, | |
| processor.build_assistant_message(audio_codes_list=[prompt_audio]), | |
| ], | |
| ], | |
| "continuation", | |
| "voice_clone_and_continuation", | |
| ) | |
| def run_inference(speaker_count: int, *all_inputs): | |
| speaker_count = int(speaker_count) | |
| speaker_count = max(MIN_SPEAKERS, min(MAX_SPEAKERS, speaker_count)) | |
| reference_audio_values = all_inputs[:MAX_SPEAKERS] | |
| prompt_text_values = all_inputs[MAX_SPEAKERS : 2 * MAX_SPEAKERS] | |
| dialogue_text = all_inputs[2 * MAX_SPEAKERS] | |
| text_normalize, sample_rate_normalize, temperature, top_p, top_k, repetition_penalty, max_new_tokens, model_path, codec_path, device, attn_implementation = all_inputs[ | |
| 2 * MAX_SPEAKERS + 1 : | |
| ] | |
| started_at = time.monotonic() | |
| model, processor, torch_device, sample_rate = load_backend( | |
| model_path=str(model_path), | |
| codec_path=str(codec_path), | |
| device_str=str(device), | |
| attn_implementation=str(attn_implementation), | |
| ) | |
| text_normalize = bool(text_normalize) | |
| sample_rate_normalize = bool(sample_rate_normalize) | |
| normalized_dialogue = str(dialogue_text or "").strip() | |
| if text_normalize: | |
| normalized_dialogue = normalize_text(normalized_dialogue) | |
| normalized_dialogue = _validate_dialogue_text(normalized_dialogue, speaker_count) | |
| cloned_speakers: list[int] = [] | |
| loaded_clone_wavs: list[tuple[torch.Tensor, int]] = [] | |
| prompt_text_map: dict[int, str] = {} | |
| for idx in range(speaker_count): | |
| ref_audio = reference_audio_values[idx] | |
| prompt_text = str(prompt_text_values[idx] or "").strip() | |
| has_reference = bool(ref_audio) | |
| has_prompt_text = bool(prompt_text) | |
| if has_reference != has_prompt_text: | |
| raise ValueError( | |
| f"S{idx + 1} must provide both reference audio and prompt text together." | |
| ) | |
| if has_reference: | |
| speaker_id = idx + 1 | |
| ref_audio_path = str(ref_audio) | |
| cloned_speakers.append(speaker_id) | |
| loaded_clone_wavs.append(_load_audio(ref_audio_path)) | |
| prompt_text_map[speaker_id] = _normalize_prompt_text(prompt_text, speaker_id) | |
| prompt_audio: Optional[torch.Tensor] = None | |
| reference_audio_codes: list[Optional[torch.Tensor]] = [] | |
| conversation_text = normalized_dialogue | |
| if cloned_speakers: | |
| conversation_text = _build_prefixed_text( | |
| dialogue_text=normalized_dialogue, | |
| prompt_text_map=prompt_text_map, | |
| cloned_speakers=cloned_speakers, | |
| ) | |
| if text_normalize: | |
| conversation_text = normalize_text(conversation_text) | |
| conversation_text = _validate_dialogue_text(conversation_text, speaker_count) | |
| if sample_rate_normalize: | |
| min_sr = min(sr for _, sr in loaded_clone_wavs) | |
| else: | |
| min_sr = None | |
| clone_wavs: list[torch.Tensor] = [] | |
| for wav, orig_sr in loaded_clone_wavs: | |
| processed_wav = wav | |
| current_sr = int(orig_sr) | |
| if min_sr is not None: | |
| processed_wav = _resample_wav(processed_wav, current_sr, int(min_sr)) | |
| current_sr = int(min_sr) | |
| processed_wav = _resample_wav(processed_wav, current_sr, sample_rate) | |
| clone_wavs.append(processed_wav) | |
| reference_audio_codes = _encode_reference_audio_codes( | |
| processor=processor, | |
| clone_wavs=clone_wavs, | |
| cloned_speakers=cloned_speakers, | |
| speaker_count=speaker_count, | |
| sample_rate=sample_rate, | |
| ) | |
| concat_prompt_wav = torch.cat(clone_wavs, dim=-1) | |
| prompt_audio = processor.encode_audios_from_wav([concat_prompt_wav], sampling_rate=sample_rate)[0] | |
| conversations, mode, mode_name = build_conversation( | |
| dialogue_text=conversation_text, | |
| reference_audio_codes=reference_audio_codes, | |
| prompt_audio=prompt_audio, | |
| processor=processor, | |
| ) | |
| batch = processor(conversations, mode=mode) | |
| input_ids = batch["input_ids"].to(torch_device) | |
| attention_mask = batch["attention_mask"].to(torch_device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=int(max_new_tokens), | |
| audio_temperature=float(temperature), | |
| audio_top_p=float(top_p), | |
| audio_top_k=int(top_k), | |
| audio_repetition_penalty=float(repetition_penalty), | |
| ) | |
| messages = processor.decode(outputs) | |
| if not messages or messages[0] is None: | |
| raise RuntimeError("The model did not return a decodable audio result.") | |
| audio = messages[0].audio_codes_list[0] | |
| if isinstance(audio, torch.Tensor): | |
| audio_np = audio.detach().float().cpu().numpy() | |
| else: | |
| audio_np = np.asarray(audio, dtype=np.float32) | |
| if audio_np.ndim > 1: | |
| audio_np = audio_np.reshape(-1) | |
| audio_np = audio_np.astype(np.float32, copy=False) | |
| clone_summary = "none" if not cloned_speakers else ",".join([f"S{i}" for i in cloned_speakers]) | |
| elapsed = time.monotonic() - started_at | |
| status = ( | |
| f"Done | mode={mode_name} | speakers={speaker_count} | cloned={clone_summary} | elapsed={elapsed:.2f}s | " | |
| f"text_normalize={text_normalize}, sample_rate_normalize={sample_rate_normalize} | " | |
| f"max_new_tokens={int(max_new_tokens)}, " | |
| f"audio_temperature={float(temperature):.2f}, audio_top_p={float(top_p):.2f}, " | |
| f"audio_top_k={int(top_k)}, audio_repetition_penalty={float(repetition_penalty):.2f}" | |
| ) | |
| return (sample_rate, audio_np), status | |
| def build_demo(args: argparse.Namespace): | |
| custom_css = """ | |
| :root { | |
| --bg: #f6f7f8; | |
| --panel: #ffffff; | |
| --ink: #111418; | |
| --muted: #4d5562; | |
| --line: #e5e7eb; | |
| --accent: #0f766e; | |
| } | |
| .gradio-container { | |
| background: linear-gradient(180deg, #f7f8fa 0%, #f3f5f7 100%); | |
| color: var(--ink); | |
| } | |
| .app-card { | |
| border: 1px solid var(--line); | |
| border-radius: 16px; | |
| background: var(--panel); | |
| padding: 14px; | |
| } | |
| .app-title { | |
| font-size: 22px; | |
| font-weight: 700; | |
| margin-bottom: 6px; | |
| letter-spacing: 0.2px; | |
| } | |
| .app-subtitle { | |
| color: var(--muted); | |
| font-size: 14px; | |
| margin-bottom: 8px; | |
| } | |
| #output_panel { | |
| overflow: hidden !important; | |
| } | |
| #output_audio { | |
| padding-bottom: 24px; | |
| margin-bottom: 0; | |
| overflow: hidden !important; | |
| } | |
| #output_audio > .wrap, | |
| #output_audio .wrap, | |
| #output_audio .audio-container, | |
| #output_audio .block { | |
| overflow: hidden !important; | |
| } | |
| #output_audio .audio-container { | |
| padding-bottom: 10px; | |
| min-height: 96px; | |
| } | |
| #output_audio_spacer { | |
| height: 12px; | |
| } | |
| #output_status { | |
| margin-top: 0; | |
| } | |
| #run-btn { | |
| background: var(--accent); | |
| border: none; | |
| } | |
| """ | |
| with gr.Blocks(title="MOSS-TTSD Demo", css=custom_css) as demo: | |
| gr.Markdown( | |
| """ | |
| <div class="app-card"> | |
| <div class="app-title">MOSS-TTSD</div> | |
| <div class="app-subtitle">Multi-speaker dialogue synthesis with optional per-speaker voice cloning.</div> | |
| </div> | |
| """ | |
| ) | |
| speaker_panels: list[gr.Group] = [] | |
| speaker_refs = [] | |
| speaker_prompts = [] | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=3): | |
| speaker_count = gr.Slider( | |
| minimum=MIN_SPEAKERS, | |
| maximum=MAX_SPEAKERS, | |
| step=1, | |
| value=2, | |
| label="Speaker Count", | |
| info="Default 2 speakers. Minimum 1, maximum 5.", | |
| ) | |
| gr.Markdown("### Voice Cloning (Optional, placed first)") | |
| gr.Markdown( | |
| "If you provide reference audio for a speaker, you must also provide that speaker's prompt text. " | |
| "Prompt text may omit [Sx]; the app will auto-prepend it." | |
| ) | |
| for idx in range(1, MAX_SPEAKERS + 1): | |
| with gr.Group(visible=idx <= 2) as panel: | |
| speaker_ref = gr.Audio( | |
| label=f"S{idx} Reference Audio (Optional)", | |
| type="filepath", | |
| ) | |
| speaker_prompt = gr.Textbox( | |
| label=f"S{idx} Prompt Text (Required with reference audio)", | |
| lines=2, | |
| placeholder=f"Example: [S{idx}] This is a prompt line for S{idx}.", | |
| ) | |
| speaker_panels.append(panel) | |
| speaker_refs.append(speaker_ref) | |
| speaker_prompts.append(speaker_prompt) | |
| gr.Markdown("### Multi-turn Dialogue") | |
| dialogue_text = gr.Textbox( | |
| label="Dialogue Text", | |
| lines=12, | |
| placeholder=( | |
| "Use explicit tags in a single box, e.g.\n" | |
| "[S1] Hello.\n" | |
| "[S2] Hi, how are you?\n" | |
| "[S1] Great, let's continue." | |
| ), | |
| ) | |
| gr.Markdown( | |
| "Without any reference audio, the model runs in generation mode. " | |
| "Once any reference audio is provided, the model switches to voice-clone continuation mode." | |
| ) | |
| with gr.Accordion("Sampling Parameters (Audio)", open=True): | |
| gr.Markdown( | |
| "- `text_normalize`: Normalize input text (**recommended to always enable**).\n" | |
| "- `sample_rate_normalize`: Resample prompt audios to the lowest sample rate before encoding " | |
| "(**recommended when using 2 or more speakers**)." | |
| ) | |
| text_normalize = gr.Checkbox( | |
| value=True, | |
| label="text_normalize", | |
| ) | |
| sample_rate_normalize = gr.Checkbox( | |
| value=False, | |
| label="sample_rate_normalize", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=3.0, | |
| step=0.05, | |
| value=1.1, | |
| label="temperature", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.9, | |
| label="top_p", | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=200, | |
| step=1, | |
| value=50, | |
| label="top_k", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=0.8, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.1, | |
| label="repetition_penalty", | |
| ) | |
| max_new_tokens = gr.Slider( | |
| minimum=256, | |
| maximum=8192, | |
| step=128, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| label="max_new_tokens", | |
| ) | |
| run_btn = gr.Button("Generate Dialogue Audio", variant="primary", elem_id="run-btn") | |
| with gr.Column(scale=2, elem_id="output_panel"): | |
| output_audio = gr.Audio(label="Output Audio", type="numpy", elem_id="output_audio") | |
| gr.HTML("", elem_id="output_audio_spacer") | |
| status = gr.Textbox(label="Status", lines=4, interactive=False, elem_id="output_status") | |
| speaker_count.change( | |
| fn=update_speaker_panels, | |
| inputs=[speaker_count], | |
| outputs=speaker_panels, | |
| ) | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[ | |
| speaker_count, | |
| *speaker_refs, | |
| *speaker_prompts, | |
| dialogue_text, | |
| text_normalize, | |
| sample_rate_normalize, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| max_new_tokens, | |
| gr.State(args.model_path), | |
| gr.State(args.codec_path), | |
| gr.State(args.device), | |
| gr.State(args.attn_implementation), | |
| ], | |
| outputs=[output_audio, status], | |
| ) | |
| return demo | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="MOSS-TTSD Gradio Demo") | |
| parser.add_argument("--model_path", type=str, default=MODEL_PATH) | |
| parser.add_argument("--codec_path", type=str, default=CODEC_MODEL_PATH) | |
| parser.add_argument("--device", type=str, default="cuda:0") | |
| parser.add_argument("--attn_implementation", type=str, default=DEFAULT_ATTN_IMPLEMENTATION) | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860")))) | |
| parser.add_argument("--share", action="store_true") | |
| parser.add_argument("--preload_backend", action="store_true") | |
| args = parser.parse_args() | |
| runtime_device = torch.device(args.device if torch.cuda.is_available() else "cpu") | |
| runtime_dtype = torch.bfloat16 if runtime_device.type == "cuda" else torch.float32 | |
| startup_attn_implementation = resolve_attn_implementation( | |
| requested=args.attn_implementation, | |
| device=runtime_device, | |
| dtype=runtime_dtype, | |
| ) or "none" | |
| print( | |
| f"[INFO] Startup runtime probe: device={runtime_device}, attn={startup_attn_implementation}", | |
| flush=True, | |
| ) | |
| if args.preload_backend: | |
| preload_started_at = time.monotonic() | |
| print( | |
| f"[Startup] Preloading backend: model={args.model_path}, codec={args.codec_path}, " | |
| f"device={args.device}, attn={args.attn_implementation}", | |
| flush=True, | |
| ) | |
| load_backend( | |
| model_path=args.model_path, | |
| codec_path=args.codec_path, | |
| device_str=args.device, | |
| attn_implementation=args.attn_implementation, | |
| ) | |
| print( | |
| f"[Startup] Backend preload finished in {time.monotonic() - preload_started_at:.2f}s", | |
| flush=True, | |
| ) | |
| else: | |
| print("[Startup] Backend preload skipped; backend will load lazily on first inference.", flush=True) | |
| demo = build_demo(args) | |
| demo.queue(default_concurrency_limit=2).launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=args.share, | |
| ssr_mode=False, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |