"""
ACE-Step: A Step Towards Music Generation Foundation Model
https://github.com/ace-step/ACE-Step
Apache 2.0 License
"""

import gradio as gr
import librosa
import os


TAG_DEFAULT = "pop, piano, rap, dark, atmospheric"
LYRIC_DEFAULT = """[verse]
月光爬上窗 染白冷的床
心跳的方向 带我入迷惘
黑夜吞噬光 命运的纸张
爱是血色霜 邪恶又芬芳

[chorus]
你是猎人的欲望 我是迷途的小羊
深陷你眼眸的荒 唐突献出心脏
我在夜里回荡 是谁给我希望
黑暗风中飘荡 假装不再受伤

[verse]
心锁在门外 谁会解开关怀
温柔的手拍 藏着冷酷杀害
思绪如尘埃 撞击爱的霹雳
灵魂的独白 为你沾满血迹

[bridge]
你是噩梦的歌唱 是灵魂的捆绑
绝望中带着光 悬崖边的渴望
心跳被你鼓掌 恶魔也痴痴想
渐渐没了抵抗 古老诡计流淌

[chorus]
你是猎人的欲望 我是迷途的小羊
深陷你眼眸的荒 唐突献出心脏
我在夜里回荡 是谁给我希望
黑暗风中飘荡 假装不再受伤

[outro]
爱如月黑无光 渗进梦的战场
逃入无声的场 放手或心嚷嚷
隐秘的极端 爱是极致风浪
灵魂彻底交偿 你是终极虚妄
"""

# First, let's define the presets at the top of the file, after the imports
GENRE_PRESETS = {
    "现代流行 (Modern Pop)": "pop, synth, drums, guitar, 120 bpm, upbeat, catchy, vibrant, female vocals, polished vocals",
    "摇滚 (Rock)": "rock, electric guitar, drums, bass, 130 bpm, energetic, rebellious, gritty, male vocals, raw vocals",
    "嘻哈 (Hip Hop)": "hip hop, 808 bass, hi-hats, synth, 90 bpm, bold, urban, intense, male vocals, rhythmic vocals",
    "乡村 (Country)": "country, acoustic guitar, steel guitar, fiddle, 100 bpm, heartfelt, rustic, warm, male vocals, twangy vocals",
    "电子舞曲 (EDM)": "edm, synth, bass, kick drum, 128 bpm, euphoric, pulsating, energetic, instrumental",
    "雷鬼 (Reggae)": "reggae, guitar, bass, drums, 80 bpm, chill, soulful, positive, male vocals, smooth vocals",
    "古典 (Classical)": "classical, orchestral, strings, piano, 60 bpm, elegant, emotive, timeless, instrumental",
    "爵士 (Jazz)": "jazz, saxophone, piano, double bass, 110 bpm, smooth, improvisational, soulful, male vocals, crooning vocals",
    "金属 (Metal)": "metal, electric guitar, double kick drum, bass, 160 bpm, aggressive, intense, heavy, male vocals, screamed vocals",
    "R&B (R&B)": "r&b, synth, bass, drums, 85 bpm, sultry, groovy, romantic, female vocals, silky vocals"
}

# Add this function to handle preset selection
def update_tags_from_preset(preset_name):
    if preset_name == "自定义 (Custom)":
        return ""
    return GENRE_PRESETS.get(preset_name, "")


def create_output_ui(task_name="已生成的"):
    # For many consumer-grade GPU devices, only one batch can be run
    output_audio1 = gr.Audio(type="filepath", label=f"{task_name} 歌")
    # output_audio2 = gr.Audio(type="filepath", label="Generated Audio 2")
    with gr.Accordion(f"{task_name} 参数(凭据)", open=False):
        input_params_json = gr.JSON(label=f"{task_name} 参数")
    # outputs = [output_audio1, output_audio2]
    outputs = [output_audio1]
    return outputs, input_params_json


def dump_func(*args):
    print(args)
    return []


