ARIA / app.py
vincentamato's picture
Updated app.py
26b4dcc
import os
import sys
import gradio as gr
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg') # Set backend before importing pyplot
import matplotlib.pyplot as plt
from PIL import Image
from huggingface_hub import hf_hub_download
import pretty_midi
import librosa
import soundfile as sf
from midi2audio import FluidSynth
import spaces
# Remove CPU forcing since we'll use ZeroGPU
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
# torch.set_num_threads(4)
from aria.image_encoder import ImageEncoder
from aria.aria import ARIA
print("Checking model files...")
# Pre-download all model files at startup
MODEL_FILES = {
"image_encoder": "image_encoder.pt",
"continuous_concat": ["continuous_concat/model.pt", "continuous_concat/mappings.pt", "continuous_concat/model_config.pt"],
"continuous_token": ["continuous_token/model.pt", "continuous_token/mappings.pt", "continuous_token/model_config.pt"],
"discrete_token": ["discrete_token/model.pt", "discrete_token/mappings.pt", "discrete_token/model_config.pt"]
}
# Create cache directory
CACHE_DIR = os.path.join(os.path.dirname(__file__), "model_cache")
os.makedirs(CACHE_DIR, exist_ok=True)
# Download and cache all files
cached_files = {}
for model_type, files in MODEL_FILES.items():
if isinstance(files, str):
files = [files]
cached_files[model_type] = []
for file in files:
try:
# Check if file already exists in cache
repo_id = "vincentamato/aria"
cached_path = os.path.join(CACHE_DIR, repo_id, file)
if os.path.exists(cached_path):
print(f"Using cached file: {file}")
cached_files[model_type].append(cached_path)
else:
print(f"Downloading file: {file}")
cached_path = hf_hub_download(
repo_id=repo_id,
filename=file,
cache_dir=CACHE_DIR
)
cached_files[model_type].append(cached_path)
except Exception as e:
print(f"Error with file {file}: {str(e)}")
print("Model files ready.")
# Global model cache
models = {}
def create_emotion_plot(valence, arousal):
"""Create a valence-arousal plot with the predicted emotion point"""
# Create figure in a process-safe way
fig = plt.figure(figsize=(8, 8), dpi=100)
ax = fig.add_subplot(111)
# Set background color and style
plt.style.use('default') # Use default style instead of seaborn
fig.patch.set_facecolor('#ffffff')
ax.set_facecolor('#ffffff')
# Create the coordinate system with a light grid
ax.grid(True, linestyle='--', alpha=0.2)
ax.axhline(y=0, color='#666666', linestyle='-', alpha=0.3, linewidth=1)
ax.axvline(x=0, color='#666666', linestyle='-', alpha=0.3, linewidth=1)
# Plot region
circle = plt.Circle((0, 0), 1, fill=False, color='#666666', alpha=0.3, linewidth=1.5)
ax.add_artist(circle)
# Add labels with nice fonts
font = {'family': 'sans-serif', 'weight': 'medium', 'size': 12}
label_dist = 1.35 # Increased distance for labels
ax.text(label_dist, 0, 'Positive', ha='left', va='center', **font)
ax.text(-label_dist, 0, 'Negative', ha='right', va='center', **font)
ax.text(0, label_dist, 'Excited', ha='center', va='bottom', **font)
ax.text(0, -label_dist, 'Calm', ha='center', va='top', **font)
# Plot the point with a nice style
ax.scatter([valence], [arousal], c='#4f46e5', s=150, zorder=5, alpha=0.8)
# Set limits and labels with more padding
ax.set_xlim(-1.6, 1.6)
ax.set_ylim(-1.6, 1.6)
# Format ticks
ax.set_xticks([-1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5])
ax.set_yticks([-1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5])
ax.tick_params(axis='both', which='major', labelsize=10)
# Add axis labels with padding
ax.set_xlabel('Valence', **font, labelpad=15)
ax.set_ylabel('Arousal', **font, labelpad=15)
# Remove spines
for spine in ax.spines.values():
spine.set_visible(False)
# Adjust layout with more padding
plt.tight_layout(pad=1.5)
# Save to a temporary file and return the path
temp_path = os.path.join(os.path.dirname(__file__), "output", "emotion_plot.png")
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
plt.savefig(temp_path, bbox_inches='tight', dpi=100)
plt.close(fig) # Close the figure to free memory
return temp_path
def get_model(conditioning_type):
"""Get or initialize model with specified conditioning"""
if conditioning_type not in models:
try:
# Use cached files
image_model_path = cached_files["image_encoder"][0]
midi_model_dir = os.path.dirname(cached_files[conditioning_type][0])
models[conditioning_type] = ARIA(
image_model_checkpoint=image_model_path,
midi_model_dir=midi_model_dir,
conditioning=conditioning_type
)
except Exception as e:
print(f"Error initializing {conditioning_type} model: {str(e)}")
return None
return models[conditioning_type]
def convert_midi_to_wav(midi_path):
"""Convert MIDI file to WAV using FluidSynth"""
wav_path = midi_path.replace('.mid', '.wav')
# If WAV file already exists and is newer than MIDI file, use cached version
if os.path.exists(wav_path) and os.path.getmtime(wav_path) > os.path.getmtime(midi_path):
return wav_path
try:
# Check common soundfont locations
soundfont_paths = [
'/usr/share/sounds/sf2/FluidR3_GM.sf2', # Linux
'/usr/share/soundfonts/default.sf2', # Linux alternative
'/usr/local/share/fluidsynth/generaluser.sf2', # macOS
'C:\\soundfonts\\generaluser.sf2' # Windows
]
soundfont = None
for sf_path in soundfont_paths:
if os.path.exists(sf_path):
soundfont = sf_path
break
if soundfont is None:
raise RuntimeError("No SoundFont file found. Please install fluid-soundfont-gm package.")
# Convert MIDI to WAV using FluidSynth with explicit soundfont
fs = FluidSynth(sound_font=soundfont)
fs.midi_to_audio(midi_path, wav_path)
return wav_path
except Exception as e:
print(f"Error converting MIDI to WAV: {str(e)}")
return None
@spaces.GPU(duration=120)
def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
"""Generate music from input image"""
model = get_model(conditioning_type)
if model is None:
# IMPORTANT: Return a 3-element tuple, not a dictionary
return (
None, # For emotion_chart
None, # For midi_output
f"⚠️ Error: Failed to initialize {conditioning_type} model. Please check the logs."
)
try:
# Create output directory
output_dir = os.path.join(os.path.dirname(__file__), "output")
os.makedirs(output_dir, exist_ok=True)
# Generate music
valence, arousal, midi_path = model.generate(
image_path=image,
out_dir=output_dir,
gen_len=gen_len,
temperature=temperature,
top_k=-1,
top_p=float(top_p),
min_instruments=int(min_instruments)
)
# Convert MIDI to WAV
wav_path = convert_midi_to_wav(midi_path)
if wav_path is None:
return (
None,
None,
"⚠️ Error: Failed to convert MIDI to WAV for playback"
)
# Create emotion plot
plot_path = create_emotion_plot(valence, arousal)
# Build a nice Markdown result string
result_text = f"""
**Model Type:** {conditioning_type}
**Predicted Emotions:**
- Valence: {valence:.3f} (negative → positive)
- Arousal: {arousal:.3f} (calm → excited)
**Generation Parameters:**
- Temperature: {temperature}
- Top-p: {top_p}
- Min Instruments: {min_instruments}
Your music has been generated! Click the play button above to listen.
"""
# RETURN AS A TUPLE
return (plot_path, wav_path, result_text)
except Exception as e:
return (
None,
None,
f"⚠️ Error generating music: {str(e)}"
)
def generate_music_wrapper(image, conditioning_type, gen_len, note_temp, rest_temp, top_p, min_instruments):
"""Wrapper for generate_music that handles separate temperatures"""
return generate_music(
image=image,
conditioning_type=conditioning_type,
gen_len=gen_len,
temperature=[float(note_temp), float(rest_temp)],
top_p=top_p,
min_instruments=min_instruments
)
# Create Gradio interface
with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"]
)) as demo:
gr.Markdown("""
# 🎨 ARIA: Artistic Rendering of Images into Audio
Upload an image and ARIA will analyze its emotional content to generate matching music!
### How it works:
1. ARIA first analyzes the emotional content of your image along two dimensions:
- **Valence**: How positive or negative the emotion is (-1 to 1)
- **Arousal**: How calm or excited the emotion is (-1 to 1)
2. These emotions are then used to generate music that matches the mood
""")
with gr.Row():
with gr.Column(scale=3):
image_input = gr.Image(
type="filepath",
label="Upload Image"
)
with gr.Group():
gr.Markdown("### Generation Settings")
with gr.Row():
with gr.Column():
conditioning_type = gr.Radio(
choices=["continuous_concat", "continuous_token", "discrete_token"],
value="continuous_concat",
label="Conditioning Type",
info="How the emotional information is incorporated into the music generation"
)
with gr.Column():
gen_len = gr.Slider(
minimum=256,
maximum=4096,
value=1024,
step=256,
label="Generation Length",
info="Number of tokens to generate (longer = more music)"
)
with gr.Row():
with gr.Column():
note_temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.2,
step=0.1,
label="Note Temperature",
info="Controls randomness of note generation"
)
with gr.Column():
rest_temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.2,
step=0.1,
label="Rest Temperature",
info="Controls randomness of rest/timing generation"
)
with gr.Row():
with gr.Column():
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.6,
step=0.1,
label="Top-p Sampling",
info="Nucleus sampling threshold - lower = more focused"
)
with gr.Column():
min_instruments = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Minimum Instruments",
info="Minimum number of instruments in the generated music"
)
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg")
# Add examples
gr.Examples(
examples=[
["examples/happy.jpg", "continuous_concat", 1024, 1.2, 1.2, 0.6, 2],
["examples/sad.jpeg", "continuous_concat", 1024, 1.2, 1.2, 0.6, 2],
],
inputs=[image_input, conditioning_type, gen_len, note_temperature, rest_temperature, top_p, min_instruments],
label="Try these examples"
)
with gr.Column(scale=2):
emotion_chart = gr.Image(
label="Predicted Emotions",
type="filepath"
)
midi_output = gr.Audio(
type="filepath",
label="Generated Music"
)
results = gr.Markdown()
gr.Markdown("""
### About ARIA
ARIA is a deep learning system that generates music from artwork by:
1. Using a image emotion model to extract emotional content from images
2. Generating matching music using an emotion-conditioned music generation model
The emotion-conditioned MIDI generation model is based on the work by Serkan Sulun et al. in their paper
["Symbolic music generation conditioned on continuous-valued emotions"](https://ieeexplore.ieee.org/document/9762257).
Original implementation: [github.com/serkansulun/midi-emotion](https://github.com/serkansulun/midi-emotion)
### Conditioning Types
- **continuous_concat**: Emotions are concatenated with music features (recommended)
- **continuous_token**: Emotions are added as special tokens
- **discrete_token**: Emotions are discretized into tokens
""")
def generate_music_wrapper(image, conditioning_type, gen_len, note_temp, rest_temp, top_p, min_instruments):
"""Wrapper for generate_music that handles separate temperatures"""
return generate_music(
image=image,
conditioning_type=conditioning_type,
gen_len=gen_len,
temperature=[float(note_temp), float(rest_temp)],
top_p=top_p,
min_instruments=min_instruments
)
generate_btn.click(
fn=generate_music_wrapper,
inputs=[image_input, conditioning_type, gen_len, note_temperature, rest_temperature, top_p, min_instruments],
outputs=[emotion_chart, midi_output, results]
)
# Launch app
demo.launch(share=True)