Spaces:
Running
Running
Commit
·
8629f1c
1
Parent(s):
8429ccf
Fixed chart rendering
Browse files- app.py +17 -7
- aria/aria.py +2 -2
- aria/generate.py +0 -61
app.py
CHANGED
@@ -65,6 +65,9 @@ models = {}
|
|
65 |
|
66 |
def create_emotion_plot(valence, arousal):
|
67 |
"""Create a valence-arousal plot with the predicted emotion point"""
|
|
|
|
|
|
|
68 |
fig = plt.figure(figsize=(8, 8), dpi=100)
|
69 |
ax = fig.add_subplot(111)
|
70 |
|
@@ -113,7 +116,13 @@ def create_emotion_plot(valence, arousal):
|
|
113 |
# Adjust layout with more padding
|
114 |
plt.tight_layout(pad=1.5)
|
115 |
|
116 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
def get_model(conditioning_type):
|
119 |
"""Get or initialize model with specified conditioning"""
|
@@ -168,7 +177,7 @@ def convert_midi_to_wav(midi_path):
|
|
168 |
print(f"Error converting MIDI to WAV: {str(e)}")
|
169 |
return None
|
170 |
|
171 |
-
@spaces.GPU(duration=120)
|
172 |
def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
|
173 |
"""Generate music from input image"""
|
174 |
model = get_model(conditioning_type)
|
@@ -208,11 +217,11 @@ def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_in
|
|
208 |
results: "⚠️ Error: Failed to convert MIDI to WAV for playback"
|
209 |
}
|
210 |
|
211 |
-
# Create emotion plot
|
212 |
-
|
213 |
|
214 |
return {
|
215 |
-
emotion_chart:
|
216 |
midi_output: wav_path,
|
217 |
results: f"""
|
218 |
**Model Type:** {conditioning_type}
|
@@ -335,8 +344,9 @@ with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
|
|
335 |
)
|
336 |
|
337 |
with gr.Column(scale=2):
|
338 |
-
emotion_chart = gr.
|
339 |
-
label="Predicted Emotions"
|
|
|
340 |
)
|
341 |
midi_output = gr.Audio(
|
342 |
type="filepath",
|
|
|
65 |
|
66 |
def create_emotion_plot(valence, arousal):
|
67 |
"""Create a valence-arousal plot with the predicted emotion point"""
|
68 |
+
# Create figure in a process-safe way
|
69 |
+
plt.switch_backend('Agg')
|
70 |
+
|
71 |
fig = plt.figure(figsize=(8, 8), dpi=100)
|
72 |
ax = fig.add_subplot(111)
|
73 |
|
|
|
116 |
# Adjust layout with more padding
|
117 |
plt.tight_layout(pad=1.5)
|
118 |
|
119 |
+
# Save to a temporary file and return the path
|
120 |
+
temp_path = os.path.join(os.path.dirname(__file__), "output", "emotion_plot.png")
|
121 |
+
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
122 |
+
plt.savefig(temp_path, bbox_inches='tight', dpi=100)
|
123 |
+
plt.close(fig) # Close the figure to free memory
|
124 |
+
|
125 |
+
return temp_path
|
126 |
|
127 |
def get_model(conditioning_type):
|
128 |
"""Get or initialize model with specified conditioning"""
|
|
|
177 |
print(f"Error converting MIDI to WAV: {str(e)}")
|
178 |
return None
|
179 |
|
180 |
+
@spaces.GPU(duration=120) # Set duration to 120 seconds for music generation
|
181 |
def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
|
182 |
"""Generate music from input image"""
|
183 |
model = get_model(conditioning_type)
|
|
|
217 |
results: "⚠️ Error: Failed to convert MIDI to WAV for playback"
|
218 |
}
|
219 |
|
220 |
+
# Create emotion plot and get its path
|
221 |
+
plot_path = create_emotion_plot(valence, arousal)
|
222 |
|
223 |
return {
|
224 |
+
emotion_chart: plot_path,
|
225 |
midi_output: wav_path,
|
226 |
results: f"""
|
227 |
**Model Type:** {conditioning_type}
|
|
|
344 |
)
|
345 |
|
346 |
with gr.Column(scale=2):
|
347 |
+
emotion_chart = gr.Image(
|
348 |
+
label="Predicted Emotions",
|
349 |
+
type="filepath"
|
350 |
)
|
351 |
midi_output = gr.Audio(
|
352 |
type="filepath",
|
aria/aria.py
CHANGED
@@ -15,7 +15,7 @@ sys.path.append(MIDI_EMOTION_PATH)
|
|
15 |
class ARIA:
|
16 |
"""ARIA model that generates music from images based on emotional content."""
|
17 |
|
18 |
-
@spaces.GPU(duration=
|
19 |
def __init__(
|
20 |
self,
|
21 |
image_model_checkpoint: str,
|
@@ -60,7 +60,7 @@ class ARIA:
|
|
60 |
self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
|
61 |
self.midi_model.eval()
|
62 |
|
63 |
-
@spaces.GPU(duration=
|
64 |
@torch.inference_mode() # More efficient than no_grad for inference
|
65 |
def generate(
|
66 |
self,
|
|
|
15 |
class ARIA:
|
16 |
"""ARIA model that generates music from images based on emotional content."""
|
17 |
|
18 |
+
@spaces.GPU(duration=10) # Model loading should be quick
|
19 |
def __init__(
|
20 |
self,
|
21 |
image_model_checkpoint: str,
|
|
|
60 |
self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
|
61 |
self.midi_model.eval()
|
62 |
|
63 |
+
@spaces.GPU(duration=60)
|
64 |
@torch.inference_mode() # More efficient than no_grad for inference
|
65 |
def generate(
|
66 |
self,
|
aria/generate.py
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
from src.models.aria.aria import ARIA
|
3 |
-
|
4 |
-
def main():
|
5 |
-
parser = argparse.ArgumentParser(description="Generate music from images based on emotional content")
|
6 |
-
|
7 |
-
parser.add_argument("--image", type=str, required=True,
|
8 |
-
help="Path to input image")
|
9 |
-
parser.add_argument("--image_model_checkpoint", type=str, required=True,
|
10 |
-
help="Path to image emotion model checkpoint")
|
11 |
-
parser.add_argument("--midi_model_dir", type=str, required=True,
|
12 |
-
help="Path to midi emotion model directory")
|
13 |
-
parser.add_argument("--out_dir", type=str, default="output",
|
14 |
-
help="Directory to save generated MIDI")
|
15 |
-
parser.add_argument("--gen_len", type=int, default=512,
|
16 |
-
help="Length of generation in tokens")
|
17 |
-
parser.add_argument("--temperature", type=float, nargs=2, default=[1.2, 1.2],
|
18 |
-
help="Temperature for sampling [note_temp, rest_temp]")
|
19 |
-
parser.add_argument("--top_k", type=int, default=-1,
|
20 |
-
help="Top-k sampling (-1 to disable)")
|
21 |
-
parser.add_argument("--top_p", type=float, default=0.7,
|
22 |
-
help="Top-p sampling threshold")
|
23 |
-
parser.add_argument("--min_instruments", type=int, default=1,
|
24 |
-
help="Minimum number of instruments required")
|
25 |
-
parser.add_argument("--cpu", action="store_true",
|
26 |
-
help="Force CPU inference")
|
27 |
-
parser.add_argument("--conditioning", type=str, required=True,
|
28 |
-
choices=["none", "discrete_token", "continuous_token", "continuous_concat"],
|
29 |
-
help="Type of conditioning to use")
|
30 |
-
parser.add_argument("--batch_size", type=int, default=1,
|
31 |
-
help="Number of samples to generate (not used for image input)")
|
32 |
-
|
33 |
-
args = parser.parse_args()
|
34 |
-
|
35 |
-
# Initialize model
|
36 |
-
model = ARIA(
|
37 |
-
image_model_checkpoint=args.image_model_checkpoint,
|
38 |
-
midi_model_dir=args.midi_model_dir,
|
39 |
-
conditioning=args.conditioning,
|
40 |
-
device="cpu" if args.cpu else None
|
41 |
-
)
|
42 |
-
|
43 |
-
# Generate music
|
44 |
-
valence, arousal, midi_path = model.generate(
|
45 |
-
image_path=args.image,
|
46 |
-
out_dir=args.out_dir,
|
47 |
-
gen_len=args.gen_len,
|
48 |
-
temperature=args.temperature,
|
49 |
-
top_k=args.top_k,
|
50 |
-
top_p=args.top_p,
|
51 |
-
min_instruments=args.min_instruments
|
52 |
-
)
|
53 |
-
|
54 |
-
# Print results
|
55 |
-
print(f"\nPredicted emotions:")
|
56 |
-
print(f"Valence: {valence:.3f} (negative -> positive)")
|
57 |
-
print(f"Arousal: {arousal:.3f} (calm -> excited)")
|
58 |
-
print(f"\nGenerated MIDI saved to: {midi_path}")
|
59 |
-
|
60 |
-
if __name__ == "__main__":
|
61 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|