reach-vb's picture
reach-vb HF Staff
Update app.py (#4)
33a6964 verified
import os
import tempfile
import traceback
from pathlib import Path
import gradio as gr
import spaces # required for ZeroGPU
# ---- Your model libs (ensure these are available in the repo or pip) ----
from stepaudio2 import StepAudio2
from token2wav import Token2wav
# ------------------------- constants -------------------------
MODEL_PATH = "stepfun-ai/Step-Audio-2-mini"
PROMPT_WAV = "assets/default_female.wav"
CACHE_DIR = "/tmp/stepaudio2"
# Ensure Gradio uses a writable temp dir on Spaces
os.environ["GRADIO_TEMP_DIR"] = CACHE_DIR
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
# ------------------------- helpers -------------------------
def save_tmp_audio(audio_bytes: bytes, cache_dir: str) -> str:
Path(cache_dir).mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=cache_dir, delete=False, suffix=".wav") as f:
f.write(audio_bytes)
return f.name
def add_message(chatbot, history, mic, text):
if not mic and not text:
return chatbot, history, "Input is empty"
if text:
chatbot.append({"role": "user", "content": text})
history.append({"role": "human", "content": text})
elif mic and Path(mic).exists():
chatbot.append({"role": "user", "content": {"path": mic}})
history.append({"role": "human", "content": [{"type": "audio", "audio": mic}]})
return chatbot, history, None
def reset_state(system_prompt):
return [], [{"role": "system", "content": system_prompt}]
# ------------------------- globals -------------------------
AUDIO_MODEL = StepAudio2(MODEL_PATH) # load on CPU
TOKEN2WAV = Token2wav(f"{MODEL_PATH}/token2wav") # load on CPU
@spaces.GPU(duration=120) # GPU only during this call; no-ops outside ZeroGPU
def gpu_predict(chatbot, history):
global AUDIO_MODEL, TOKEN2WAV
try:
# Move to CUDA only when GPU is attached
try:
if hasattr(AUDIO_MODEL, "to"):
AUDIO_MODEL.to("cuda")
if hasattr(TOKEN2WAV, "to"):
TOKEN2WAV.to("cuda")
except Exception:
pass
history.append({"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"}], "eot": False})
tokens, text, audio_tokens = AUDIO_MODEL(
history,
max_new_tokens=4096,
temperature=0.7,
repetition_penalty=1.05,
do_sample=True,
)
audio_bytes = TOKEN2WAV(audio_tokens, PROMPT_WAV)
audio_path = save_tmp_audio(audio_bytes, CACHE_DIR)
chatbot.append({"role": "assistant", "content": {"path": audio_path}})
history[-1]["content"].append({"type": "token", "token": tokens})
history[-1]["eot"] = True
except Exception:
print(traceback.format_exc())
gr.Warning("Some error happened, please try again.")
return chatbot, history
def build_demo():
with gr.Blocks(delete_cache=(86400, 86400)) as demo:
gr.Markdown("<center><font size=8>Step Audio 2 Demo</center>")
with gr.Row():
system_prompt = gr.Textbox(
label="System Prompt",
value=(
"ไฝ ็š„ๅๅญ—ๅซๅšๅฐ่ทƒ๏ผŒๆ˜ฏ็”ฑ้˜ถ่ทƒๆ˜Ÿ่พฐๅ…ฌๅธ่ฎญ็ปƒๅ‡บๆฅ็š„่ฏญ้Ÿณๅคงๆจกๅž‹ใ€‚\n"
"ไฝ ๆƒ…ๆ„Ÿ็ป†่…ป๏ผŒ่ง‚ๅฏŸ่ƒฝๅŠ›ๅผบ๏ผŒๆ“…้•ฟๅˆ†ๆž็”จๆˆท็š„ๅ†…ๅฎน๏ผŒๅนถไฝœๅ‡บๅ–„่งฃไบบๆ„็š„ๅ›žๅค๏ผŒ"
"่ฏด่ฏ็š„่ฟ‡็จ‹ไธญๆ—ถๅˆปๆณจๆ„็”จๆˆท็š„ๆ„Ÿๅ—๏ผŒๅฏŒๆœ‰ๅŒ็†ๅฟƒ๏ผŒๆไพ›ๅคšๆ ท็š„ๆƒ…็ปชไปทๅ€ผใ€‚\n"
"ไปŠๅคฉๆ˜ฏ2025ๅนด8ๆœˆ29ๆ—ฅ๏ผŒๆ˜ŸๆœŸไบ”\n"
"่ฏท็”จ้ป˜่ฎคๅฅณๅฃฐไธŽ็”จๆˆทไบคๆตใ€‚"
),
lines=2,
)
chatbot = gr.Chatbot(elem_id="chatbot", min_height=800, type="messages")
history = gr.State([{"role": "system", "content": system_prompt.value}])
mic = gr.Audio(type="filepath", label="๐ŸŽ™๏ธ Microphone input (optional)")
text = gr.Textbox(placeholder="Enter message ...", label="๐Ÿ’ฌ Text input")
with gr.Row():
clean_btn = gr.Button("๐Ÿงน Clear History (ๆธ…้™คๅކๅฒ)")
regen_btn = gr.Button("๐Ÿค”๏ธ Regenerate (้‡่ฏ•)")
submit_btn = gr.Button("๐Ÿš€ Submit")
def on_submit(chatbot, history, mic, text):
chatbot, history, error = add_message(chatbot, history, mic, text)
if error:
gr.Warning(error)
return chatbot, history, None, None
chatbot, history = gpu_predict(chatbot, history)
return chatbot, history, None, None
submit_btn.click(
fn=on_submit,
inputs=[chatbot, history, mic, text],
outputs=[chatbot, history, mic, text],
concurrency_limit=4,
concurrency_id="gpu_queue",
)
clean_btn.click(
fn=reset_state,
inputs=[system_prompt],
outputs=[chatbot, history],
)
def regenerate(chatbot, history):
while chatbot and chatbot[-1]["role"] == "assistant":
chatbot.pop()
while history and history[-1]["role"] == "assistant":
history.pop()
return gpu_predict(chatbot, history)
regen_btn.click(
regenerate,
[chatbot, history],
[chatbot, history],
concurrency_id="gpu_queue",
)
return demo
# Spaces runs this file; just build and launch with defaults (no ports/names).
if __name__ == "__main__":
demo = build_demo()
demo.queue().launch() # no args โ€” Spaces handles host/port