# Install dependencies in application code, as we don't have access to a GPU at build time
# Thanks to https://huggingface.co/Steveeeeeeen for their code to handle this!
import os
import shlex
import subprocess

subprocess.run(shlex.split("pip install flash-attn  --no-build-isolation"), env=os.environ | {"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, check=True)
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"), check=True)
subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"), check=True)

import spaces
import gradio as gr
import numpy as np

from typing import Tuple, Dict, Any, Optional
from taproot import Task

# Configuration
is_hf_spaces = os.getenv("SYSTEM", "") == "spaces"
max_characters = 2000
header_markdown = """
# Zonos v0.1
State of the art text-to-speech model [[model]](https://huggingface.co/collections/Zyphra/zonos-v01-67ac661c85e1898670823b4f). [[blog]](https://www.zyphra.com/post/beta-release-of-zonos-v0-1), [[Zyphra Audio (hosted service)]](https://maia.zyphra.com/sign-in?redirect_url=https%3A%2F%2Fmaia.zyphra.com%2Faudio)
## Unleashed
Use this space to generate long-form speech up to around ~2 minutes in length. To generate an unlimited length, clone this space and run it locally.
### Tips

- When providing prefix audio, include the text of the prefix audio in your speech text to ensure a smooth transition.
- The appropriate range of Speaking Rate and Pitch STD are highly dependent on the speaker audio. Start with the defaults and adjust as needed.
- Emotion sliders do not completely function intuitively, and require some experimentation to get the desired effect.
""".strip()

# Create pipelines, downloading required files as necessary
speech_enhancement = Task.get("speech-enhancement", model="deep-filter-net-v3", available_only=False)
speech_enhancement.download_required_files(text_callback=print)
hybrid_task = Task.get("speech-synthesis", model="zonos-hybrid", available_only=False)
hybrid_task.download_required_files(text_callback=print)
hybrid_pipe = hybrid_task()
hybrid_pipe.load(allow_optional=True)

transformer_task = Task.get(
    "speech-synthesis", model="zonos-transformer", available_only=False
)
transformer_task.download_required_files(text_callback=print)
transformer_pipe = transformer_task()

if is_hf_spaces:
    # Must load all models on GPU when using ZERO
    transformer_pipe.load(allow_optional=True)

# Global state
pipelines = {
    "Zonos Transformer v0.1": transformer_pipe,
    "Zonos Hybrid v0.1": hybrid_pipe,
}
pipeline_names = list(pipelines.keys())
supported_language_codes = hybrid_pipe.supported_languages  # Same for both pipes

# Model toggle
def update_ui(pipeline_choice: str) -> Tuple[Dict[str, Any], ...]:
    """
    Dynamically show/hide UI elements based on the model's conditioners.
    """
    if not is_hf_spaces:
        # When not using ZERO, we can onload/offload pipes
        for pipeline_name, pipeline in pipelines.items():
            if pipeline_name == pipeline_choice:
                pipeline.load()
            else:
                pipeline.unload()

    pipe = pipelines[pipeline_choice]
    cond_names = [c.name for c in pipe.pretrained.model.prefix_conditioner.conditioners]

    vqscore_update = gr.update(visible=("vqscore_8" in cond_names))
    emotion_update = gr.update(visible=("emotion" in cond_names))
    fmax_update = gr.update(visible=("fmax" in cond_names))
    pitch_update = gr.update(visible=("pitch_std" in cond_names))
    speaking_rate_update = gr.update(visible=("speaking_rate" in cond_names))
    dnsmos_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
    speaker_noised_update = gr.update(visible=("speaker_noised" in cond_names))

    return (
        vqscore_update,
        emotion_update,
        fmax_update,
        pitch_update,
        speaking_rate_update,
        dnsmos_update,
        speaker_noised_update,
    )

# Invocation method
@spaces.GPU(duration=180)
def generate_audio(
    pipeline_choice: str,
    text: str,
    language: str,
    speaker_audio: Optional[str],
    prefix_audio: Optional[str],
    e1: float,
    e2: float,
    e3: float,
    e4: float,
    e5: float,
    e6: float,
    e7: float,
    e8: float,
    vq_single: float,
    fmax: float,
    pitch_std: float,
    speaking_rate: float,
    dnsmos_ovrl: float,
    speaker_noised: bool,
    cfg_scale: float,
    min_p: float,
    seed: int,
    max_chunk_length: int,
    cross_fade_duration: float,
    punctuation_pause_duration: float,
    target_rms: float,
    randomize_seed: bool,
    skip_dnsmos: bool,
    skip_vqscore: bool,
    skip_fmax: bool,
    skip_pitch: bool,
    skip_speaking_rate: bool,
    skip_emotion: bool,
    skip_speaker: bool,
    speaker_pitch_shift: float,
    speaker_equalize: bool,
    speaker_enhance: bool,
    prefix_equalize: bool,
    prefix_enhance: bool,
    enhance: bool,
    progress=gr.Progress(),
) -> Tuple[Tuple[int, np.ndarray[Any, Any]], int]:
    """
    Generates audio based on the provided UI parameters.
    """
    selected_pipeline = pipelines[pipeline_choice]
    if randomize_seed:
        seed = np.random.randint(0, 2**32)

    def on_progress(step: int, total: int) -> None:
        progress((step, total))

    print(f"{speaker_audio=}")
    selected_pipeline.on_progress(on_progress)
    try:
        wav_out = selected_pipeline(
            text=text,
            enhance=enhance,
            language=language,
            reference_audio=speaker_audio,
            reference_audio_pitch_shift=speaker_pitch_shift,
            equalize_reference_audio=speaker_equalize,
            enhance_reference_audio=speaker_enhance,
            prefix_audio=prefix_audio,
            equalize_prefix_audio=prefix_equalize,
            enhance_prefix_audio=prefix_enhance,
            seed=seed,
            max_chunk_length=max_chunk_length,
            cross_fade_duration=cross_fade_duration,
            punctuation_pause_duration=punctuation_pause_duration,
            target_rms=target_rms,
            cfg_scale=cfg_scale,
            min_p=min_p,
            fmax=fmax,
            pitch_std=pitch_std,
            emotion_happiness=e1,
            emotion_sadness=e2,
            emotion_disgust=e3,
            emotion_fear=e4,
            emotion_surprise=e5,
            emotion_anger=e6,
            emotion_other=e7,
            emotion_neutral=e8,
            speaking_rate=speaking_rate,
            vq_score=vq_single,
            speaker_noised=speaker_noised,
            dnsmos=dnsmos_ovrl,
            skip_speaker=skip_speaker,
            skip_dnsmos=skip_dnsmos,
            skip_vq_score=skip_vqscore,
            skip_fmax=skip_fmax,
            skip_pitch=skip_pitch,
            skip_speaking_rate=skip_speaking_rate,
            skip_emotion=skip_emotion,
            output_format="float",
        )

        return (
            (
                48000 if enhance else 44100,
                wav_out.squeeze().numpy()
            ),
            seed
        )
    finally:
        selected_pipeline.off_progress()

# Interface
if __name__ == "__main__":
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column(scale=3):
                gr.Markdown(header_markdown)
                
            gr.Image(
                value="https://raw.githubusercontent.com/Zyphra/Zonos/refs/heads/main/assets/ZonosHeader.png",
                container=False,
                interactive=False,
                show_label=False,
                show_share_button=False,
                show_fullscreen_button=False,
                show_download_button=False,
            )
    
        with gr.Row(equal_height=True):
            pipeline_choice = gr.Dropdown(
                choices=pipeline_names,
                value=pipeline_names[0],
                label="Zonos Model Variant",
            )
            language = gr.Dropdown(
                choices=supported_language_codes,
                value="en-us",
                label="Language",
            )
            enhanced_checkbox = gr.Checkbox(
                value=True,
                label="Enhance Output with DeepFilterNet"
            )
    
        with gr.Row():
            if not is_hf_spaces:
                limit_text = "Unlimited"
            else:
                limit_text = f"Up to {max_characters}"
    
            text = gr.Textbox(
                label=f"Speech Text ({limit_text} Characters)",
                value="Zonos is a state-of-the-art text-to-speech model that generates expressive and natural-sounding audio with robust customization options.",
                lines=4,
                max_lines=20,
                max_length=max_characters if is_hf_spaces else None,
            )
    
        with gr.Row():
            generate_button = gr.Button("Generate Audio")
    
        with gr.Row():
            output_audio = gr.Audio(label="Generated Audio", type="numpy", autoplay=True)
    
        with gr.Row():
            gr.Markdown("## Long-Form Parameters")
    
        with gr.Column(variant="panel"):
            with gr.Row(equal_height=True):
                max_chunk_length = gr.Slider(
                    1, 300, 150, 1, label="Max Chunk Length (Characters)",
                    info="The maximum number of characters to generate in a single chunk. Zonos itself has a much higher limit than this, but consistency breaks down as you go past ~200 characters or so."
                )
                target_rms = gr.Slider(
                    0.0, 1.0, 0.10, 0.01, label="Target RMS",
                    info="The target RMS (root-mean-square) amplitude for the generated audio. Each chunk will have its loudness normalized to this value to ensure consistent volume levels."
                )
            with gr.Row(equal_height=True):
                punctuation_pause_duration = gr.Slider(
                    0, 1, 0.10, 0.01, label="Punctuation Pause Duration (Seconds)",
                    info="Pause duration to add after a chunk that ends with punctuation. Full-stop punctuation (periods) will have the entire length, while shorter pauses will use half of this duration."
                )
                cross_fade_duration = gr.Slider(
                    0, 1, 0.15, 0.01, label="Chunk Cross-Fade Duration (Seconds)",
                    info="The duration of the cross-fade between chunks. This helps to smooth out transitions between chunks. In general, this should be set to a value greater than the pause duration."
                )
    
        with gr.Row():
            gr.Markdown("## Generation Parameters")
    
        with gr.Row(variant="panel", equal_height=True):
            with gr.Column():
                prefix_audio = gr.Audio(
                    label="Optional Prefix Audio (continue from this audio)",
                    type="filepath",
                )
                prefix_equalize_checkbox = gr.Checkbox(label="Equalize Prefix Audio", value=True)
                prefix_enhance_checkbox = gr.Checkbox(label="Enhance Prefix Audio with DeepFilterNet", value=True)
                
            with gr.Column(scale=3):
                cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="CFG Scale")
                min_p_slider = gr.Slider(0.0, 1.0, 0.15, 0.01, label="Min P")
                seed_number = gr.Number(label="Seed", value=6475309, precision=0)
                randomize_seed_toggle = gr.Checkbox(label="Randomize Seed", value=True)
    
        with gr.Row():
            gr.Markdown(
                "## Conditioning Parameters\nAll of these types of conditioning are optional and can be disabled."
            )
    
        with gr.Row(variant="panel", equal_height=True) as speaker_row:
            with gr.Column():
                speaker_uncond = gr.Checkbox(label="Skip Speaker")
                speaker_noised_checkbox = gr.Checkbox(
                    label="Speaker Noised",
                    value=False,
                    interactive=False,
                    info="'Speaker Noised' is a conditioning value that the model understands, not a processing step. Check this box if your input audio is noisy."
                )
                speaker_equalize_checkbox = gr.Checkbox(label="Equalize Speaker Audio", value=True)
                speaker_enhance_checkbox = gr.Checkbox(label="Enhance Speaker Audio with DeepFilterNet", value=True)

                def on_enhanced_change(use_enhance: bool) -> Dict[str, Any]:
                    update_dict = {"interactive": not use_enhance}
                    if use_enhance:
                        update_dict["value"] = False
                    return gr.update(**update_dict)

                speaker_enhance_checkbox.change(
                    fn=on_enhanced_change,
                    inputs=[speaker_enhance_checkbox],
                    outputs=[speaker_noised_checkbox]
                )
                speaker_pitch_shift = gr.Slider(
                    -1200, 1200, -44.99, 0.01, label="Speaker Pitch Shift (Cents)",
                    info="A pitch shift to apply to speaker audio before extracting embeddings. A slight down-shift of ~45 cents tends to produce a more accurate voice cloning."
                )
                
            speaker_audio = gr.Audio(
                label="Optional Speaker Audio (for cloning)",
                type="filepath",
                scale=3,
            )
    
        with gr.Row(variant="panel", equal_height=True) as emotion_row:
            emotion_uncond = gr.Checkbox(label="Skip Emotion")
            with gr.Column(scale=3):
                with gr.Row():
                    emotion1 = gr.Slider(0.0, 1.0, 0.307, 0.001, label="Happiness")
                    emotion2 = gr.Slider(0.0, 1.0, 0.025, 0.001, label="Sadness")
                    emotion3 = gr.Slider(0.0, 1.0, 0.025, 0.001, label="Disgust")
                    emotion4 = gr.Slider(0.0, 1.0, 0.025, 0.001, label="Fear")
                with gr.Row():
                    emotion5 = gr.Slider(0.0, 1.0, 0.025, 0.001, label="Surprise")
                    emotion6 = gr.Slider(0.0, 1.0, 0.025, 0.001, label="Anger")
                    emotion7 = gr.Slider(0.0, 1.0, 0.025, 0.001, label="Other")
                    emotion8 = gr.Slider(0.0, 1.0, 0.307, 0.001, label="Neutral")
    
        with gr.Row(variant="panel", equal_height=True) as dnsmos_row:
            dnsmos_uncond = gr.Checkbox(label="Skip DNSMOS")
            dnsmos_slider = gr.Slider(
                1.0,
                5.0,
                value=4.0,
                step=0.1,
                label="Deep Noise Suppression Mean Opinion Score [arXiv 2010.15258]",
                scale=3,
            )
    
        with gr.Row(variant="panel", equal_height=True) as vq_score_row:
            vq_uncond = gr.Checkbox(label="Skip VQScore")
            vq_single_slider = gr.Slider(
                0.5, 0.8, 0.78, 0.01, label="VQScore [arXiv 2402.16321]", scale=3
            )
    
        with gr.Row(variant="panel", equal_height=True) as fmax_row:
            fmax_uncond = gr.Checkbox(label="Skip Fmax")
            fmax_slider = gr.Slider(
                0, 22050, value=22050, step=1, label="Fmax (Hz)", scale=3
            )
    
        with gr.Row(variant="panel", equal_height=True) as pitch_row:
            pitch_uncond = gr.Checkbox(label="Skip Pitch")
            pitch_std_slider = gr.Slider(
                0.0, 300.0, value=20.0, step=1, label="Pitch Standard Deviation", scale=3
            )
    
        with gr.Row(variant="panel", equal_height=True) as speaking_rate_row:
            speaking_rate_uncond = gr.Checkbox(label="Skip Speaking Rate")
            speaking_rate_slider = gr.Slider(
                5.0, 30.0, value=15.0, step=0.5, label="Speaking Rate", scale=3
            )
    
        pipeline_choice.change(
            fn=update_ui,
            inputs=[pipeline_choice],
            outputs=[
                vq_score_row,
                emotion_row,
                fmax_row,
                pitch_row,
                speaking_rate_row,
                dnsmos_row,
                speaker_noised_checkbox,
            ],
        )
    
        # Trigger UI update on load
        demo.load(
            fn=update_ui,
            inputs=[pipeline_choice],
            outputs=[
                vq_score_row,
                emotion_row,
                fmax_row,
                pitch_row,
                speaking_rate_row,
                dnsmos_row,
                speaker_noised_checkbox,
            ],
        )
    
        # Generate audio on button click
        generate_button.click(
            fn=generate_audio,
            inputs=[
                pipeline_choice,
                text,
                language,
                speaker_audio,
                prefix_audio,
                emotion1,
                emotion2,
                emotion3,
                emotion4,
                emotion5,
                emotion6,
                emotion7,
                emotion8,
                vq_single_slider,
                fmax_slider,
                pitch_std_slider,
                speaking_rate_slider,
                dnsmos_slider,
                speaker_noised_checkbox,
                cfg_scale_slider,
                min_p_slider,
                seed_number,
                max_chunk_length,
                cross_fade_duration,
                punctuation_pause_duration,
                target_rms,
                randomize_seed_toggle,
                dnsmos_uncond,
                vq_uncond,
                fmax_uncond,
                pitch_uncond,
                speaking_rate_uncond,
                emotion_uncond,
                speaker_uncond,
                speaker_pitch_shift,
                speaker_equalize_checkbox,
                speaker_enhance_checkbox,
                prefix_equalize_checkbox,
                prefix_enhance_checkbox,
                enhanced_checkbox,
            ],
            outputs=[output_audio, seed_number],
        )

        demo.launch()