Spaces:
Running
Running
import os | |
import sys | |
import gradio as gr | |
import torch | |
import numpy as np | |
import matplotlib | |
matplotlib.use('Agg') # Set backend before importing pyplot | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from huggingface_hub import hf_hub_download | |
import pretty_midi | |
import librosa | |
import soundfile as sf | |
from midi2audio import FluidSynth | |
import spaces | |
# Remove CPU forcing since we'll use ZeroGPU | |
# os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
# torch.set_num_threads(4) | |
from aria.image_encoder import ImageEncoder | |
from aria.aria import ARIA | |
print("Checking model files...") | |
# Pre-download all model files at startup | |
MODEL_FILES = { | |
"image_encoder": "image_encoder.pt", | |
"continuous_concat": ["continuous_concat/model.pt", "continuous_concat/mappings.pt", "continuous_concat/model_config.pt"], | |
"continuous_token": ["continuous_token/model.pt", "continuous_token/mappings.pt", "continuous_token/model_config.pt"], | |
"discrete_token": ["discrete_token/model.pt", "discrete_token/mappings.pt", "discrete_token/model_config.pt"] | |
} | |
# Create cache directory | |
CACHE_DIR = os.path.join(os.path.dirname(__file__), "model_cache") | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Download and cache all files | |
cached_files = {} | |
for model_type, files in MODEL_FILES.items(): | |
if isinstance(files, str): | |
files = [files] | |
cached_files[model_type] = [] | |
for file in files: | |
try: | |
# Check if file already exists in cache | |
repo_id = "vincentamato/aria" | |
cached_path = os.path.join(CACHE_DIR, repo_id, file) | |
if os.path.exists(cached_path): | |
print(f"Using cached file: {file}") | |
cached_files[model_type].append(cached_path) | |
else: | |
print(f"Downloading file: {file}") | |
cached_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=file, | |
cache_dir=CACHE_DIR | |
) | |
cached_files[model_type].append(cached_path) | |
except Exception as e: | |
print(f"Error with file {file}: {str(e)}") | |
print("Model files ready.") | |
# Global model cache | |
models = {} | |
def create_emotion_plot(valence, arousal): | |
"""Create a valence-arousal plot with the predicted emotion point""" | |
# Create figure in a process-safe way | |
fig = plt.figure(figsize=(8, 8), dpi=100) | |
ax = fig.add_subplot(111) | |
# Set background color and style | |
plt.style.use('default') # Use default style instead of seaborn | |
fig.patch.set_facecolor('#ffffff') | |
ax.set_facecolor('#ffffff') | |
# Create the coordinate system with a light grid | |
ax.grid(True, linestyle='--', alpha=0.2) | |
ax.axhline(y=0, color='#666666', linestyle='-', alpha=0.3, linewidth=1) | |
ax.axvline(x=0, color='#666666', linestyle='-', alpha=0.3, linewidth=1) | |
# Plot region | |
circle = plt.Circle((0, 0), 1, fill=False, color='#666666', alpha=0.3, linewidth=1.5) | |
ax.add_artist(circle) | |
# Add labels with nice fonts | |
font = {'family': 'sans-serif', 'weight': 'medium', 'size': 12} | |
label_dist = 1.35 # Increased distance for labels | |
ax.text(label_dist, 0, 'Positive', ha='left', va='center', **font) | |
ax.text(-label_dist, 0, 'Negative', ha='right', va='center', **font) | |
ax.text(0, label_dist, 'Excited', ha='center', va='bottom', **font) | |
ax.text(0, -label_dist, 'Calm', ha='center', va='top', **font) | |
# Plot the point with a nice style | |
ax.scatter([valence], [arousal], c='#4f46e5', s=150, zorder=5, alpha=0.8) | |
# Set limits and labels with more padding | |
ax.set_xlim(-1.6, 1.6) | |
ax.set_ylim(-1.6, 1.6) | |
# Format ticks | |
ax.set_xticks([-1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5]) | |
ax.set_yticks([-1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5]) | |
ax.tick_params(axis='both', which='major', labelsize=10) | |
# Add axis labels with padding | |
ax.set_xlabel('Valence', **font, labelpad=15) | |
ax.set_ylabel('Arousal', **font, labelpad=15) | |
# Remove spines | |
for spine in ax.spines.values(): | |
spine.set_visible(False) | |
# Adjust layout with more padding | |
plt.tight_layout(pad=1.5) | |
# Save to a temporary file and return the path | |
temp_path = os.path.join(os.path.dirname(__file__), "output", "emotion_plot.png") | |
os.makedirs(os.path.dirname(temp_path), exist_ok=True) | |
plt.savefig(temp_path, bbox_inches='tight', dpi=100) | |
plt.close(fig) # Close the figure to free memory | |
return temp_path | |
def get_model(conditioning_type): | |
"""Get or initialize model with specified conditioning""" | |
if conditioning_type not in models: | |
try: | |
# Use cached files | |
image_model_path = cached_files["image_encoder"][0] | |
midi_model_dir = os.path.dirname(cached_files[conditioning_type][0]) | |
models[conditioning_type] = ARIA( | |
image_model_checkpoint=image_model_path, | |
midi_model_dir=midi_model_dir, | |
conditioning=conditioning_type | |
) | |
except Exception as e: | |
print(f"Error initializing {conditioning_type} model: {str(e)}") | |
return None | |
return models[conditioning_type] | |
def convert_midi_to_wav(midi_path): | |
"""Convert MIDI file to WAV using FluidSynth""" | |
wav_path = midi_path.replace('.mid', '.wav') | |
# If WAV file already exists and is newer than MIDI file, use cached version | |
if os.path.exists(wav_path) and os.path.getmtime(wav_path) > os.path.getmtime(midi_path): | |
return wav_path | |
try: | |
# Check common soundfont locations | |
soundfont_paths = [ | |
'/usr/share/sounds/sf2/FluidR3_GM.sf2', # Linux | |
'/usr/share/soundfonts/default.sf2', # Linux alternative | |
'/usr/local/share/fluidsynth/generaluser.sf2', # macOS | |
'C:\\soundfonts\\generaluser.sf2' # Windows | |
] | |
soundfont = None | |
for sf_path in soundfont_paths: | |
if os.path.exists(sf_path): | |
soundfont = sf_path | |
break | |
if soundfont is None: | |
raise RuntimeError("No SoundFont file found. Please install fluid-soundfont-gm package.") | |
# Convert MIDI to WAV using FluidSynth with explicit soundfont | |
fs = FluidSynth(sound_font=soundfont) | |
fs.midi_to_audio(midi_path, wav_path) | |
return wav_path | |
except Exception as e: | |
print(f"Error converting MIDI to WAV: {str(e)}") | |
return None | |
def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments): | |
"""Generate music from input image""" | |
model = get_model(conditioning_type) | |
if model is None: | |
# IMPORTANT: Return a 3-element tuple, not a dictionary | |
return ( | |
None, # For emotion_chart | |
None, # For midi_output | |
f"⚠️ Error: Failed to initialize {conditioning_type} model. Please check the logs." | |
) | |
try: | |
# Create output directory | |
output_dir = os.path.join(os.path.dirname(__file__), "output") | |
os.makedirs(output_dir, exist_ok=True) | |
# Generate music | |
valence, arousal, midi_path = model.generate( | |
image_path=image, | |
out_dir=output_dir, | |
gen_len=gen_len, | |
temperature=temperature, | |
top_k=-1, | |
top_p=float(top_p), | |
min_instruments=int(min_instruments) | |
) | |
# Convert MIDI to WAV | |
wav_path = convert_midi_to_wav(midi_path) | |
if wav_path is None: | |
return ( | |
None, | |
None, | |
"⚠️ Error: Failed to convert MIDI to WAV for playback" | |
) | |
# Create emotion plot | |
plot_path = create_emotion_plot(valence, arousal) | |
# Build a nice Markdown result string | |
result_text = f""" | |
**Model Type:** {conditioning_type} | |
**Predicted Emotions:** | |
- Valence: {valence:.3f} (negative → positive) | |
- Arousal: {arousal:.3f} (calm → excited) | |
**Generation Parameters:** | |
- Temperature: {temperature} | |
- Top-p: {top_p} | |
- Min Instruments: {min_instruments} | |
Your music has been generated! Click the play button above to listen. | |
""" | |
# RETURN AS A TUPLE | |
return (plot_path, wav_path, result_text) | |
except Exception as e: | |
return ( | |
None, | |
None, | |
f"⚠️ Error generating music: {str(e)}" | |
) | |
def generate_music_wrapper(image, conditioning_type, gen_len, note_temp, rest_temp, top_p, min_instruments): | |
"""Wrapper for generate_music that handles separate temperatures""" | |
return generate_music( | |
image=image, | |
conditioning_type=conditioning_type, | |
gen_len=gen_len, | |
temperature=[float(note_temp), float(rest_temp)], | |
top_p=top_p, | |
min_instruments=min_instruments | |
) | |
# Create Gradio interface | |
with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft( | |
primary_hue="indigo", | |
secondary_hue="slate", | |
font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"] | |
)) as demo: | |
gr.Markdown(""" | |
# 🎨 ARIA: Artistic Rendering of Images into Audio | |
Upload an image and ARIA will analyze its emotional content to generate matching music! | |
### How it works: | |
1. ARIA first analyzes the emotional content of your image along two dimensions: | |
- **Valence**: How positive or negative the emotion is (-1 to 1) | |
- **Arousal**: How calm or excited the emotion is (-1 to 1) | |
2. These emotions are then used to generate music that matches the mood | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
image_input = gr.Image( | |
type="filepath", | |
label="Upload Image" | |
) | |
with gr.Group(): | |
gr.Markdown("### Generation Settings") | |
with gr.Row(): | |
with gr.Column(): | |
conditioning_type = gr.Radio( | |
choices=["continuous_concat", "continuous_token", "discrete_token"], | |
value="continuous_concat", | |
label="Conditioning Type", | |
info="How the emotional information is incorporated into the music generation" | |
) | |
with gr.Column(): | |
gen_len = gr.Slider( | |
minimum=256, | |
maximum=4096, | |
value=1024, | |
step=256, | |
label="Generation Length", | |
info="Number of tokens to generate (longer = more music)" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
note_temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=1.2, | |
step=0.1, | |
label="Note Temperature", | |
info="Controls randomness of note generation" | |
) | |
with gr.Column(): | |
rest_temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=1.2, | |
step=0.1, | |
label="Rest Temperature", | |
info="Controls randomness of rest/timing generation" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.6, | |
step=0.1, | |
label="Top-p Sampling", | |
info="Nucleus sampling threshold - lower = more focused" | |
) | |
with gr.Column(): | |
min_instruments = gr.Slider( | |
minimum=1, | |
maximum=5, | |
value=2, | |
step=1, | |
label="Minimum Instruments", | |
info="Minimum number of instruments in the generated music" | |
) | |
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg") | |
# Add examples | |
gr.Examples( | |
examples=[ | |
["examples/happy.jpg", "continuous_concat", 1024, 1.2, 1.2, 0.6, 2], | |
["examples/sad.jpeg", "continuous_concat", 1024, 1.2, 1.2, 0.6, 2], | |
], | |
inputs=[image_input, conditioning_type, gen_len, note_temperature, rest_temperature, top_p, min_instruments], | |
label="Try these examples" | |
) | |
with gr.Column(scale=2): | |
emotion_chart = gr.Image( | |
label="Predicted Emotions", | |
type="filepath" | |
) | |
midi_output = gr.Audio( | |
type="filepath", | |
label="Generated Music" | |
) | |
results = gr.Markdown() | |
gr.Markdown(""" | |
### About ARIA | |
ARIA is a deep learning system that generates music from artwork by: | |
1. Using a image emotion model to extract emotional content from images | |
2. Generating matching music using an emotion-conditioned music generation model | |
The emotion-conditioned MIDI generation model is based on the work by Serkan Sulun et al. in their paper | |
["Symbolic music generation conditioned on continuous-valued emotions"](https://ieeexplore.ieee.org/document/9762257). | |
Original implementation: [github.com/serkansulun/midi-emotion](https://github.com/serkansulun/midi-emotion) | |
### Conditioning Types | |
- **continuous_concat**: Emotions are concatenated with music features (recommended) | |
- **continuous_token**: Emotions are added as special tokens | |
- **discrete_token**: Emotions are discretized into tokens | |
""") | |
def generate_music_wrapper(image, conditioning_type, gen_len, note_temp, rest_temp, top_p, min_instruments): | |
"""Wrapper for generate_music that handles separate temperatures""" | |
return generate_music( | |
image=image, | |
conditioning_type=conditioning_type, | |
gen_len=gen_len, | |
temperature=[float(note_temp), float(rest_temp)], | |
top_p=top_p, | |
min_instruments=min_instruments | |
) | |
generate_btn.click( | |
fn=generate_music_wrapper, | |
inputs=[image_input, conditioning_type, gen_len, note_temperature, rest_temperature, top_p, min_instruments], | |
outputs=[emotion_chart, midi_output, results] | |
) | |
# Launch app | |
demo.launch(share=True) |