def create_text2music_ui(
    gr,
    text2music_process_func,
    sample_data_func=None,
    load_data_func=None,
):

    with gr.Row():
        with gr.Column():
            with gr.Row(equal_height=True):
                # add markdown, tags and lyrics examples are from ai music generation community
                audio_duration = gr.Slider(
                    -1,
                    240.0,
                    step=1,
                    value=-1,
                    label="音频时长",
                    interactive=True,
                    info="-1 表示随机时长 (30 ~ 240秒)。",
                    scale=9,
                )
                sample_bnt = gr.Button("示例", variant="secondary", scale=1)

            # audio2audio
            with gr.Row(equal_height=True):
                audio2audio_enable = gr.Checkbox(label="启用音频到音频生成", value=False, info="勾选以使用参考音频进行音频到音频生成。", elem_id="audio2audio_checkbox")
                lora_name_or_path = gr.Dropdown(
                    label="中文说唱",
                    choices=["ACE-Step/ACE-Step-v1-chinese-rap-LoRA", "none"],
                    value="none",
                    allow_custom_value=True,
                )

            ref_audio_input = gr.Audio(type="filepath", label="参考音频 (用于音频到音频生成)", visible=False, elem_id="ref_audio_input", show_download_button=True)
            ref_audio_strength = gr.Slider(
                label="参考音频强度",
                minimum=0.0,
                maximum=1.0,
                step=0.01,
                value=0.5,
                elem_id="ref_audio_strength",
                visible=False,
                interactive=True,
            )

            def toggle_ref_audio_visibility(is_checked):
                return (
                    gr.update(visible=is_checked, elem_id="ref_audio_input"),
                    gr.update(visible=is_checked, elem_id="ref_audio_strength"),
                )

            audio2audio_enable.change(
                fn=toggle_ref_audio_visibility,
                inputs=[audio2audio_enable],
                outputs=[ref_audio_input, ref_audio_strength],
            )

            with gr.Column(scale=2):
                with gr.Group():
                    gr.Markdown("""<center>支持风格、描述和场景。使用逗号分隔不同的标签。</center>""")
                    genre_preset = gr.Dropdown(
                            choices=["自定义 (Custom)"] + list(GENRE_PRESETS.keys()),
                            value="自定义 (Custom)",
                            label="预设",
                            scale=1,
                        )
                    prompt = gr.Textbox(
                            lines=1,
                            label="生成的音乐风格",
                            max_lines=10,
                            value=TAG_DEFAULT,
                            scale=9,
                        )

            # Add the change event for the preset dropdown
            genre_preset.change(
                fn=update_tags_from_preset,
                inputs=[genre_preset],
                outputs=[prompt]
            )
            with gr.Group():
                gr.Markdown("""<center>[verse]、[chorus] 和 [bridge] 来分隔歌词的不同部分。<br>使用 [instrumental] 或 [inst] 生成纯音乐。不支持歌词中的流派结构标签。</center>""")
                lyrics = gr.Textbox(
                    lines=9,
                    label="歌词",
                    max_lines=500,
                    value=LYRIC_DEFAULT,
                )

            with gr.Accordion("基本设置", open=False, visible=False):
                infer_step = gr.Slider(
                    minimum=1,
                    maximum=200,
                    step=1,
                    value=60,
                    label="推理步数",
                    interactive=True,
                )
                guidance_scale = gr.Slider(
                    minimum=0.0,
                    maximum=30.0,
                    step=0.1,
                    value=15.0,
                    label="引导尺度",
                    interactive=True,
                    info="当 guidance_scale_lyric > 1 且 guidance_scale_text > 1 时,不应用引导尺度。",
                )
                guidance_scale_text = gr.Slider(
                    minimum=0.0,
                    maximum=10.0,
                    step=0.1,
                    value=0.0,
                    label="文本引导尺度",
                    interactive=True,
                    info="文本条件的引导尺度。仅适用于 cfg。建议设置 guidance_scale_text=5.0, guidance_scale_lyric=1.5 作为开始。",
                )
                guidance_scale_lyric = gr.Slider(
                    minimum=0.0,
                    maximum=10.0,
                    step=0.1,
                    value=0.0,
                    label="歌词引导尺度",
                    interactive=True,
                )

                manual_seeds = gr.Textbox(
                    label="手动种子 (默认为无)",
                    placeholder="1,2,3,4",
                    value=None,
                    info="生成种子",
                )

            with gr.Accordion("高级设置", open=False, visible=False):
                scheduler_type = gr.Radio(
                    ["euler", "heun"],
                    value="euler",
                    label="调度器类型",
                    elem_id="scheduler_type",
                    info="生成调度器类型。推荐使用 euler。heun 将花费更多时间。",
                )
                cfg_type = gr.Radio(
                    ["cfg", "apg", "cfg_star"],
                    value="apg",
                    label="CFG 类型",
                    elem_id="cfg_type",
                    info="生成 CFG 类型。推荐使用 apg。cfg 和 cfg_star 几乎相同。",
                )
                use_erg_tag = gr.Checkbox(
                    label="对标签使用 ERG",
                    value=True,
                    info="对标签使用熵校正引导。它将注意力乘以一个温度,以减弱标签条件并提高多样性。",
                )
                use_erg_lyric = gr.Checkbox(
                    label="对歌词使用 ERG",
                    value=False,
                    info="同上,但应用于歌词编码器的注意力。",
                )
                use_erg_diffusion = gr.Checkbox(
                    label="对扩散模型使用 ERG",
                    value=True,
                    info="同上,但应用于扩散模型的注意力。",
                )

                omega_scale = gr.Slider(
                    minimum=-100.0,
                    maximum=100.0,
                    step=0.1,
                    value=10.0,
                    label="粒度尺度",
                    interactive=True,
                    info="生成粒度尺度。值越高可以减少伪影。",
                )

                guidance_interval = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    step=0.01,
                    value=0.5,
                    label="引导间隔",
                    interactive=True,
                    info="生成引导间隔。0.5 表示仅在中间步骤应用引导 (0.25 * 推理步数 到 0.75 * 推理步数)。",
                )
                guidance_interval_decay = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    step=0.01,
                    value=0.0,
                    label="引导间隔衰减",
                    interactive=True,
                    info="生成引导间隔衰减。引导尺度将在此间隔内从 guidance_scale 衰减到 min_guidance_scale。0.0 表示不衰减。",
                )
                min_guidance_scale = gr.Slider(
                    minimum=0.0,
                    maximum=200.0,
                    step=0.1,
                    value=3.0,
                    label="最小引导尺度",
                    interactive=True,
                    info="引导间隔衰减结束时的最小引导尺度。",
                )
                oss_steps = gr.Textbox(
                    label="OSS 步数",
                    placeholder="16, 29, 52, 96, 129, 158, 172, 183, 189, 200",
                    value=None,
                    info="生成的最优步数。但未充分测试。",
                )

            text2music_bnt = gr.Button("生成", variant="primary")

            outputs, input_params_json = create_output_ui()
            

        def json2output(json_data):
            return (
                json_data["audio_duration"],
                json_data["prompt"],
                json_data["lyrics"],
                json_data["infer_step"],
                json_data["guidance_scale"],
                json_data["scheduler_type"],
                json_data["cfg_type"],
                json_data["omega_scale"],
                ", ".join(map(str, json_data["actual_seeds"])),
                json_data["guidance_interval"],
                json_data["guidance_interval_decay"],
                json_data["min_guidance_scale"],
                json_data["use_erg_tag"],
                json_data["use_erg_lyric"],
                json_data["use_erg_diffusion"],
                ", ".join(map(str, json_data["oss_steps"])),
                (
                    json_data["guidance_scale_text"]
                    if "guidance_scale_text" in json_data
                    else 0.0
                ),
                (
                    json_data["guidance_scale_lyric"]
                    if "guidance_scale_lyric" in json_data
                    else 0.0
                ),
                (
                    json_data["audio2audio_enable"]
                    if "audio2audio_enable" in json_data
                    else False
                ),
                (
                    json_data["ref_audio_strength"]
                    if "ref_audio_strength" in json_data
                    else 0.5
                ),
                (
                    json_data["ref_audio_input"]
                    if "ref_audio_input" in json_data
                    else None
                ),
            )

        def sample_data(lora_name_or_path_):
            json_data = sample_data_func(lora_name_or_path_)
            return json2output(json_data)

        sample_bnt.click(
            sample_data,
            inputs=[lora_name_or_path],
            outputs=[
                audio_duration,
                prompt,
                lyrics,
                infer_step,
                guidance_scale,
                scheduler_type,
                cfg_type,
                omega_scale,
                manual_seeds,
                guidance_interval,
                guidance_interval_decay,
                min_guidance_scale,
                use_erg_tag,
                use_erg_lyric,
                use_erg_diffusion,
                oss_steps,
                guidance_scale_text,
                guidance_scale_lyric,
                audio2audio_enable,
                ref_audio_strength,
                ref_audio_input,
            ],
        )

    text2music_bnt.click(
        fn=text2music_process_func,
        inputs=[
            audio_duration,
            prompt,
            lyrics,
            infer_step,
            guidance_scale,
            scheduler_type,
            cfg_type,
            omega_scale,
            manual_seeds,
            guidance_interval,
            guidance_interval_decay,
            min_guidance_scale,
            use_erg_tag,
            use_erg_lyric,
            use_erg_diffusion,
            oss_steps,
            guidance_scale_text,
            guidance_scale_lyric,
            audio2audio_enable,
            ref_audio_strength,
            ref_audio_input,
            lora_name_or_path,
        ],
        outputs=outputs + [input_params_json],
    )


def create_main_demo_ui(
    text2music_process_func=dump_func,
    sample_data_func=dump_func,
    load_data_func=dump_func,
):
    with gr.Blocks(
        title="ACE-Step",
    ) as demo:
        with gr.Tab("文本转音乐"):
            create_text2music_ui(
                gr=gr,
                text2music_process_func=text2music_process_func,
                sample_data_func=sample_data_func,
                load_data_func=load_data_func,
            )
    return demo


if __name__ == "__main__":
    demo = create_main_demo_ui()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
    )