vincentamato commited on
Commit
8629f1c
·
1 Parent(s): 8429ccf

Fixed chart rendering

Browse files
Files changed (3) hide show
  1. app.py +17 -7
  2. aria/aria.py +2 -2
  3. aria/generate.py +0 -61
app.py CHANGED
@@ -65,6 +65,9 @@ models = {}
65
 
66
  def create_emotion_plot(valence, arousal):
67
  """Create a valence-arousal plot with the predicted emotion point"""
 
 
 
68
  fig = plt.figure(figsize=(8, 8), dpi=100)
69
  ax = fig.add_subplot(111)
70
 
@@ -113,7 +116,13 @@ def create_emotion_plot(valence, arousal):
113
  # Adjust layout with more padding
114
  plt.tight_layout(pad=1.5)
115
 
116
- return fig
 
 
 
 
 
 
117
 
118
  def get_model(conditioning_type):
119
  """Get or initialize model with specified conditioning"""
@@ -168,7 +177,7 @@ def convert_midi_to_wav(midi_path):
168
  print(f"Error converting MIDI to WAV: {str(e)}")
169
  return None
170
 
171
- @spaces.GPU(duration=120) # Set duration to 60 seconds for music generation
172
  def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
173
  """Generate music from input image"""
174
  model = get_model(conditioning_type)
@@ -208,11 +217,11 @@ def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_in
208
  results: "⚠️ Error: Failed to convert MIDI to WAV for playback"
209
  }
210
 
211
- # Create emotion plot
212
- emotion_fig = create_emotion_plot(valence, arousal)
213
 
214
  return {
215
- emotion_chart: emotion_fig,
216
  midi_output: wav_path,
217
  results: f"""
218
  **Model Type:** {conditioning_type}
@@ -335,8 +344,9 @@ with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
335
  )
336
 
337
  with gr.Column(scale=2):
338
- emotion_chart = gr.Plot(
339
- label="Predicted Emotions"
 
340
  )
341
  midi_output = gr.Audio(
342
  type="filepath",
 
65
 
66
  def create_emotion_plot(valence, arousal):
67
  """Create a valence-arousal plot with the predicted emotion point"""
68
+ # Create figure in a process-safe way
69
+ plt.switch_backend('Agg')
70
+
71
  fig = plt.figure(figsize=(8, 8), dpi=100)
72
  ax = fig.add_subplot(111)
73
 
 
116
  # Adjust layout with more padding
117
  plt.tight_layout(pad=1.5)
118
 
119
+ # Save to a temporary file and return the path
120
+ temp_path = os.path.join(os.path.dirname(__file__), "output", "emotion_plot.png")
121
+ os.makedirs(os.path.dirname(temp_path), exist_ok=True)
122
+ plt.savefig(temp_path, bbox_inches='tight', dpi=100)
123
+ plt.close(fig) # Close the figure to free memory
124
+
125
+ return temp_path
126
 
127
  def get_model(conditioning_type):
128
  """Get or initialize model with specified conditioning"""
 
177
  print(f"Error converting MIDI to WAV: {str(e)}")
178
  return None
179
 
180
+ @spaces.GPU(duration=120) # Set duration to 120 seconds for music generation
181
  def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
182
  """Generate music from input image"""
183
  model = get_model(conditioning_type)
 
217
  results: "⚠️ Error: Failed to convert MIDI to WAV for playback"
218
  }
219
 
220
+ # Create emotion plot and get its path
221
+ plot_path = create_emotion_plot(valence, arousal)
222
 
223
  return {
224
+ emotion_chart: plot_path,
225
  midi_output: wav_path,
226
  results: f"""
227
  **Model Type:** {conditioning_type}
 
344
  )
345
 
346
  with gr.Column(scale=2):
347
+ emotion_chart = gr.Image(
348
+ label="Predicted Emotions",
349
+ type="filepath"
350
  )
