vincentamato commited on
Commit
3cbd872
·
1 Parent(s): c083e90

Fixed chart rendering

Browse files
Files changed (1) hide show
  1. app.py +110 -66
app.py CHANGED
@@ -65,36 +65,64 @@ print("Model files ready.")
65
  # Global model cache
66
  models = {}
67
 
68
- def create_emotion_text(valence, arousal):
69
- """Create a text-based representation of the emotion coordinates"""
70
- # Helper function to get emotion description
71
- def get_emotion_description(v, a):
72
- if v > 0.5:
73
- if a > 0.5: return "Joyful/Excited"
74
- elif a < -0.5: return "Content/Peaceful"
75
- else: return "Happy/Pleasant"
76
- elif v < -0.5:
77
- if a > 0.5: return "Angry/Tense"
78
- elif a < -0.5: return "Sad/Depressed"
79
- else: return "Unhappy/Unpleasant"
80
- else:
81
- if a > 0.5: return "Alert/Energetic"
82
- elif a < -0.5: return "Tired/Calm"
83
- else: return "Neutral"
84
 
85
- emotion = get_emotion_description(valence, arousal)
 
 
 
86
 
87
- return f"""
88
- ### Predicted Emotions
 
 
89
 
90
- **Emotion Category:** {emotion}
 
 
91
 
92
- **Coordinates:**
93
- - **Valence:** {valence:.2f} (negative positive)
94
- - **Arousal:** {arousal:.2f} (calm excited)
 
 
 
 
95
 
96
- These values are used to generate music that matches the emotional tone of your image.
97
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def get_model(conditioning_type):
100
  """Get or initialize model with specified conditioning"""
@@ -149,22 +177,23 @@ def convert_midi_to_wav(midi_path):
149
  print(f"Error converting MIDI to WAV: {str(e)}")
150
  return None
151
 
152
- @spaces.GPU(duration=120) # Set duration to 120 seconds for music generation
153
  def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
154
  """Generate music from input image"""
155
  model = get_model(conditioning_type)
156
  if model is None:
157
- return {
158
- emotion_display: None,
159
- midi_output: None,
160
- results: f"⚠️ Error: Failed to initialize {conditioning_type} model. Please check the logs."
161
- }
 
162
 
163
  try:
164
- # Create output directory with absolute path
165
  output_dir = os.path.join(os.path.dirname(__file__), "output")
166
  os.makedirs(output_dir, exist_ok=True)
167
-
168
  # Generate music
169
  valence, arousal, midi_path = model.generate(
170
  image_path=image,
@@ -176,42 +205,54 @@ def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_in
176
  min_instruments=int(min_instruments)
177
  )
178
 
179
- # Ensure we have the absolute path to the MIDI file
180
- if not os.path.isabs(midi_path):
181
- midi_path = os.path.join(output_dir, midi_path)
182
-
183
- # Convert MIDI to WAV for playback
184
  wav_path = convert_midi_to_wav(midi_path)
185
  if wav_path is None:
186
- return {
187
- emotion_display: None,
188
- midi_output: None,
189
- results: "⚠️ Error: Failed to convert MIDI to WAV for playback"
190
- }
191
 
192
- # Create emotion text display
193
- emotion_text = create_emotion_text(valence, arousal)
194
 
195
- return {
196
- emotion_display: emotion_text,
197
- midi_output: wav_path,
198
- results: f"""
199
- **Model Type:** {conditioning_type}
200
-
201
- **Generation Parameters:**
202
- - Temperature: {temperature}
203
- - Top-p: {top_p}
204
- - Min Instruments: {min_instruments}
205
-
206
- Your music has been generated! Click the play button above to listen.
207
- """
208
- }
 
 
 
 
 
209
  except Exception as e:
210
- return {
211
- emotion_display: None,
212
- midi_output: None,
213
- results: f"⚠️ Error generating music: {str(e)}"
214
- }
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  # Create Gradio interface
217
  with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
@@ -312,11 +353,14 @@ with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
312
  )
313
 
314
  with gr.Column(scale=2):
 
 
 
 
315
  midi_output = gr.Audio(
316
  type="filepath",
317
  label="Generated Music"
318
  )
319
- emotion_display = gr.Markdown()
320
  results = gr.Markdown()
321
 
322
  gr.Markdown("""
@@ -350,7 +394,7 @@ with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
350
  generate_btn.click(
351
  fn=generate_music_wrapper,
352
  inputs=[image_input, conditioning_type, gen_len, note_temperature, rest_temperature, top_p, min_instruments],
353
- outputs=[emotion_display, midi_output, results]
354
  )
355
 
356
  # Launch app
 
65
  # Global model cache
66
  models = {}
67
 
68
+ def create_emotion_plot(valence, arousal):
69
+ """Create a valence-arousal plot with the predicted emotion point"""
70
+ # Create figure in a process-safe way
71
+ fig = plt.figure(figsize=(8, 8), dpi=100)
72
+ ax = fig.add_subplot(111)
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ # Set background color and style
75
+ plt.style.use('default') # Use default style instead of seaborn
76
+ fig.patch.set_facecolor('#ffffff')
77
+ ax.set_facecolor('#ffffff')
78
 
79
+ # Create the coordinate system with a light grid
80
+ ax.grid(True, linestyle='--', alpha=0.2)
81
+ ax.axhline(y=0, color='#666666', linestyle='-', alpha=0.3, linewidth=1)
82
+ ax.axvline(x=0, color='#666666', linestyle='-', alpha=0.3, linewidth=1)
83
 
