Dia-1.6B / app.py
Freddy Boulton
Deeplink
ee0b7f4
raw
history blame
16.5 kB
import tempfile
import time
from pathlib import Path
from typing import Optional, Tuple
import spaces
import gradio as gr
from gradio_dialogue import Dialogue
import numpy as np
import soundfile as sf
import torch
from dia.model import Dia
# Load Nari model and config
print("Loading Nari model...")
try:
# Use the function from inference.py
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
except Exception as e:
print(f"Error loading Nari model: {e}")
raise
@spaces.GPU
def run_inference(
text_input: str,
audio_prompt_input: Optional[Tuple[int, np.ndarray]],
max_new_tokens: int,
cfg_scale: float,
temperature: float,
top_p: float,
cfg_filter_top_k: int,
speed_factor: float,
):
"""
Runs Nari inference using the globally loaded model and provided inputs.
Uses temporary files for text and audio prompt compatibility with inference.generate.
"""
# global model, device # Access global model, config, device
if not text_input or text_input.isspace():
raise gr.Error("Text input cannot be empty.")
temp_txt_file_path = None
temp_audio_prompt_path = None
output_audio = (44100, np.zeros(1, dtype=np.float32))
try:
prompt_path_for_generate = None
if audio_prompt_input is not None:
sr, audio_data = audio_prompt_input
# Check if audio_data is valid
if (
audio_data is None or audio_data.size == 0 or audio_data.max() == 0
): # Check for silence/empty
gr.Warning("Audio prompt seems empty or silent, ignoring prompt.")
else:
# Save prompt audio to a temporary WAV file
with tempfile.NamedTemporaryFile(
mode="wb", suffix=".wav", delete=False
) as f_audio:
temp_audio_prompt_path = f_audio.name # Store path for cleanup
# Basic audio preprocessing for consistency
# Convert to float32 in [-1, 1] range if integer type
if np.issubdtype(audio_data.dtype, np.integer):
max_val = np.iinfo(audio_data.dtype).max
audio_data = audio_data.astype(np.float32) / max_val
elif not np.issubdtype(audio_data.dtype, np.floating):
gr.Warning(
f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion."
)
# Attempt conversion, might fail for complex types
try:
audio_data = audio_data.astype(np.float32)
except Exception as conv_e:
raise gr.Error(
f"Failed to convert audio prompt to float32: {conv_e}"
)
# Ensure mono (average channels if stereo)
if audio_data.ndim > 1:
if audio_data.shape[0] == 2: # Assume (2, N)
audio_data = np.mean(audio_data, axis=0)
elif audio_data.shape[1] == 2: # Assume (N, 2)
audio_data = np.mean(audio_data, axis=1)
else:
gr.Warning(
f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis."
)
audio_data = (
audio_data[0]
if audio_data.shape[0] < audio_data.shape[1]
else audio_data[:, 0]
)
audio_data = np.ascontiguousarray(
audio_data
) # Ensure contiguous after slicing/mean
# Write using soundfile
try:
sf.write(
temp_audio_prompt_path, audio_data, sr, subtype="FLOAT"
) # Explicitly use FLOAT subtype
prompt_path_for_generate = temp_audio_prompt_path
print(
f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})"
)
except Exception as write_e:
print(f"Error writing temporary audio file: {write_e}")
raise gr.Error(f"Failed to save audio prompt: {write_e}")
# 3. Run Generation
start_time = time.time()
# Use torch.inference_mode() context manager for the generation call
with torch.inference_mode():
output_audio_np = model.generate(
text_input,
max_tokens=max_new_tokens,
cfg_scale=cfg_scale,
temperature=temperature,
top_p=top_p,
cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
use_torch_compile=False, # Keep False for Gradio stability
audio_prompt=prompt_path_for_generate,
)
end_time = time.time()
print(f"Generation finished in {end_time - start_time:.2f} seconds.")
# 4. Convert Codes to Audio
if output_audio_np is not None:
# Get sample rate from the loaded DAC model
output_sr = 44100
# --- Slow down audio ---
original_len = len(output_audio_np)
# Ensure speed_factor is positive and not excessively small/large to avoid issues
speed_factor = max(0.1, min(speed_factor, 5.0))
target_len = int(
original_len / speed_factor
) # Target length based on speed_factor
if (
target_len != original_len and target_len > 0
): # Only interpolate if length changes and is valid
x_original = np.arange(original_len)
x_resampled = np.linspace(0, original_len - 1, target_len)
resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
output_audio = (
output_sr,
resampled_audio_np.astype(np.float32),
) # Use resampled audio
print(
f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
)
else:
output_audio = (
output_sr,
output_audio_np,
) # Keep original if calculation fails or no change
print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
# --- End slowdown ---
print(
f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
)
# Explicitly convert to int16 to prevent Gradio warning
if (
output_audio[1].dtype == np.float32
or output_audio[1].dtype == np.float64
):
audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
output_audio = (output_sr, audio_for_gradio)
print("Converted audio to int16 for Gradio output.")
else:
print("\nGeneration finished, but no valid tokens were produced.")
# Return default silence
gr.Warning("Generation produced no output.")
except Exception as e:
print(f"Error during inference: {e}")
import traceback
traceback.print_exc()
# Re-raise as Gradio error to display nicely in the UI
raise gr.Error(f"Inference failed: {e}")
finally:
# 5. Cleanup Temporary Files defensively
if temp_txt_file_path and Path(temp_txt_file_path).exists():
try:
Path(temp_txt_file_path).unlink()
print(f"Deleted temporary text file: {temp_txt_file_path}")
except OSError as e:
print(
f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}"
)
if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists():
try:
Path(temp_audio_prompt_path).unlink()
print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}")
except OSError as e:
print(
f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}"
)
return output_audio
# --- Create Gradio Interface ---
css = """
#col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
"""
# Attempt to load default text from example.txt
default_text = [{"speaker": "Speaker 1", "text": "Dia is an open weights text to dialogue model."},
{"speaker": "Speaker 2", "text": "You get full control over scripts and voices."},
{"speaker": "Speaker 1", "text": "Wow. Amazing. (laughs)"},
{"speaker": "Speaker 2", "text": "Try it now on Git hub or Hugging Face."},
]
example_txt_path = Path("./example.txt")
if example_txt_path.exists():
try:
default_text = example_txt_path.read_text(encoding="utf-8").strip()
if not default_text: # Handle empty example file
default_text = "Example text file was empty."
except Exception as e:
print(f"Warning: Could not read example.txt: {e}")
def formatter(speaker, text):
speaker = speaker.split(" ")[1]
return f"[S{speaker}] {text}"
emotions = [
"(laughs)",
"(clears throat)",
"(sighs)",
"(gasps)",
"(coughs)",
"(singing)",
"(sings)",
"(mumbles)",
"(beep)",
"(groans)",
"(sniffs)",
"(claps)",
"(screams)",
"(inhales)",
"(exhales)",
"(applause)",
"(burps)",
"(humming)",
"(sneezes)",
"(chuckle)",
"(whistles)",
]
# Build Gradio UI
with gr.Blocks(css=css) as demo:
gr.Markdown("# Nari Text-to-Speech Synthesis")
with gr.Row(equal_height=False):
with gr.Column(scale=1):
text_input = Dialogue(
speakers=["Speaker 1", "Speaker 2"],
emotions=emotions,
formatter=formatter,
value=default_text,
)
audio_prompt_input = gr.Audio(
label="Audio Prompt (Optional)",
show_label=True,
sources=["upload", "microphone"],
type="numpy",
)
with gr.Accordion("Generation Parameters", open=False):
max_new_tokens = gr.Slider(
label="Max New Tokens (Audio Length)",
minimum=860,
maximum=3072,
value=model.config.data.audio_length, # Use config default if available, else fallback
step=50,
info="Controls the maximum length of the generated audio (more tokens = longer audio).",
)
cfg_scale = gr.Slider(
label="CFG Scale (Guidance Strength)",
minimum=1.0,
maximum=5.0,
value=3.0, # Default from inference.py
step=0.1,
info="Higher values increase adherence to the text prompt.",
)
temperature = gr.Slider(
label="Temperature (Randomness)",
minimum=1.0,
maximum=1.5,
value=1.3, # Default from inference.py
step=0.05,
info="Lower values make the output more deterministic, higher values increase randomness.",
)
top_p = gr.Slider(
label="Top P (Nucleus Sampling)",
minimum=0.80,
maximum=1.0,
value=0.95, # Default from inference.py
step=0.01,
info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.",
)
cfg_filter_top_k = gr.Slider(
label="CFG Filter Top K",
minimum=15,
maximum=50,
value=30,
step=1,
info="Top k filter for CFG guidance.",
)
speed_factor_slider = gr.Slider(
label="Speed Factor",
minimum=0.8,
maximum=1.0,
value=0.94,
step=0.02,
info="Adjusts the speed of the generated audio (1.0 = original speed).",
)
run_button = gr.Button("Generate Audio", variant="primary")
with gr.Column(scale=1):
audio_output = gr.Audio(
label="Generated Audio",
type="numpy",
autoplay=False,
)
gr.Deeplink()
# Link button click to function
run_button.click(
fn=run_inference,
inputs=[
text_input,
audio_prompt_input,
max_new_tokens,
cfg_scale,
temperature,
top_p,
cfg_filter_top_k,
speed_factor_slider,
],
outputs=[audio_output], # Add status_output here if using it
api_name="generate_audio",
)
# Add examples (ensure the prompt path is correct or remove it if example file doesn't exist)
example_prompt_path = "./example_prompt.mp3" # Adjust if needed
examples_list = [
[
[{"speaker": "Speaker 1", "text": "Oh fire! Oh my goodness! What's the procedure? What to we do people? The smoke could be coming through an air duct!"},
{"speaker": "Speaker 2", "text": "Oh my god! Okay.. it's happening. Everybody stay calm!"},
{"speaker": "Speaker 1", "text": "What's the procedure..."},
{"speaker": "Speaker 2", "text": "Everybody stay fucking calm!!!... Everybody fucking calm down!!!!! \n[S1] No! No! If you touch the handle, if its hot there might be a fire down the hallway!"},
],
None,
3072,
3.0,
1.3,
0.95,
35,
0.94,
],
[
[{"speaker": "Speaker 1", "text": "Open weights text to dialogue model."},
{"speaker": "Speaker 2", "text": "You get full control over scripts and voices."},
{"speaker": "Speaker 1", "text": "I'm biased, but I think we clearly won."},
{"speaker": "Speaker 2", "text": "Hard to disagree. (laughs)"},
{"speaker": "Speaker 1", "text": "Thanks for listening to this demo."},
{"speaker": "Speaker 2", "text": "Try it now on Git hub and Hugging Face."},
{"speaker": "Speaker 1", "text": "If you liked our model, please give us a star and share to your friends."},
{"speaker": "Speaker 2", "text": "This was Nari Labs."},
],
example_prompt_path if Path(example_prompt_path).exists() else None,
3072,
3.0,
1.3,
0.95,
35,
0.94,
],
]
if examples_list:
gr.Examples(
examples=examples_list,
inputs=[
text_input,
audio_prompt_input,
max_new_tokens,
cfg_scale,
temperature,
top_p,
cfg_filter_top_k,
speed_factor_slider,
],
outputs=[audio_output],
fn=run_inference,
cache_examples=False,
label="Examples (Click to Run)",
)
else:
gr.Markdown("_(No examples configured or example prompt file missing)_")
# --- Launch the App ---
if __name__ == "__main__":
print("Launching Gradio interface...")
# set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
# use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
demo.launch(ssr_mode=False)