import tempfile import time from pathlib import Path from typing import Optional, Tuple import spaces import gradio as gr from gradio_dialogue import Dialogue import numpy as np import soundfile as sf import torch from dia.model import Dia # Load Nari model and config print("Loading Nari model...") try: # Use the function from inference.py model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32") except Exception as e: print(f"Error loading Nari model: {e}") raise @spaces.GPU def run_inference( text_input: str, audio_prompt_input: Optional[Tuple[int, np.ndarray]], max_new_tokens: int, cfg_scale: float, temperature: float, top_p: float, cfg_filter_top_k: int, speed_factor: float, ): """ Runs Nari inference using the globally loaded model and provided inputs. Uses temporary files for text and audio prompt compatibility with inference.generate. """ # global model, device # Access global model, config, device if not text_input or text_input.isspace(): raise gr.Error("Text input cannot be empty.") temp_txt_file_path = None temp_audio_prompt_path = None output_audio = (44100, np.zeros(1, dtype=np.float32)) try: prompt_path_for_generate = None if audio_prompt_input is not None: sr, audio_data = audio_prompt_input # Check if audio_data is valid if ( audio_data is None or audio_data.size == 0 or audio_data.max() == 0 ): # Check for silence/empty gr.Warning("Audio prompt seems empty or silent, ignoring prompt.") else: # Save prompt audio to a temporary WAV file with tempfile.NamedTemporaryFile( mode="wb", suffix=".wav", delete=False ) as f_audio: temp_audio_prompt_path = f_audio.name # Store path for cleanup # Basic audio preprocessing for consistency # Convert to float32 in [-1, 1] range if integer type if np.issubdtype(audio_data.dtype, np.integer): max_val = np.iinfo(audio_data.dtype).max audio_data = audio_data.astype(np.float32) / max_val elif not np.issubdtype(audio_data.dtype, np.floating): gr.Warning( f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion." ) # Attempt conversion, might fail for complex types try: audio_data = audio_data.astype(np.float32) except Exception as conv_e: raise gr.Error( f"Failed to convert audio prompt to float32: {conv_e}" ) # Ensure mono (average channels if stereo) if audio_data.ndim > 1: if audio_data.shape[0] == 2: # Assume (2, N) audio_data = np.mean(audio_data, axis=0) elif audio_data.shape[1] == 2: # Assume (N, 2) audio_data = np.mean(audio_data, axis=1) else: gr.Warning( f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis." ) audio_data = ( audio_data[0] if audio_data.shape[0] < audio_data.shape[1] else audio_data[:, 0] ) audio_data = np.ascontiguousarray( audio_data ) # Ensure contiguous after slicing/mean # Write using soundfile try: sf.write( temp_audio_prompt_path, audio_data, sr, subtype="FLOAT" ) # Explicitly use FLOAT subtype prompt_path_for_generate = temp_audio_prompt_path print( f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})" ) except Exception as write_e: print(f"Error writing temporary audio file: {write_e}") raise gr.Error(f"Failed to save audio prompt: {write_e}") # 3. Run Generation start_time = time.time() # Use torch.inference_mode() context manager for the generation call with torch.inference_mode(): output_audio_np = model.generate( text_input, max_tokens=max_new_tokens, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, cfg_filter_top_k=cfg_filter_top_k, # Pass the value here use_torch_compile=False, # Keep False for Gradio stability audio_prompt=prompt_path_for_generate, ) end_time = time.time() print(f"Generation finished in {end_time - start_time:.2f} seconds.") # 4. Convert Codes to Audio if output_audio_np is not None: # Get sample rate from the loaded DAC model output_sr = 44100 # --- Slow down audio --- original_len = len(output_audio_np) # Ensure speed_factor is positive and not excessively small/large to avoid issues speed_factor = max(0.1, min(speed_factor, 5.0)) target_len = int( original_len / speed_factor ) # Target length based on speed_factor if ( target_len != original_len and target_len > 0 ): # Only interpolate if length changes and is valid x_original = np.arange(original_len) x_resampled = np.linspace(0, original_len - 1, target_len) resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np) output_audio = ( output_sr, resampled_audio_np.astype(np.float32), ) # Use resampled audio print( f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed." ) else: output_audio = ( output_sr, output_audio_np, ) # Keep original if calculation fails or no change print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).") # --- End slowdown --- print( f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}" ) # Explicitly convert to int16 to prevent Gradio warning if ( output_audio[1].dtype == np.float32 or output_audio[1].dtype == np.float64 ): audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0) audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16) output_audio = (output_sr, audio_for_gradio) print("Converted audio to int16 for Gradio output.") else: print("\nGeneration finished, but no valid tokens were produced.") # Return default silence gr.Warning("Generation produced no output.") except Exception as e: print(f"Error during inference: {e}") import traceback traceback.print_exc() # Re-raise as Gradio error to display nicely in the UI raise gr.Error(f"Inference failed: {e}") finally: # 5. Cleanup Temporary Files defensively if temp_txt_file_path and Path(temp_txt_file_path).exists(): try: Path(temp_txt_file_path).unlink() print(f"Deleted temporary text file: {temp_txt_file_path}") except OSError as e: print( f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}" ) if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists(): try: Path(temp_audio_prompt_path).unlink() print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}") except OSError as e: print( f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}" ) return output_audio # --- Create Gradio Interface --- css = """ #col-container {max-width: 90%; margin-left: auto; margin-right: auto;} """ # Attempt to load default text from example.txt default_text = [{"speaker": "Speaker 1", "text": "Dia is an open weights text to dialogue model."}, {"speaker": "Speaker 2", "text": "You get full control over scripts and voices."}, {"speaker": "Speaker 1", "text": "Wow. Amazing. (laughs)"}, {"speaker": "Speaker 2", "text": "Try it now on Git hub or Hugging Face."}, ] example_txt_path = Path("./example.txt") if example_txt_path.exists(): try: default_text = example_txt_path.read_text(encoding="utf-8").strip() if not default_text: # Handle empty example file default_text = "Example text file was empty." except Exception as e: print(f"Warning: Could not read example.txt: {e}") def formatter(speaker, text): speaker = speaker.split(" ")[1] return f"[S{speaker}] {text}" emotions = [ "(laughs)", "(clears throat)", "(sighs)", "(gasps)", "(coughs)", "(singing)", "(sings)", "(mumbles)", "(beep)", "(groans)", "(sniffs)", "(claps)", "(screams)", "(inhales)", "(exhales)", "(applause)", "(burps)", "(humming)", "(sneezes)", "(chuckle)", "(whistles)", ] # Build Gradio UI with gr.Blocks(css=css) as demo: gr.Markdown("# Nari Text-to-Speech Synthesis") with gr.Row(equal_height=False): with gr.Column(scale=1): text_input = Dialogue( speakers=["Speaker 1", "Speaker 2"], emotions=emotions, formatter=formatter, value=default_text, ) audio_prompt_input = gr.Audio( label="Audio Prompt (Optional)", show_label=True, sources=["upload", "microphone"], type="numpy", ) with gr.Accordion("Generation Parameters", open=False): max_new_tokens = gr.Slider( label="Max New Tokens (Audio Length)", minimum=860, maximum=3072, value=model.config.data.audio_length, # Use config default if available, else fallback step=50, info="Controls the maximum length of the generated audio (more tokens = longer audio).", ) cfg_scale = gr.Slider( label="CFG Scale (Guidance Strength)", minimum=1.0, maximum=5.0, value=3.0, # Default from inference.py step=0.1, info="Higher values increase adherence to the text prompt.", ) temperature = gr.Slider( label="Temperature (Randomness)", minimum=1.0, maximum=1.5, value=1.3, # Default from inference.py step=0.05, info="Lower values make the output more deterministic, higher values increase randomness.", ) top_p = gr.Slider( label="Top P (Nucleus Sampling)", minimum=0.80, maximum=1.0, value=0.95, # Default from inference.py step=0.01, info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.", ) cfg_filter_top_k = gr.Slider( label="CFG Filter Top K", minimum=15, maximum=50, value=30, step=1, info="Top k filter for CFG guidance.", ) speed_factor_slider = gr.Slider( label="Speed Factor", minimum=0.8, maximum=1.0, value=0.94, step=0.02, info="Adjusts the speed of the generated audio (1.0 = original speed).", ) run_button = gr.Button("Generate Audio", variant="primary") with gr.Column(scale=1): audio_output = gr.Audio( label="Generated Audio", type="numpy", autoplay=False, ) gr.Deeplink() # Link button click to function run_button.click( fn=run_inference, inputs=[ text_input, audio_prompt_input, max_new_tokens, cfg_scale, temperature, top_p, cfg_filter_top_k, speed_factor_slider, ], outputs=[audio_output], # Add status_output here if using it api_name="generate_audio", ) # Add examples (ensure the prompt path is correct or remove it if example file doesn't exist) example_prompt_path = "./example_prompt.mp3" # Adjust if needed examples_list = [ [ [{"speaker": "Speaker 1", "text": "Oh fire! Oh my goodness! What's the procedure? What to we do people? The smoke could be coming through an air duct!"}, {"speaker": "Speaker 2", "text": "Oh my god! Okay.. it's happening. Everybody stay calm!"}, {"speaker": "Speaker 1", "text": "What's the procedure..."}, {"speaker": "Speaker 2", "text": "Everybody stay fucking calm!!!... Everybody fucking calm down!!!!! \n[S1] No! No! If you touch the handle, if its hot there might be a fire down the hallway!"}, ], None, 3072, 3.0, 1.3, 0.95, 35, 0.94, ], [ [{"speaker": "Speaker 1", "text": "Open weights text to dialogue model."}, {"speaker": "Speaker 2", "text": "You get full control over scripts and voices."}, {"speaker": "Speaker 1", "text": "I'm biased, but I think we clearly won."}, {"speaker": "Speaker 2", "text": "Hard to disagree. (laughs)"}, {"speaker": "Speaker 1", "text": "Thanks for listening to this demo."}, {"speaker": "Speaker 2", "text": "Try it now on Git hub and Hugging Face."}, {"speaker": "Speaker 1", "text": "If you liked our model, please give us a star and share to your friends."}, {"speaker": "Speaker 2", "text": "This was Nari Labs."}, ], example_prompt_path if Path(example_prompt_path).exists() else None, 3072, 3.0, 1.3, 0.95, 35, 0.94, ], ] if examples_list: gr.Examples( examples=examples_list, inputs=[ text_input, audio_prompt_input, max_new_tokens, cfg_scale, temperature, top_p, cfg_filter_top_k, speed_factor_slider, ], outputs=[audio_output], fn=run_inference, cache_examples=False, label="Examples (Click to Run)", ) else: gr.Markdown("_(No examples configured or example prompt file missing)_") # --- Launch the App --- if __name__ == "__main__": print("Launching Gradio interface...") # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker demo.launch(ssr_mode=False)