84
+ # Plot region
85
+ circle = plt.Circle((0, 0), 1, fill=False, color='#666666', alpha=0.3, linewidth=1.5)
86
+ ax.add_artist(circle)
87
 
88
+ # Add labels with nice fonts
89
+ font = {'family': 'sans-serif', 'weight': 'medium', 'size': 12}
90
+ label_dist = 1.35 # Increased distance for labels
91
+ ax.text(label_dist, 0, 'Positive', ha='left', va='center', **font)
92
+ ax.text(-label_dist, 0, 'Negative', ha='right', va='center', **font)
93
+ ax.text(0, label_dist, 'Excited', ha='center', va='bottom', **font)
94
+ ax.text(0, -label_dist, 'Calm', ha='center', va='top', **font)
95
 
96
+ # Plot the point with a nice style
97
+ ax.scatter([valence], [arousal], c='#4f46e5', s=150, zorder=5, alpha=0.8)
98
+
99
+ # Set limits and labels with more padding
100
+ ax.set_xlim(-1.6, 1.6)
101
+ ax.set_ylim(-1.6, 1.6)
102
+
103
+ # Format ticks
104
+ ax.set_xticks([-1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5])
105
+ ax.set_yticks([-1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5])
106
+ ax.tick_params(axis='both', which='major', labelsize=10)
107
+
108
+ # Add axis labels with padding
109
+ ax.set_xlabel('Valence', **font, labelpad=15)
110
+ ax.set_ylabel('Arousal', **font, labelpad=15)
111
+
112
+ # Remove spines
113
+ for spine in ax.spines.values():
114
+ spine.set_visible(False)
115
+
116
+ # Adjust layout with more padding
117
+ plt.tight_layout(pad=1.5)
118
+
119
+ # Save to a temporary file and return the path
120
+ temp_path = os.path.join(os.path.dirname(__file__), "output", "emotion_plot.png")
121
+ os.makedirs(os.path.dirname(temp_path), exist_ok=True)
122
+ plt.savefig(temp_path, bbox_inches='tight', dpi=100)
123
+ plt.close(fig) # Close the figure to free memory
124
+
125
+ return temp_path
126
 
127
  def get_model(conditioning_type):
128
  """Get or initialize model with specified conditioning"""
 
177
  print(f"Error converting MIDI to WAV: {str(e)}")
178
  return None
179
 
180
+ @spaces.GPU(duration=120)
181
  def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
182
  """Generate music from input image"""
183
  model = get_model(conditioning_type)
184
  if model is None:
185
+ # IMPORTANT: Return a 3-element tuple, not a dictionary
186
+ return (
187
+ None, # For emotion_chart
188
+ None, # For midi_output
189
+ f"⚠️ Error: Failed to initialize {conditioning_type} model. Please check the logs."
190
+ )
191
 
192
  try:
193
+ # Create output directory
194
  output_dir = os.path.join(os.path.dirname(__file__), "output")
195
  os.makedirs(output_dir, exist_ok=True)
196
+
197
  # Generate music
198
  valence, arousal, midi_path = model.generate(
199
  image_path=image,
 
205
  min_instruments=int(min_instruments)
206
  )
207
 
208
+ # Convert MIDI to WAV
 
 
 
 
209
  wav_path = convert_midi_to_wav(midi_path)
210
  if wav_path is None:
211
+ return (
212
+ None,
213
+ None,
214
+ "⚠️ Error: Failed to convert MIDI to WAV for playback"
215
+ )
216
 
217
+ # Create emotion plot
218
+ plot_path = create_emotion_plot(valence, arousal)
219
 
220
+ # Build a nice Markdown result string
221
+ result_text = f"""
222
+ **Model Type:** {conditioning_type}
223
+
224
+ **Predicted Emotions:**
225
+ - Valence: {valence:.3f} (negative → positive)
226
+ - Arousal: {arousal:.3f} (calm → excited)
227
+
228
+ **Generation Parameters:**
229
+ - Temperature: {temperature}
230
+ - Top-p: {top_p}
231
+ - Min Instruments: {min_instruments}
232
+
233
+ Your music has been generated! Click the play button above to listen.
234
+ """
235
+
236
+ # RETURN AS A TUPLE
237
+ return (plot_path, wav_path, result_text)
238
+
239
  except Exception as e:
240
+ return (
241
+ None,
242
+ None,
243
+ f"⚠️ Error generating music: {str(e)}"
244
+ )
245
+
246
+ def generate_music_wrapper(image, conditioning_type, gen_len, note_temp, rest_temp, top_p, min_instruments):
247
+ """Wrapper for generate_music that handles separate temperatures"""
248
+ return generate_music(
249
+ image=image,
250
+ conditioning_type=conditioning_type,
251
+ gen_len=gen_len,
252
+ temperature=[float(note_temp), float(rest_temp)],
253
+ top_p=top_p,
254
+ min_instruments=min_instruments
255
+ )
256
 
257
  # Create Gradio interface
258
  with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
 
353
  )
354
 
355
  with gr.Column(scale=2):
356
+ emotion_chart = gr.Image(
357
+ label="Predicted Emotions",
358
+ type="filepath"
359
+ )
360
  midi_output = gr.Audio(
361
  type="filepath",
362
  label="Generated Music"
363
  )
 
364
  results = gr.Markdown()
365
 
366
  gr.Markdown("""
 
394
  generate_btn.click(
395
  fn=generate_music_wrapper,
396
  inputs=[image_input, conditioning_type, gen_len, note_temperature, rest_temperature, top_p, min_instruments],
397
+ outputs=[emotion_chart, midi_output, results]
398
  )
399
 
400
  # Launch app