351
  midi_output = gr.Audio(
352
  type="filepath",
aria/aria.py CHANGED
@@ -15,7 +15,7 @@ sys.path.append(MIDI_EMOTION_PATH)
15
  class ARIA:
16
  """ARIA model that generates music from images based on emotional content."""
17
 
18
- @spaces.GPU(duration=20) # Model loading should be quick
19
  def __init__(
20
  self,
21
  image_model_checkpoint: str,
@@ -60,7 +60,7 @@ class ARIA:
60
  self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
61
  self.midi_model.eval()
62
 
63
- @spaces.GPU(duration=120)
64
  @torch.inference_mode() # More efficient than no_grad for inference
65
  def generate(
66
  self,
 
15
  class ARIA:
16
  """ARIA model that generates music from images based on emotional content."""
17
 
18
+ @spaces.GPU(duration=10) # Model loading should be quick
19
  def __init__(
20
  self,
21
  image_model_checkpoint: str,
 
60
  self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
61
  self.midi_model.eval()
62
 
63
+ @spaces.GPU(duration=60)
64
  @torch.inference_mode() # More efficient than no_grad for inference
65
  def generate(
66
  self,
aria/generate.py DELETED
@@ -1,61 +0,0 @@
1
- import argparse
2
- from src.models.aria.aria import ARIA
3
-
4
- def main():
5
- parser = argparse.ArgumentParser(description="Generate music from images based on emotional content")
6
-
7
- parser.add_argument("--image", type=str, required=True,
8
- help="Path to input image")
9
- parser.add_argument("--image_model_checkpoint", type=str, required=True,
10
- help="Path to image emotion model checkpoint")
11
- parser.add_argument("--midi_model_dir", type=str, required=True,
12
- help="Path to midi emotion model directory")
13
- parser.add_argument("--out_dir", type=str, default="output",
14
- help="Directory to save generated MIDI")
15
- parser.add_argument("--gen_len", type=int, default=512,
16
- help="Length of generation in tokens")
17
- parser.add_argument("--temperature", type=float, nargs=2, default=[1.2, 1.2],
18
- help="Temperature for sampling [note_temp, rest_temp]")
19
- parser.add_argument("--top_k", type=int, default=-1,
20
- help="Top-k sampling (-1 to disable)")
21
- parser.add_argument("--top_p", type=float, default=0.7,
22
- help="Top-p sampling threshold")
23
- parser.add_argument("--min_instruments", type=int, default=1,
24
- help="Minimum number of instruments required")
25
- parser.add_argument("--cpu", action="store_true",
26
- help="Force CPU inference")
27
- parser.add_argument("--conditioning", type=str, required=True,
28
- choices=["none", "discrete_token", "continuous_token", "continuous_concat"],
29
- help="Type of conditioning to use")
30
- parser.add_argument("--batch_size", type=int, default=1,
31
- help="Number of samples to generate (not used for image input)")
32
-
33
- args = parser.parse_args()
34
-
35
- # Initialize model
36
- model = ARIA(
37
- image_model_checkpoint=args.image_model_checkpoint,
38
- midi_model_dir=args.midi_model_dir,
39
- conditioning=args.conditioning,
40
- device="cpu" if args.cpu else None
41
- )
42
-
43
- # Generate music
44
- valence, arousal, midi_path = model.generate(
45
- image_path=args.image,
46
- out_dir=args.out_dir,
47
- gen_len=args.gen_len,
48
- temperature=args.temperature,
49
- top_k=args.top_k,
50
- top_p=args.top_p,
51
- min_instruments=args.min_instruments
52
- )
53
-
54
- # Print results
55
- print(f"\nPredicted emotions:")
56
- print(f"Valence: {valence:.3f} (negative -> positive)")
57
- print(f"Arousal: {arousal:.3f} (calm -> excited)")
58
- print(f"\nGenerated MIDI saved to: {midi_path}")
59
-
60
- if __name__ == "__main__":
61
- main()