vincentamato commited on
Commit
69defc9
·
1 Parent(s): e15e4d5

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ env/
26
+ ENV/
27
+
28
+ # IDE
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+
34
+ # Generated files
35
+ output/
36
+ model_cache/
37
+ *.wav
38
+ *.mid
39
+
40
+ # Example files are tracked normally (no LFS needed)
41
+ !examples/
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: ARIA
3
- emoji: 🦀
4
  colorFrom: indigo
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.13.1
8
  app_file: app.py
 
1
  ---
2
+ title: Aria
3
+ emoji: 📉
4
  colorFrom: indigo
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.13.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gradio as gr
4
+ import torch
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ from huggingface_hub import hf_hub_download
9
+ import pretty_midi
10
+ import librosa
11
+ import soundfile as sf
12
+ from midi2audio import FluidSynth
13
+
14
+ from aria.image_encoder import ImageEncoder
15
+ from aria.aria import ARIA
16
+
17
+ print("Checking model files...")
18
+ # Pre-download all model files at startup
19
+ MODEL_FILES = {
20
+ "image_encoder": "image_encoder.pt",
21
+ "continuous_concat": ["continuous_concat/model.pt", "continuous_concat/mappings.pt", "continuous_concat/model_config.pt"],
22
+ "continuous_token": ["continuous_token/model.pt", "continuous_token/mappings.pt", "continuous_token/model_config.pt"],
23
+ "discrete_token": ["discrete_token/model.pt", "discrete_token/mappings.pt", "discrete_token/model_config.pt"]
24
+ }
25
+
26
+ # Create cache directory
27
+ CACHE_DIR = os.path.join(os.path.dirname(__file__), "model_cache")
28
+ os.makedirs(CACHE_DIR, exist_ok=True)
29
+
30
+ # Download and cache all files
31
+ cached_files = {}
32
+ for model_type, files in MODEL_FILES.items():
33
+ if isinstance(files, str):
34
+ files = [files]
35
+
36
+ cached_files[model_type] = []
37
+ for file in files:
38
+ try:
39
+ # Check if file already exists in cache
40
+ repo_id = "vincentamato/aria"
41
+ cached_path = os.path.join(CACHE_DIR, repo_id, file)
42
+ if os.path.exists(cached_path):
43
+ print(f"Using cached file: {file}")
44
+ cached_files[model_type].append(cached_path)
45
+ else:
46
+ print(f"Downloading file: {file}")
47
+ cached_path = hf_hub_download(
48
+ repo_id=repo_id,
49
+ filename=file,
50
+ cache_dir=CACHE_DIR
51
+ )
52
+ cached_files[model_type].append(cached_path)
53
+ except Exception as e:
54
+ print(f"Error with file {file}: {str(e)}")
55
+
56
+ print("Model files ready.")
57
+
58
+ # Global model cache
59
+ models = {}
60
+
61
+ def create_emotion_plot(valence, arousal):
62
+ """Create a valence-arousal plot with the predicted emotion point"""
63
+ fig = plt.figure(figsize=(8, 8), dpi=100)
64
+ ax = fig.add_subplot(111)
65
+
66
+ # Set background color and style
67
+ plt.style.use('default') # Use default style instead of seaborn
68
+ fig.patch.set_facecolor('#ffffff')
69
+ ax.set_facecolor('#ffffff')
70
+
71
+ # Create the coordinate system with a light grid
72
+ ax.grid(True, linestyle='--', alpha=0.2)
73
+ ax.axhline(y=0, color='#666666', linestyle='-', alpha=0.3, linewidth=1)
74
+ ax.axvline(x=0, color='#666666', linestyle='-', alpha=0.3, linewidth=1)
75
+
76
+ # Plot region
77
+ circle = plt.Circle((0, 0), 1, fill=False, color='#666666', alpha=0.3, linewidth=1.5)
78
+ ax.add_artist(circle)
79
+
80
+ # Add labels with nice fonts
81
+ font = {'family': 'sans-serif', 'weight': 'medium', 'size': 12}
82
+ label_dist = 1.35 # Increased distance for labels
83
+ ax.text(label_dist, 0, 'Positive', ha='left', va='center', **font)
84
+ ax.text(-label_dist, 0, 'Negative', ha='right', va='center', **font)
85
+ ax.text(0, label_dist, 'Excited', ha='center', va='bottom', **font)
86
+ ax.text(0, -label_dist, 'Calm', ha='center', va='top', **font)
87
+
88
+ # Plot the point with a nice style
89
+ ax.scatter([valence], [arousal], c='#4f46e5', s=150, zorder=5, alpha=0.8)
90
+
91
+ # Set limits and labels with more padding
92
+ ax.set_xlim(-1.6, 1.6)
93
+ ax.set_ylim(-1.6, 1.6)
94
+
95
+ # Format ticks
96
+ ax.set_xticks([-1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5])
97
+ ax.set_yticks([-1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5])
98
+ ax.tick_params(axis='both', which='major', labelsize=10)
99
+
100
+ # Add axis labels with padding
101
+ ax.set_xlabel('Valence', **font, labelpad=15)
102
+ ax.set_ylabel('Arousal', **font, labelpad=15)
103
+
104
+ # Remove spines
105
+ for spine in ax.spines.values():
106
+ spine.set_visible(False)
107
+
108
+ # Adjust layout with more padding
109
+ plt.tight_layout(pad=1.5)
110
+
111
+ return fig
112
+
113
+ def get_model(conditioning_type):
114
+ """Get or initialize model with specified conditioning"""
115
+ if conditioning_type not in models:
116
+ try:
117
+ # Use cached files
118
+ image_model_path = cached_files["image_encoder"][0]
119
+ midi_model_dir = os.path.dirname(cached_files[conditioning_type][0])
120
+
121
+ models[conditioning_type] = ARIA(
122
+ image_model_checkpoint=image_model_path,
123
+ midi_model_dir=midi_model_dir,
124
+ conditioning=conditioning_type
125
+ )
126
+ except Exception as e:
127
+ print(f"Error initializing {conditioning_type} model: {str(e)}")
128
+ return None
129
+ return models[conditioning_type]
130
+
131
+ def convert_midi_to_wav(midi_path):
132
+ """Convert MIDI file to WAV using FluidSynth"""
133
+ wav_path = midi_path.replace('.mid', '.wav')
134
+
135
+ # If WAV file already exists and is newer than MIDI file, use cached version
136
+ if os.path.exists(wav_path) and os.path.getmtime(wav_path) > os.path.getmtime(midi_path):
137
+ return wav_path
138
+
139
+ try:
140
+ # Check common soundfont locations
141
+ soundfont_paths = [
142
+ '/usr/share/sounds/sf2/FluidR3_GM.sf2', # Linux
143
+ '/usr/share/soundfonts/default.sf2', # Linux alternative
144
+ '/usr/local/share/fluidsynth/generaluser.sf2', # macOS
145
+ 'C:\\soundfonts\\generaluser.sf2' # Windows
146
+ ]
147
+
148
+ soundfont = None
149
+ for sf_path in soundfont_paths:
150
+ if os.path.exists(sf_path):
151
+ soundfont = sf_path
152
+ break
153
+
154
+ if soundfont is None:
155
+ raise RuntimeError("No SoundFont file found. Please install fluid-soundfont-gm package.")
156
+
157
+ # Convert MIDI to WAV using FluidSynth with explicit soundfont
158
+ fs = FluidSynth(sound_font=soundfont)
159
+ fs.midi_to_audio(midi_path, wav_path)
160
+
161
+ return wav_path
162
+ except Exception as e:
163
+ print(f"Error converting MIDI to WAV: {str(e)}")
164
+ return None
165
+
166
+ def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
167
+ """Generate music from input image"""
168
+ model = get_model(conditioning_type)
169
+ if model is None:
170
+ return {
171
+ emotion_chart: None,
172
+ midi_output: None,
173
+ results: f"⚠️ Error: Failed to initialize {conditioning_type} model. Please check the logs."
174
+ }
175
+
176
+ try:
177
+ # Create output directory with absolute path
178
+ output_dir = os.path.join(os.path.dirname(__file__), "output")
179
+ os.makedirs(output_dir, exist_ok=True)
180
+
181
+ # Generate music
182
+ valence, arousal, midi_path = model.generate(
183
+ image_path=image,
184
+ out_dir=output_dir,
185
+ gen_len=gen_len,
186
+ temperature=temperature,
187
+ top_k=-1,
188
+ top_p=float(top_p),
189
+ min_instruments=int(min_instruments)
190
+ )
191
+
192
+ # Ensure we have the absolute path to the MIDI file
193
+ if not os.path.isabs(midi_path):
194
+ midi_path = os.path.join(output_dir, midi_path)
195
+
196
+ # Convert MIDI to WAV for playback
197
+ wav_path = convert_midi_to_wav(midi_path)
198
+ if wav_path is None:
199
+ return {
200
+ emotion_chart: None,
201
+ midi_output: None,
202
+ results: "⚠️ Error: Failed to convert MIDI to WAV for playback"
203
+ }
204
+
205
+ # Create emotion plot
206
+ emotion_fig = create_emotion_plot(valence, arousal)
207
+
208
+ return {
209
+ emotion_chart: emotion_fig,
210
+ midi_output: wav_path,
211
+ results: f"""
212
+ **Model Type:** {conditioning_type}
213
+
214
+ **Predicted Emotions:**
215
+ - Valence: {valence:.3f} (negative → positive)
216
+ - Arousal: {arousal:.3f} (calm → excited)
217
+
218
+ **Generation Parameters:**
219
+ - Temperature: {temperature}
220
+ - Top-p: {top_p}
221
+ - Min Instruments: {min_instruments}
222
+
223
+ Your music has been generated! Click the play button above to listen.
224
+ """
225
+ }
226
+ except Exception as e:
227
+ return {
228
+ emotion_chart: None,
229
+ midi_output: None,
230
+ results: f"⚠️ Error generating music: {str(e)}"
231
+ }
232
+
233
+ # Create Gradio interface
234
+ with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
235
+ primary_hue="indigo",
236
+ secondary_hue="slate",
237
+ font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"]
238
+ )) as demo:
239
+ gr.Markdown("""
240
+ # 🎨 ARIA: Artistic Rendering of Images into Audio
241
+
242
+ Upload an image and ARIA will analyze its emotional content to generate matching music!
243
+
244
+ ### How it works:
245
+ 1. ARIA first analyzes the emotional content of your image along two dimensions:
246
+ - **Valence**: How positive or negative the emotion is (-1 to 1)
247
+ - **Arousal**: How calm or excited the emotion is (-1 to 1)
248
+ 2. These emotions are then used to generate music that matches the mood
249
+ """)
250
+
251
+ with gr.Row():
252
+ with gr.Column(scale=3):
253
+ image_input = gr.Image(
254
+ type="filepath",
255
+ label="Upload Image"
256
+ )
257
+
258
+ with gr.Group():
259
+ gr.Markdown("### Generation Settings")
260
+
261
+ with gr.Row():
262
+ with gr.Column():
263
+ conditioning_type = gr.Radio(
264
+ choices=["continuous_concat", "continuous_token", "discrete_token"],
265
+ value="continuous_concat",
266
+ label="Conditioning Type",
267
+ info="How the emotional information is incorporated into the music generation"
268
+ )
269
+ with gr.Column():
270
+ gen_len = gr.Slider(
271
+ minimum=256,
272
+ maximum=4096,
273
+ value=1024,
274
+ step=256,
275
+ label="Generation Length",
276
+ info="Number of tokens to generate (longer = more music)"
277
+ )
278
+
279
+ with gr.Row():
280
+ with gr.Column():
281
+ note_temperature = gr.Slider(
282
+ minimum=0.1,
283
+ maximum=2.0,
284
+ value=1.2,
285
+ step=0.1,
286
+ label="Note Temperature",
287
+ info="Controls randomness of note generation"
288
+ )
289
+ with gr.Column():
290
+ rest_temperature = gr.Slider(
291
+ minimum=0.1,
292
+ maximum=2.0,
293
+ value=1.2,
294
+ step=0.1,
295
+ label="Rest Temperature",
296
+ info="Controls randomness of rest/timing generation"
297
+ )
298
+
299
+ with gr.Row():
300
+ with gr.Column():
301
+ top_p = gr.Slider(
302
+ minimum=0.1,
303
+ maximum=1.0,
304
+ value=0.6,
305
+ step=0.1,
306
+ label="Top-p Sampling",
307
+ info="Nucleus sampling threshold - lower = more focused"
308
+ )
309
+ with gr.Column():
310
+ min_instruments = gr.Slider(
311
+ minimum=1,
312
+ maximum=5,
313
+ value=2,
314
+ step=1,
315
+ label="Minimum Instruments",
316
+ info="Minimum number of instruments in the generated music"
317
+ )
318
+
319
+ generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg")
320
+
321
+ # Add examples
322
+ gr.Examples(
323
+ examples=[
324
+ ["examples/happy.jpg", "continuous_concat", 1024, 1.2, 1.2, 0.6, 2],
325
+ ["examples/sad.jpeg", "continuous_token", 1024, 1.2, 1.2, 0.6, 2],
326
+ ],
327
+ inputs=[image_input, conditioning_type, gen_len, note_temperature, rest_temperature, top_p, min_instruments],
328
+ label="Try these examples"
329
+ )
330
+
331
+ with gr.Column(scale=2):
332
+ emotion_chart = gr.Plot(
333
+ label="Predicted Emotions"
334
+ )
335
+ midi_output = gr.Audio(
336
+ type="filepath",
337
+ label="Generated Music"
338
+ )
339
+ results = gr.Markdown()
340
+
341
+ gr.Markdown("""
342
+ ### About ARIA
343
+
344
+ ARIA is a deep learning system that generates music from artwork by:
345
+ 1. Using a image emotion model to extract emotional content from images
346
+ 2. Generating matching music using an emotion-conditioned music generation model
347
+
348
+ The emotion-conditioned MIDI generation model is based on the work by Serkan Sulun et al. in their paper
349
+ ["Symbolic music generation conditioned on continuous-valued emotions"](https://ieeexplore.ieee.org/document/9762257).
350
+ Original implementation: [github.com/serkansulun/midi-emotion](https://github.com/serkansulun/midi-emotion)
351
+
352
+ ### Conditioning Types
353
+ - **continuous_concat**: Emotions are concatenated with music features (recommended)
354
+ - **continuous_token**: Emotions are added as special tokens
355
+ - **discrete_token**: Emotions are discretized into tokens
356
+ """)
357
+
358
+ def generate_music_wrapper(image, conditioning_type, gen_len, note_temp, rest_temp, top_p, min_instruments):
359
+ """Wrapper for generate_music that handles separate temperatures"""
360
+ return generate_music(
361
+ image=image,
362
+ conditioning_type=conditioning_type,
363
+ gen_len=gen_len,
364
+ temperature=[float(note_temp), float(rest_temp)],
365
+ top_p=top_p,
366
+ min_instruments=min_instruments
367
+ )
368
+
369
+ generate_btn.click(
370
+ fn=generate_music_wrapper,
371
+ inputs=[image_input, conditioning_type, gen_len, note_temperature, rest_temperature, top_p, min_instruments],
372
+ outputs=[emotion_chart, midi_output, results]
373
+ )
374
+
375
+ # Launch app
376
+ demo.launch(share=True)
aria/aria.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from PIL import Image
4
+ import numpy as np
5
+ import datetime
6
+
7
+ from .image_encoder import ImageEncoder
8
+
9
+ # Add MIDI emotion model path to Python path
10
+ import sys
11
+ MIDI_EMOTION_PATH = os.path.join(os.path.dirname(__file__), "..", "midi_emotion", "src")
12
+ sys.path.append(MIDI_EMOTION_PATH)
13
+
14
+ class ARIA:
15
+ """ARIA model that generates music from images based on emotional content."""
16
+
17
+ def __init__(
18
+ self,
19
+ image_model_checkpoint: str,
20
+ midi_model_dir: str,
21
+ conditioning: str = "continuous_concat",
22
+ device: str = None
23
+ ):
24
+ """Initialize ARIA model.
25
+
26
+ Args:
27
+ image_model_checkpoint: Path to image emotion model checkpoint
28
+ midi_model_dir: Path to midi emotion model directory
29
+ conditioning: Type of conditioning to use (continuous_concat, continuous_token, discrete_token)
30
+ device: Device to run on (default: auto-detect)
31
+ """
32
+ self.device = torch.device("cuda" if torch.cuda.is_available() and not device == "cpu" else "cpu")
33
+ self.conditioning = conditioning
34
+
35
+ # Load image emotion model
36
+ self.image_model = ImageEncoder()
37
+ checkpoint = torch.load(image_model_checkpoint, map_location=self.device, weights_only=True)
38
+ self.image_model.load_state_dict(checkpoint["model_state_dict"])
39
+ self.image_model.eval()
40
+
41
+ # Import midi generation
42
+ from midi_emotion.src.generate import generate
43
+ from midi_emotion.src.models.build_model import build_model
44
+ self.generate_midi = generate
45
+
46
+ # Load midi model
47
+ model_fp = os.path.join(midi_model_dir, 'model.pt')
48
+ mappings_fp = os.path.join(midi_model_dir, 'mappings.pt')
49
+ config_fp = os.path.join(midi_model_dir, 'model_config.pt')
50
+
51
+ self.maps = torch.load(mappings_fp, weights_only=True)
52
+ config = torch.load(config_fp, weights_only=True)
53
+ self.midi_model, _ = build_model(None, load_config_dict=config)
54
+ self.midi_model = self.midi_model.to(self.device)
55
+ self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
56
+ self.midi_model.eval()
57
+
58
+ def generate(
59
+ self,
60
+ image_path: str,
61
+ out_dir: str = "output",
62
+ gen_len: int = 2048,
63
+ temperature: list = [1.2, 1.2],
64
+ top_k: int = -1,
65
+ top_p: float = 0.7,
66
+ min_instruments: int = 2
67
+ ) -> tuple[float, float, str]:
68
+ """Generate music from an image.
69
+
70
+ Args:
71
+ image_path: Path to input image
72
+ out_dir: Directory to save generated MIDI
73
+ gen_len: Length of generation in tokens
74
+ temperature: Temperature for sampling [note_temp, rest_temp]
75
+ top_k: Top-k sampling (-1 to disable)
76
+ top_p: Top-p sampling threshold
77
+ min_instruments: Minimum number of instruments required
78
+
79
+ Returns:
80
+ Tuple of (valence, arousal, midi_path)
81
+ """
82
+ # Get emotion from image
83
+ image = Image.open(image_path).convert("RGB")
84
+ with torch.no_grad():
85
+ valence, arousal = self.image_model(image)
86
+ valence = valence.squeeze().cpu().item()
87
+ arousal = arousal.squeeze().cpu().item()
88
+
89
+ # Create output directory
90
+ os.makedirs(out_dir, exist_ok=True)
91
+
92
+ # Generate MIDI
93
+ continuous_conditions = np.array([[valence, arousal]], dtype=np.float32)
94
+
95
+ # Generate timestamp for filename (for reference)
96
+ now = datetime.datetime.now()
97
+ timestamp = now.strftime("%Y_%m_%d_%H_%M_%S")
98
+
99
+ # Generate the MIDI
100
+ self.generate_midi(
101
+ model=self.midi_model,
102
+ maps=self.maps,
103
+ device=self.device,
104
+ out_dir=out_dir,
105
+ conditioning=self.conditioning,
106
+ continuous_conditions=continuous_conditions,
107
+ gen_len=gen_len,
108
+ temperatures=temperature,
109
+ top_k=top_k,
110
+ top_p=top_p,
111
+ min_n_instruments=min_instruments
112
+ )
113
+
114
+ # Find the most recently generated MIDI file
115
+ midi_files = [f for f in os.listdir(out_dir) if f.endswith('.mid')]
116
+ if midi_files:
117
+ # Sort by creation time and get most recent
118
+ midi_path = os.path.join(out_dir, max(midi_files, key=lambda f: os.path.getctime(os.path.join(out_dir, f))))
119
+ return valence, arousal, midi_path
120
+
121
+ raise RuntimeError("Failed to generate MIDI file")
aria/generate.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from src.models.aria.aria import ARIA
3
+
4
+ def main():
5
+ parser = argparse.ArgumentParser(description="Generate music from images based on emotional content")
6
+
7
+ parser.add_argument("--image", type=str, required=True,
8
+ help="Path to input image")
9
+ parser.add_argument("--image_model_checkpoint", type=str, required=True,
10
+ help="Path to image emotion model checkpoint")
11
+ parser.add_argument("--midi_model_dir", type=str, required=True,
12
+ help="Path to midi emotion model directory")
13
+ parser.add_argument("--out_dir", type=str, default="output",
14
+ help="Directory to save generated MIDI")
15
+ parser.add_argument("--gen_len", type=int, default=512,
16
+ help="Length of generation in tokens")
17
+ parser.add_argument("--temperature", type=float, nargs=2, default=[1.2, 1.2],
18
+ help="Temperature for sampling [note_temp, rest_temp]")
19
+ parser.add_argument("--top_k", type=int, default=-1,
20
+ help="Top-k sampling (-1 to disable)")
21
+ parser.add_argument("--top_p", type=float, default=0.7,
22
+ help="Top-p sampling threshold")
23
+ parser.add_argument("--min_instruments", type=int, default=1,
24
+ help="Minimum number of instruments required")
25
+ parser.add_argument("--cpu", action="store_true",
26
+ help="Force CPU inference")
27
+ parser.add_argument("--conditioning", type=str, required=True,
28
+ choices=["none", "discrete_token", "continuous_token", "continuous_concat"],
29
+ help="Type of conditioning to use")
30
+ parser.add_argument("--batch_size", type=int, default=1,
31
+ help="Number of samples to generate (not used for image input)")
32
+
33
+ args = parser.parse_args()
34
+
35
+ # Initialize model
36
+ model = ARIA(
37
+ image_model_checkpoint=args.image_model_checkpoint,
38
+ midi_model_dir=args.midi_model_dir,
39
+ conditioning=args.conditioning,
40
+ device="cpu" if args.cpu else None
41
+ )
42
+
43
+ # Generate music
44
+ valence, arousal, midi_path = model.generate(
45
+ image_path=args.image,
46
+ out_dir=args.out_dir,
47
+ gen_len=args.gen_len,
48
+ temperature=args.temperature,
49
+ top_k=args.top_k,
50
+ top_p=args.top_p,
51
+ min_instruments=args.min_instruments
52
+ )
53
+
54
+ # Print results
55
+ print(f"\nPredicted emotions:")
56
+ print(f"Valence: {valence:.3f} (negative -> positive)")
57
+ print(f"Arousal: {arousal:.3f} (calm -> excited)")
58
+ print(f"\nGenerated MIDI saved to: {midi_path}")
59
+
60
+ if __name__ == "__main__":
61
+ main()
aria/image_encoder.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from PIL import Image
5
+ from typing import Tuple, Union
6
+
7
+ class ImageEncoder(nn.Module):
8
+ def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"):
9
+ """Initialize the image encoder using CLIP.
10
+
11
+ Args:
12
+ clip_model_name: HuggingFace model name for CLIP
13
+ """
14
+ super().__init__()
15
+
16
+ # Load CLIP model and processor
17
+ self.clip_model = CLIPModel.from_pretrained(clip_model_name)
18
+ self.processor = CLIPProcessor.from_pretrained(clip_model_name)
19
+
20
+ # Freeze CLIP parameters
21
+ for param in self.clip_model.parameters():
22
+ param.requires_grad = False
23
+
24
+ # Add projection layers for valence and arousal
25
+ hidden_dim = self.clip_model.config.projection_dim
26
+ projection_dim = hidden_dim // 2
27
+
28
+ self.valence_head = nn.Sequential(
29
+ nn.Linear(hidden_dim, projection_dim),
30
+ nn.ReLU(),
31
+ nn.Dropout(0.1),
32
+ nn.Linear(projection_dim, projection_dim // 2),
33
+ nn.ReLU(),
34
+ nn.Dropout(0.1),
35
+ nn.Linear(projection_dim // 2, 1),
36
+ nn.Tanh() # Output between -1 and 1
37
+ )
38
+
39
+ self.arousal_head = nn.Sequential(
40
+ nn.Linear(hidden_dim, projection_dim),
41
+ nn.ReLU(),
42
+ nn.Dropout(0.1),
43
+ nn.Linear(projection_dim, projection_dim // 2),
44
+ nn.ReLU(),
45
+ nn.Dropout(0.1),
46
+ nn.Linear(projection_dim // 2, 1),
47
+ nn.Tanh() # Output between -1 and 1
48
+ )
49
+
50
+ # Move model to GPU if available
51
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ self.to(self.device)
53
+
54
+ def forward(self, images: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
55
+ """Forward pass to get valence and arousal predictions.
56
+
57
+ Args:
58
+ images: Either PIL images or tensors in CLIP format
59
+
60
+ Returns:
61
+ Tuple of predicted valence and arousal scores
62
+ """
63
+ # Process images if they're PIL images
64
+ if isinstance(images, Image.Image):
65
+ inputs = self.processor(images=images, return_tensors="pt")
66
+ pixel_values = inputs.pixel_values.to(self.device)
67
+ else:
68
+ pixel_values = images.to(self.device)
69
+
70
+ # Get CLIP image features
71
+ image_features = self.clip_model.get_image_features(pixel_values)
72
+
73
+ # Project to valence and arousal scores
74
+ valence = self.valence_head(image_features)
75
+ arousal = self.arousal_head(image_features)
76
+
77
+ return valence, arousal
78
+
79
+ def encode_image(self, image: Image.Image) -> torch.Tensor:
80
+ """Get the raw CLIP image embeddings.
81
+
82
+ Args:
83
+ image: PIL image to encode
84
+
85
+ Returns:
86
+ Image embedding tensor
87
+ """
88
+ inputs = self.processor(images=image, return_tensors="pt")
89
+ with torch.no_grad():
90
+ image_features = self.clip_model.get_image_features(inputs.pixel_values.to(self.device))
91
+ return image_features
examples/happy.jpg ADDED
examples/sad.jpeg ADDED
midi_emotion/.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .vscode
3
+ data_files/*
4
+ output/*
5
+ !.gitkeep
6
+ .cache
midi_emotion/LICENSE.md ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright © 2022 INESC TEC
2
+
3
+ Emotion-based MIDI generator: Uses deep neural networks to create symbolic music (MIDI) based on user-defined emotions from the valence-arousal plane.
4
+
5
+ This software is authored by:
6
+ Serkan Sulun
7
+
8
+
9
+ This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
10
+ This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
11
+ You should have received a copy of the GNU General Public License along with this program. If not, see <https://www.gnu.org/licenses/>.
12
+ A commercial license is also available for use in industrial projects and collaborations that do not wish to use the GPL v3 license.
13
+ To obtain the commercial license please contact the INESC TEC Tech-nology Licensing Office (TLO) at [email protected], or
14
+ Campus da Faculdade de Engenharia da Universidade do Porto
15
+ Rua Dr. Roberto Frias
16
+ 4200-465 Porto
17
+ Portugal
18
+
19
+ If needed SAL (INESC TEC Technology Licensing Office - TLO) can assist with all the legal details regarding the licensing agreement
20
+
21
+ If you use Emotion-based MIDI generator in a work that leads to a scientific publication, we would appreciate it if you would kindly cite Emotion-based MIDI generator in your manuscript.
22
+
23
+ S. Sulun, M. E. P. Davies and P. Viana, "Symbolic Music Generation Conditioned on Continuous-Valued Emotions," in IEEE Access, vol. 10, pp. 44617-44626, 2022, doi: 10.1109/ACCESS.2022.3169744.
24
+
25
+ The paper can be found at https://ieeexplore.ieee.org/document/9762257
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+ GNU GENERAL PUBLIC LICENSE
35
+ Version 3, 29 June 2007
36
+
37
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
38
+ Everyone is permitted to copy and distribute verbatim copies
39
+ of this license document, but changing it is not allowed.
40
+
41
+ Preamble
42
+
43
+ The GNU General Public License is a free, copyleft license for
44
+ software and other kinds of works.
45
+
46
+ The licenses for most software and other practical works are designed
47
+ to take away your freedom to share and change the works. By contrast,
48
+ the GNU General Public License is intended to guarantee your freedom to
49
+ share and change all versions of a program--to make sure it remains free
50
+ software for all its users. We, the Free Software Foundation, use the
51
+ GNU General Public License for most of our software; it applies also to
52
+ any other work released this way by its authors. You can apply it to
53
+ your programs, too.
54
+
55
+ When we speak of free software, we are referring to freedom, not
56
+ price. Our General Public Licenses are designed to make sure that you
57
+ have the freedom to distribute copies of free software (and charge for
58
+ them if you wish), that you receive source code or can get it if you
59
+ want it, that you can change the software or use pieces of it in new
60
+ free programs, and that you know you can do these things.
61
+
62
+ To protect your rights, we need to prevent others from denying you
63
+ these rights or asking you to surrender the rights. Therefore, you have
64
+ certain responsibilities if you distribute copies of the software, or if
65
+ you modify it: responsibilities to respect the freedom of others.
66
+
67
+ For example, if you distribute copies of such a program, whether
68
+ gratis or for a fee, you must pass on to the recipients the same
69
+ freedoms that you received. You must make sure that they, too, receive
70
+ or can get the source code. And you must show them these terms so they
71
+ know their rights.
72
+
73
+ Developers that use the GNU GPL protect your rights with two steps:
74
+ (1) assert copyright on the software, and (2) offer you this License
75
+ giving you legal permission to copy, distribute and/or modify it.
76
+
77
+ For the developers' and authors' protection, the GPL clearly explains
78
+ that there is no warranty for this free software. For both users' and
79
+ authors' sake, the GPL requires that modified versions be marked as
80
+ changed, so that their problems will not be attributed erroneously to
81
+ authors of previous versions.
82
+
83
+ Some devices are designed to deny users access to install or run
84
+ modified versions of the software inside them, although the manufacturer
85
+ can do so. This is fundamentally incompatible with the aim of
86
+ protecting users' freedom to change the software. The systematic
87
+ pattern of such abuse occurs in the area of products for individuals to
88
+ use, which is precisely where it is most unacceptable. Therefore, we
89
+ have designed this version of the GPL to prohibit the practice for those
90
+ products. If such problems arise substantially in other domains, we
91
+ stand ready to extend this provision to those domains in future versions
92
+ of the GPL, as needed to protect the freedom of users.
93
+
94
+ Finally, every program is threatened constantly by software patents.
95
+ States should not allow patents to restrict development and use of
96
+ software on general-purpose computers, but in those that do, we wish to
97
+ avoid the special danger that patents applied to a free program could
98
+ make it effectively proprietary. To prevent this, the GPL assures that
99
+ patents cannot be used to render the program non-free.
100
+
101
+ The precise terms and conditions for copying, distribution and
102
+ modification follow.
103
+
104
+ TERMS AND CONDITIONS
105
+
106
+ 0. Definitions.
107
+
108
+ "This License" refers to version 3 of the GNU General Public License.
109
+
110
+ "Copyright" also means copyright-like laws that apply to other kinds of
111
+ works, such as semiconductor masks.
112
+
113
+ "The Program" refers to any copyrightable work licensed under this
114
+ License. Each licensee is addressed as "you". "Licensees" and
115
+ "recipients" may be individuals or organizations.
116
+
117
+ To "modify" a work means to copy from or adapt all or part of the work
118
+ in a fashion requiring copyright permission, other than the making of an
119
+ exact copy. The resulting work is called a "modified version" of the
120
+ earlier work or a work "based on" the earlier work.
121
+
122
+ A "covered work" means either the unmodified Program or a work based
123
+ on the Program.
124
+
125
+ To "propagate" a work means to do anything with it that, without
126
+ permission, would make you directly or secondarily liable for
127
+ infringement under applicable copyright law, except executing it on a
128
+ computer or modifying a private copy. Propagation includes copying,
129
+ distribution (with or without modification), making available to the
130
+ public, and in some countries other activities as well.
131
+
132
+ To "convey" a work means any kind of propagation that enables other
133
+ parties to make or receive copies. Mere interaction with a user through
134
+ a computer network, with no transfer of a copy, is not conveying.
135
+
136
+ An interactive user interface displays "Appropriate Legal Notices"
137
+ to the extent that it includes a convenient and prominently visible
138
+ feature that (1) displays an appropriate copyright notice, and (2)
139
+ tells the user that there is no warranty for the work (except to the
140
+ extent that warranties are provided), that licensees may convey the
141
+ work under this License, and how to view a copy of this License. If
142
+ the interface presents a list of user commands or options, such as a
143
+ menu, a prominent item in the list meets this criterion.
144
+
145
+ 1. Source Code.
146
+
147
+ The "source code" for a work means the preferred form of the work
148
+ for making modifications to it. "Object code" means any non-source
149
+ form of a work.
150
+
151
+ A "Standard Interface" means an interface that either is an official
152
+ standard defined by a recognized standards body, or, in the case of
153
+ interfaces specified for a particular programming language, one that
154
+ is widely used among developers working in that language.
155
+
156
+ The "System Libraries" of an executable work include anything, other
157
+ than the work as a whole, that (a) is included in the normal form of
158
+ packaging a Major Component, but which is not part of that Major
159
+ Component, and (b) serves only to enable use of the work with that
160
+ Major Component, or to implement a Standard Interface for which an
161
+ implementation is available to the public in source code form. A
162
+ "Major Component", in this context, means a major essential component
163
+ (kernel, window system, and so on) of the specific operating system
164
+ (if any) on which the executable work runs, or a compiler used to
165
+ produce the work, or an object code interpreter used to run it.
166
+
167
+ The "Corresponding Source" for a work in object code form means all
168
+ the source code needed to generate, install, and (for an executable
169
+ work) run the object code and to modify the work, including scripts to
170
+ control those activities. However, it does not include the work's
171
+ System Libraries, or general-purpose tools or generally available free
172
+ programs which are used unmodified in performing those activities but
173
+ which are not part of the work. For example, Corresponding Source
174
+ includes interface definition files associated with source files for
175
+ the work, and the source code for shared libraries and dynamically
176
+ linked subprograms that the work is specifically designed to require,
177
+ such as by intimate data communication or control flow between those
178
+ subprograms and other parts of the work.
179
+
180
+ The Corresponding Source need not include anything that users
181
+ can regenerate automatically from other parts of the Corresponding
182
+ Source.
183
+
184
+ The Corresponding Source for a work in source code form is that
185
+ same work.
186
+
187
+ 2. Basic Permissions.
188
+
189
+ All rights granted under this License are granted for the term of
190
+ copyright on the Program, and are irrevocable provided the stated
191
+ conditions are met. This License explicitly affirms your unlimited
192
+ permission to run the unmodified Program. The output from running a
193
+ covered work is covered by this License only if the output, given its
194
+ content, constitutes a covered work. This License acknowledges your
195
+ rights of fair use or other equivalent, as provided by copyright law.
196
+
197
+ You may make, run and propagate covered works that you do not
198
+ convey, without conditions so long as your license otherwise remains
199
+ in force. You may convey covered works to others for the sole purpose
200
+ of having them make modifications exclusively for you, or provide you
201
+ with facilities for running those works, provided that you comply with
202
+ the terms of this License in conveying all material for which you do
203
+ not control copyright. Those thus making or running the covered works
204
+ for you must do so exclusively on your behalf, under your direction
205
+ and control, on terms that prohibit them from making any copies of
206
+ your copyrighted material outside their relationship with you.
207
+
208
+ Conveying under any other circumstances is permitted solely under
209
+ the conditions stated below. Sublicensing is not allowed; section 10
210
+ makes it unnecessary.
211
+
212
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
213
+
214
+ No covered work shall be deemed part of an effective technological
215
+ measure under any applicable law fulfilling obligations under article
216
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
217
+ similar laws prohibiting or restricting circumvention of such
218
+ measures.
219
+
220
+ When you convey a covered work, you waive any legal power to forbid
221
+ circumvention of technological measures to the extent such circumvention
222
+ is effected by exercising rights under this License with respect to
223
+ the covered work, and you disclaim any intention to limit operation or
224
+ modification of the work as a means of enforcing, against the work's
225
+ users, your or third parties' legal rights to forbid circumvention of
226
+ technological measures.
227
+
228
+ 4. Conveying Verbatim Copies.
229
+
230
+ You may convey verbatim copies of the Program's source code as you
231
+ receive it, in any medium, provided that you conspicuously and
232
+ appropriately publish on each copy an appropriate copyright notice;
233
+ keep intact all notices stating that this License and any
234
+ non-permissive terms added in accord with section 7 apply to the code;
235
+ keep intact all notices of the absence of any warranty; and give all
236
+ recipients a copy of this License along with the Program.
237
+
238
+ You may charge any price or no price for each copy that you convey,
239
+ and you may offer support or warranty protection for a fee.
240
+
241
+ 5. Conveying Modified Source Versions.
242
+
243
+ You may convey a work based on the Program, or the modifications to
244
+ produce it from the Program, in the form of source code under the
245
+ terms of section 4, provided that you also meet all of these conditions:
246
+
247
+ a) The work must carry prominent notices stating that you modified
248
+ it, and giving a relevant date.
249
+
250
+ b) The work must carry prominent notices stating that it is
251
+ released under this License and any conditions added under section
252
+ 7. This requirement modifies the requirement in section 4 to
253
+ "keep intact all notices".
254
+
255
+ c) You must license the entire work, as a whole, under this
256
+ License to anyone who comes into possession of a copy. This
257
+ License will therefore apply, along with any applicable section 7
258
+ additional terms, to the whole of the work, and all its parts,
259
+ regardless of how they are packaged. This License gives no
260
+ permission to license the work in any other way, but it does not
261
+ invalidate such permission if you have separately received it.
262
+
263
+ d) If the work has interactive user interfaces, each must display
264
+ Appropriate Legal Notices; however, if the Program has interactive
265
+ interfaces that do not display Appropriate Legal Notices, your
266
+ work need not make them do so.
267
+
268
+ A compilation of a covered work with other separate and independent
269
+ works, which are not by their nature extensions of the covered work,
270
+ and which are not combined with it such as to form a larger program,
271
+ in or on a volume of a storage or distribution medium, is called an
272
+ "aggregate" if the compilation and its resulting copyright are not
273
+ used to limit the access or legal rights of the compilation's users
274
+ beyond what the individual works permit. Inclusion of a covered work
275
+ in an aggregate does not cause this License to apply to the other
276
+ parts of the aggregate.
277
+
278
+ 6. Conveying Non-Source Forms.
279
+
280
+ You may convey a covered work in object code form under the terms
281
+ of sections 4 and 5, provided that you also convey the
282
+ machine-readable Corresponding Source under the terms of this License,
283
+ in one of these ways:
284
+
285
+ a) Convey the object code in, or embodied in, a physical product
286
+ (including a physical distribution medium), accompanied by the
287
+ Corresponding Source fixed on a durable physical medium
288
+ customarily used for software interchange.
289
+
290
+ b) Convey the object code in, or embodied in, a physical product
291
+ (including a physical distribution medium), accompanied by a
292
+ written offer, valid for at least three years and valid for as
293
+ long as you offer spare parts or customer support for that product
294
+ model, to give anyone who possesses the object code either (1) a
295
+ copy of the Corresponding Source for all the software in the
296
+ product that is covered by this License, on a durable physical
297
+ medium customarily used for software interchange, for a price no
298
+ more than your reasonable cost of physically performing this
299
+ conveying of source, or (2) access to copy the
300
+ Corresponding Source from a network server at no charge.
301
+
302
+ c) Convey individual copies of the object code with a copy of the
303
+ written offer to provide the Corresponding Source. This
304
+ alternative is allowed only occasionally and noncommercially, and
305
+ only if you received the object code with such an offer, in accord
306
+ with subsection 6b.
307
+
308
+ d) Convey the object code by offering access from a designated
309
+ place (gratis or for a charge), and offer equivalent access to the
310
+ Corresponding Source in the same way through the same place at no
311
+ further charge. You need not require recipients to copy the
312
+ Corresponding Source along with the object code. If the place to
313
+ copy the object code is a network server, the Corresponding Source
314
+ may be on a different server (operated by you or a third party)
315
+ that supports equivalent copying facilities, provided you maintain
316
+ clear directions next to the object code saying where to find the
317
+ Corresponding Source. Regardless of what server hosts the
318
+ Corresponding Source, you remain obligated to ensure that it is
319
+ available for as long as needed to satisfy these requirements.
320
+
321
+ e) Convey the object code using peer-to-peer transmission, provided
322
+ you inform other peers where the object code and Corresponding
323
+ Source of the work are being offered to the general public at no
324
+ charge under subsection 6d.
325
+
326
+ A separable portion of the object code, whose source code is excluded
327
+ from the Corresponding Source as a System Library, need not be
328
+ included in conveying the object code work.
329
+
330
+ A "User Product" is either (1) a "consumer product", which means any
331
+ tangible personal property which is normally used for personal, family,
332
+ or household purposes, or (2) anything designed or sold for incorporation
333
+ into a dwelling. In determining whether a product is a consumer product,
334
+ doubtful cases shall be resolved in favor of coverage. For a particular
335
+ product received by a particular user, "normally used" refers to a
336
+ typical or common use of that class of product, regardless of the status
337
+ of the particular user or of the way in which the particular user
338
+ actually uses, or expects or is expected to use, the product. A product
339
+ is a consumer product regardless of whether the product has substantial
340
+ commercial, industrial or non-consumer uses, unless such uses represent
341
+ the only significant mode of use of the product.
342
+
343
+ "Installation Information" for a User Product means any methods,
344
+ procedures, authorization keys, or other information required to install
345
+ and execute modified versions of a covered work in that User Product from
346
+ a modified version of its Corresponding Source. The information must
347
+ suffice to ensure that the continued functioning of the modified object
348
+ code is in no case prevented or interfered with solely because
349
+ modification has been made.
350
+
351
+ If you convey an object code work under this section in, or with, or
352
+ specifically for use in, a User Product, and the conveying occurs as
353
+ part of a transaction in which the right of possession and use of the
354
+ User Product is transferred to the recipient in perpetuity or for a
355
+ fixed term (regardless of how the transaction is characterized), the
356
+ Corresponding Source conveyed under this section must be accompanied
357
+ by the Installation Information. But this requirement does not apply
358
+ if neither you nor any third party retains the ability to install
359
+ modified object code on the User Product (for example, the work has
360
+ been installed in ROM).
361
+
362
+ The requirement to provide Installation Information does not include a
363
+ requirement to continue to provide support service, warranty, or updates
364
+ for a work that has been modified or installed by the recipient, or for
365
+ the User Product in which it has been modified or installed. Access to a
366
+ network may be denied when the modification itself materially and
367
+ adversely affects the operation of the network or violates the rules and
368
+ protocols for communication across the network.
369
+
370
+ Corresponding Source conveyed, and Installation Information provided,
371
+ in accord with this section must be in a format that is publicly
372
+ documented (and with an implementation available to the public in
373
+ source code form), and must require no special password or key for
374
+ unpacking, reading or copying.
375
+
376
+ 7. Additional Terms.
377
+
378
+ "Additional permissions" are terms that supplement the terms of this
379
+ License by making exceptions from one or more of its conditions.
380
+ Additional permissions that are applicable to the entire Program shall
381
+ be treated as though they were included in this License, to the extent
382
+ that they are valid under applicable law. If additional permissions
383
+ apply only to part of the Program, that part may be used separately
384
+ under those permissions, but the entire Program remains governed by
385
+ this License without regard to the additional permissions.
386
+
387
+ When you convey a copy of a covered work, you may at your option
388
+ remove any additional permissions from that copy, or from any part of
389
+ it. (Additional permissions may be written to require their own
390
+ removal in certain cases when you modify the work.) You may place
391
+ additional permissions on material, added by you to a covered work,
392
+ for which you have or can give appropriate copyright permission.
393
+
394
+ Notwithstanding any other provision of this License, for material you
395
+ add to a covered work, you may (if authorized by the copyright holders of
396
+ that material) supplement the terms of this License with terms:
397
+
398
+ a) Disclaiming warranty or limiting liability differently from the
399
+ terms of sections 15 and 16 of this License; or
400
+
401
+ b) Requiring preservation of specified reasonable legal notices or
402
+ author attributions in that material or in the Appropriate Legal
403
+ Notices displayed by works containing it; or
404
+
405
+ c) Prohibiting misrepresentation of the origin of that material, or
406
+ requiring that modified versions of such material be marked in
407
+ reasonable ways as different from the original version; or
408
+
409
+ d) Limiting the use for publicity purposes of names of licensors or
410
+ authors of the material; or
411
+
412
+ e) Declining to grant rights under trademark law for use of some
413
+ trade names, trademarks, or service marks; or
414
+
415
+ f) Requiring indemnification of licensors and authors of that
416
+ material by anyone who conveys the material (or modified versions of
417
+ it) with contractual assumptions of liability to the recipient, for
418
+ any liability that these contractual assumptions directly impose on
419
+ those licensors and authors.
420
+
421
+ All other non-permissive additional terms are considered "further
422
+ restrictions" within the meaning of section 10. If the Program as you
423
+ received it, or any part of it, contains a notice stating that it is
424
+ governed by this License along with a term that is a further
425
+ restriction, you may remove that term. If a license document contains
426
+ a further restriction but permits relicensing or conveying under this
427
+ License, you may add to a covered work material governed by the terms
428
+ of that license document, provided that the further restriction does
429
+ not survive such relicensing or conveying.
430
+
431
+ If you add terms to a covered work in accord with this section, you
432
+ must place, in the relevant source files, a statement of the
433
+ additional terms that apply to those files, or a notice indicating
434
+ where to find the applicable terms.
435
+
436
+ Additional terms, permissive or non-permissive, may be stated in the
437
+ form of a separately written license, or stated as exceptions;
438
+ the above requirements apply either way.
439
+
440
+ 8. Termination.
441
+
442
+ You may not propagate or modify a covered work except as expressly
443
+ provided under this License. Any attempt otherwise to propagate or
444
+ modify it is void, and will automatically terminate your rights under
445
+ this License (including any patent licenses granted under the third
446
+ paragraph of section 11).
447
+
448
+ However, if you cease all violation of this License, then your
449
+ license from a particular copyright holder is reinstated (a)
450
+ provisionally, unless and until the copyright holder explicitly and
451
+ finally terminates your license, and (b) permanently, if the copyright
452
+ holder fails to notify you of the violation by some reasonable means
453
+ prior to 60 days after the cessation.
454
+
455
+ Moreover, your license from a particular copyright holder is
456
+ reinstated permanently if the copyright holder notifies you of the
457
+ violation by some reasonable means, this is the first time you have
458
+ received notice of violation of this License (for any work) from that
459
+ copyright holder, and you cure the violation prior to 30 days after
460
+ your receipt of the notice.
461
+
462
+ Termination of your rights under this section does not terminate the
463
+ licenses of parties who have received copies or rights from you under
464
+ this License. If your rights have been terminated and not permanently
465
+ reinstated, you do not qualify to receive new licenses for the same
466
+ material under section 10.
467
+
468
+ 9. Acceptance Not Required for Having Copies.
469
+
470
+ You are not required to accept this License in order to receive or
471
+ run a copy of the Program. Ancillary propagation of a covered work
472
+ occurring solely as a consequence of using peer-to-peer transmission
473
+ to receive a copy likewise does not require acceptance. However,
474
+ nothing other than this License grants you permission to propagate or
475
+ modify any covered work. These actions infringe copyright if you do
476
+ not accept this License. Therefore, by modifying or propagating a
477
+ covered work, you indicate your acceptance of this License to do so.
478
+
479
+ 10. Automatic Licensing of Downstream Recipients.
480
+
481
+ Each time you convey a covered work, the recipient automatically
482
+ receives a license from the original licensors, to run, modify and
483
+ propagate that work, subject to this License. You are not responsible
484
+ for enforcing compliance by third parties with this License.
485
+
486
+ An "entity transaction" is a transaction transferring control of an
487
+ organization, or substantially all assets of one, or subdividing an
488
+ organization, or merging organizations. If propagation of a covered
489
+ work results from an entity transaction, each party to that
490
+ transaction who receives a copy of the work also receives whatever
491
+ licenses to the work the party's predecessor in interest had or could
492
+ give under the previous paragraph, plus a right to possession of the
493
+ Corresponding Source of the work from the predecessor in interest, if
494
+ the predecessor has it or can get it with reasonable efforts.
495
+
496
+ You may not impose any further restrictions on the exercise of the
497
+ rights granted or affirmed under this License. For example, you may
498
+ not impose a license fee, royalty, or other charge for exercise of
499
+ rights granted under this License, and you may not initiate litigation
500
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
501
+ any patent claim is infringed by making, using, selling, offering for
502
+ sale, or importing the Program or any portion of it.
503
+
504
+ 11. Patents.
505
+
506
+ A "contributor" is a copyright holder who authorizes use under this
507
+ License of the Program or a work on which the Program is based. The
508
+ work thus licensed is called the contributor's "contributor version".
509
+
510
+ A contributor's "essential patent claims" are all patent claims
511
+ owned or controlled by the contributor, whether already acquired or
512
+ hereafter acquired, that would be infringed by some manner, permitted
513
+ by this License, of making, using, or selling its contributor version,
514
+ but do not include claims that would be infringed only as a
515
+ consequence of further modification of the contributor version. For
516
+ purposes of this definition, "control" includes the right to grant
517
+ patent sublicenses in a manner consistent with the requirements of
518
+ this License.
519
+
520
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
521
+ patent license under the contributor's essential patent claims, to
522
+ make, use, sell, offer for sale, import and otherwise run, modify and
523
+ propagate the contents of its contributor version.
524
+
525
+ In the following three paragraphs, a "patent license" is any express
526
+ agreement or commitment, however denominated, not to enforce a patent
527
+ (such as an express permission to practice a patent or covenant not to
528
+ sue for patent infringement). To "grant" such a patent license to a
529
+ party means to make such an agreement or commitment not to enforce a
530
+ patent against the party.
531
+
532
+ If you convey a covered work, knowingly relying on a patent license,
533
+ and the Corresponding Source of the work is not available for anyone
534
+ to copy, free of charge and under the terms of this License, through a
535
+ publicly available network server or other readily accessible means,
536
+ then you must either (1) cause the Corresponding Source to be so
537
+ available, or (2) arrange to deprive yourself of the benefit of the
538
+ patent license for this particular work, or (3) arrange, in a manner
539
+ consistent with the requirements of this License, to extend the patent
540
+ license to downstream recipients. "Knowingly relying" means you have
541
+ actual knowledge that, but for the patent license, your conveying the
542
+ covered work in a country, or your recipient's use of the covered work
543
+ in a country, would infringe one or more identifiable patents in that
544
+ country that you have reason to believe are valid.
545
+
546
+ If, pursuant to or in connection with a single transaction or
547
+ arrangement, you convey, or propagate by procuring conveyance of, a
548
+ covered work, and grant a patent license to some of the parties
549
+ receiving the covered work authorizing them to use, propagate, modify
550
+ or convey a specific copy of the covered work, then the patent license
551
+ you grant is automatically extended to all recipients of the covered
552
+ work and works based on it.
553
+
554
+ A patent license is "discriminatory" if it does not include within
555
+ the scope of its coverage, prohibits the exercise of, or is
556
+ conditioned on the non-exercise of one or more of the rights that are
557
+ specifically granted under this License. You may not convey a covered
558
+ work if you are a party to an arrangement with a third party that is
559
+ in the business of distributing software, under which you make payment
560
+ to the third party based on the extent of your activity of conveying
561
+ the work, and under which the third party grants, to any of the
562
+ parties who would receive the covered work from you, a discriminatory
563
+ patent license (a) in connection with copies of the covered work
564
+ conveyed by you (or copies made from those copies), or (b) primarily
565
+ for and in connection with specific products or compilations that
566
+ contain the covered work, unless you entered into that arrangement,
567
+ or that patent license was granted, prior to 28 March 2007.
568
+
569
+ Nothing in this License shall be construed as excluding or limiting
570
+ any implied license or other defenses to infringement that may
571
+ otherwise be available to you under applicable patent law.
572
+
573
+ 12. No Surrender of Others' Freedom.
574
+
575
+ If conditions are imposed on you (whether by court order, agreement or
576
+ otherwise) that contradict the conditions of this License, they do not
577
+ excuse you from the conditions of this License. If you cannot convey a
578
+ covered work so as to satisfy simultaneously your obligations under this
579
+ License and any other pertinent obligations, then as a consequence you may
580
+ not convey it at all. For example, if you agree to terms that obligate you
581
+ to collect a royalty for further conveying from those to whom you convey
582
+ the Program, the only way you could satisfy both those terms and this
583
+ License would be to refrain entirely from conveying the Program.
584
+
585
+ 13. Use with the GNU Affero General Public License.
586
+
587
+ Notwithstanding any other provision of this License, you have
588
+ permission to link or combine any covered work with a work licensed
589
+ under version 3 of the GNU Affero General Public License into a single
590
+ combined work, and to convey the resulting work. The terms of this
591
+ License will continue to apply to the part which is the covered work,
592
+ but the special requirements of the GNU Affero General Public License,
593
+ section 13, concerning interaction through a network will apply to the
594
+ combination as such.
595
+
596
+ 14. Revised Versions of this License.
597
+
598
+ The Free Software Foundation may publish revised and/or new versions of
599
+ the GNU General Public License from time to time. Such new versions will
600
+ be similar in spirit to the present version, but may differ in detail to
601
+ address new problems or concerns.
602
+
603
+ Each version is given a distinguishing version number. If the
604
+ Program specifies that a certain numbered version of the GNU General
605
+ Public License "or any later version" applies to it, you have the
606
+ option of following the terms and conditions either of that numbered
607
+ version or of any later version published by the Free Software
608
+ Foundation. If the Program does not specify a version number of the
609
+ GNU General Public License, you may choose any version ever published
610
+ by the Free Software Foundation.
611
+
612
+ If the Program specifies that a proxy can decide which future
613
+ versions of the GNU General Public License can be used, that proxy's
614
+ public statement of acceptance of a version permanently authorizes you
615
+ to choose that version for the Program.
616
+
617
+ Later license versions may give you additional or different
618
+ permissions. However, no additional obligations are imposed on any
619
+ author or copyright holder as a result of your choosing to follow a
620
+ later version.
621
+
622
+ 15. Disclaimer of Warranty.
623
+
624
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
625
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
626
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
627
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
628
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
629
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
630
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
631
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
632
+
633
+ 16. Limitation of Liability.
634
+
635
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
636
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
637
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
638
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
639
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
640
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
641
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
642
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
643
+ SUCH DAMAGES.
644
+
645
+ 17. Interpretation of Sections 15 and 16.
646
+
647
+ If the disclaimer of warranty and limitation of liability provided
648
+ above cannot be given local legal effect according to their terms,
649
+ reviewing courts shall apply local law that most closely approximates
650
+ an absolute waiver of all civil liability in connection with the
651
+ Program, unless a warranty or assumption of liability accompanies a
652
+ copy of the Program in return for a fee.
653
+
midi_emotion/readme.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Generates multi-instrument symbolic music (MIDI), based on user-provided emotions from valence-arousal plane. In simpler words, it can generate happy (positive valence, positive arousal), calm (positive valence, negative arousal), angry (negative valence, positive arousal) or sad (negative valence, negative arousal) music.
2
+
3
+ Source code for our paper "Symbolic music generation conditioned on continuous-valued emotions",
4
+ Serkan Sulun, Matthew E. P. Davies, Paula Viana, 2022.
5
+ https://ieeexplore.ieee.org/document/9762257
6
+
7
+ To cite:
8
+ ```S. Sulun, M. E. P. Davies and P. Viana, "Symbolic music generation conditioned on continuous-valued emotions," in IEEE Access, doi: 10.1109/ACCESS.2022.3169744.```
9
+
10
+ Required Python libraries: Numpy, Pytorch, Pandas, pretty_midi, Pypianoroll, tqdm, Spotipy, Pytables. Or run: ```pip install -r requirements.txt```
11
+
12
+ To create the Lakh-Spotify dataset:
13
+
14
+ - Go to the ```src/create_dataset``` folder
15
+
16
+ - Download the datasets:
17
+
18
+ [Lakh pianoroll 5 full dataset](https://ucsdcloud-my.sharepoint.com/personal/h3dong_ucsd_edu/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fh3dong%5Fucsd%5Fedu%2FDocuments%2Fdata%2Flpd%2Flpd%5F5%2Flpd%5F5%5Ffull%2Etar%2Egz&parent=%2Fpersonal%2Fh3dong%5Fucsd%5Fedu%2FDocuments%2Fdata%2Flpd%2Flpd%5F5&ga=1)
19
+
20
+ MSD summary file
21
+ http://labrosa.ee.columbia.edu/millionsong/sites/default/files/AdditionalFiles/msd_summary_file.h5
22
+
23
+ Echonest mapping dataset
24
+ ```ftp://ftp.acousticbrainz.org/pub/acousticbrainz/acousticbrainz-labs/download/msdrosetta/millionsongdataset_echonest.tar.bz2```
25
+ Alternatively: https://drive.google.com/file/d/17Exfxjtq7bI9EKtEZlOrBCkx8RBx7h77/view?usp=sharing
26
+
27
+
28
+ Lakh-MSD matching scores file
29
+ http://hog.ee.columbia.edu/craffel/lmd/match_scores.json
30
+
31
+ - Extract when necessary, and place all inside folder ```./data_files```
32
+
33
+ - Get Spotify client ID and client secret:
34
+ https://developer.spotify.com/dashboard/applications
35
+ Then, fill in the variables "client_id" and "client_secret" in ```src/create_dataset/utils.py```
36
+
37
+ - Run ```run.py```.
38
+
39
+ To preprocess and create the training dataset:
40
+
41
+ - Go to the ```src/data``` folder and run ```preprocess_pianorolls.py```
42
+
43
+
44
+ To generate MIDI using pretrained models:
45
+
46
+ - Download model(s) from the following link:
47
+ https://drive.google.com/drive/folders/1R5-HaXmNzXBAhGq1idrDF-YEKkZm5C8C?usp=sharing
48
+
49
+ - Extract into the folder ```output```
50
+
51
+ - Go to ```src``` folder and run ```generate.py``` with appropriate arguments. e.g:
52
+ ```python generate.py --model_dir continuous_concat --conditioning continuous_concat --valence -0.8, -0.8 0.8 0.8 --arousal -0.8 -0.8 0.8 0.8```
53
+
54
+
55
+ To train:
56
+
57
+ - Go to ```src``` folder and run ```train.py``` with appropriate arguments. e.g:
58
+ ```python train.py --conditioning continuous_concat```
59
+
60
+ There are 4 different conditioning modes:
61
+ ```none```: No conditioning, vanilla model.
62
+ ```discrete_token```: Conditioning using discrete tokens, i.e. control tokens.
63
+ ```continuous_token```: Conditioning using continuous values embedded as vectors, then prepended to the other embedded tokens in sequence dimension.
64
+ ```continuous_concat```: Conditioning using continuous values embedded as vectors, then concatenated to all other embedded tokens in channel dimension.
65
+
66
+ See ```config.py``` for all options.
midi_emotion/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy==1.21.0
2
+ pandas==1.2.5
3
+ pretty-midi==0.2.9
4
+ pypianoroll==1.0.4
5
+ spotipy==2.19.0
6
+ tables==3.6.1
7
+ torch==2.1.0
8
+ tqdm==4.61.1
midi_emotion/setup.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="midi_emotion",
5
+ version="0.1.0",
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ "torch",
9
+ "numpy",
10
+ "pretty_midi",
11
+ "tqdm"
12
+ ]
13
+ )
midi_emotion/src/config.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import argparse
4
+
5
+ parser = argparse.ArgumentParser(description='Generates emotion-based symbolic music')
6
+
7
+ parser.add_argument("--conditioning", type=str, required=False, default="continuous_concat",
8
+ choices=["none", "discrete_token", "continuous_token",
9
+ "continuous_concat"], help='Conditioning type')
10
+ parser.add_argument("--data_folder", type=str, default="../data_files/lpd_5/lpd_5_full_transposable")
11
+ parser.add_argument('--full_dataset', action="store_true",
12
+ help='Use LPD-full dataset')
13
+ parser.add_argument('--n_layer', type=int, default=20,
14
+ help='number of total layers')
15
+ parser.add_argument('--n_head', type=int, default=16,
16
+ help='number of heads')
17
+ parser.add_argument('--d_model', type=int, default=768,
18
+ help='model dimension')
19
+ parser.add_argument('--d_condition', type=int, default=192,
20
+ help='condition dimension (if continuous_concat is used)')
21
+ parser.add_argument('--d_inner', type=int, default=768*4,
22
+ help='inner dimension in FF')
23
+ parser.add_argument('--tgt_len', type=int, default=1216,
24
+ help='number of tokens to predict')
25
+ parser.add_argument('--max_gen_input_len', type=int, default=-1,
26
+ help='number of tokens to predict')
27
+ parser.add_argument('--gen_len', type=int, default=2048,
28
+ help='Generation length')
29
+ parser.add_argument('--temp_note', type=float, default=1.2,
30
+ help='Temperature for generating notes')
31
+ parser.add_argument('--temp_rest', type=float, default=1.2,
32
+ help='Temperature for generating rests')
33
+ parser.add_argument('--n_bars', type=int, default=-1,
34
+ help='number of bars to use')
35
+ parser.add_argument('--no_pad', action='store_true',
36
+ help='dont pad sequences')
37
+ parser.add_argument('--eval_tgt_len', type=int, default=-1,
38
+ help='number of tokens to predict for evaluation')
39
+ parser.add_argument('--dropout', type=float, default=0.1,
40
+ help='global dropout rate')
41
+ parser.add_argument("--overwrite_dropout", action="store_true",
42
+ help="resets dropouts")
43
+ parser.add_argument('--lr', type=float, default=2e-5,
44
+ help='initial learning rate (0.00025|5 for adam|sgd)')
45
+ parser.add_argument("--overwrite_lr", action="store_true",
46
+ help="Overwrites learning rate if pretrained model is loaded.")
47
+ parser.add_argument('--arousal_feature', default='note_density', type=str,
48
+ choices=['tempo', 'note_density'],
49
+ help='Feature to use as arousal feature')
50
+ parser.add_argument('--scheduler', default='constant', type=str,
51
+ choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant', "cyclic"],
52
+ help='lr scheduler to use.')
53
+ parser.add_argument('--lr_min', type=float, default=5e-6,
54
+ help='minimum learning rate for cyclic scheduler')
55
+ parser.add_argument('--lr_max', type=float, default=5e-3,
56
+ help='maximum learning rate for cyclic scheduler')
57
+ parser.add_argument('--warmup_step', type=int, default=0,
58
+ help='upper epoch limit')
59
+ parser.add_argument('--decay_rate', type=float, default=0.5,
60
+ help='decay factor when ReduceLROnPlateau is used')
61
+ parser.add_argument('--clip', type=float, default=1.0,
62
+ help='gradient clipping')
63
+ parser.add_argument('--batch_size', type=int, default=4,
64
+ help='batch size')
65
+ parser.add_argument('--accumulate_step', type=int, default=1,
66
+ help='accumulate gradients (multiplies effective batch size')
67
+ parser.add_argument('--seed', type=int, default=-1,
68
+ help='random seed')
69
+ parser.add_argument('--no_cuda', action='store_true',
70
+ help='use CPU')
71
+ parser.add_argument('--log_step', type=int, default=1000,
72
+ help='report interval')
73
+ parser.add_argument('--eval_step', type=int, default=8000,
74
+ help='evaluation interval')
75
+ parser.add_argument('--max_eval_step', type=int, default=1000,
76
+ help='maximum evaluation steps')
77
+ parser.add_argument('--gen_step', type=int, default=8000,
78
+ help='generation interval')
79
+ parser.add_argument('--work_dir', default='../output', type=str,
80
+ help='experiment directory.')
81
+ parser.add_argument('--restart_dir', type=str, default=None,
82
+ help='restart dir')
83
+ parser.add_argument('--debug', action='store_true',
84
+ help='run in debug mode (do not create exp dir)')
85
+ parser.add_argument('--max_step', type=int, default=1000000000,
86
+ help='maximum training steps')
87
+ parser.add_argument('--overfit', action='store_true',
88
+ help='Works on a single sample')
89
+ parser.add_argument('--find_lr', action='store_true',
90
+ help='Run learning rate finder')
91
+ parser.add_argument('--num_workers', default=8, type=int,
92
+ help='Number of cores for data loading')
93
+ parser.add_argument('--bar_start_prob', type=float, default=0.5,
94
+ help=('probability of training sample'
95
+ ' starting at a bar location'))
96
+ parser.add_argument("--n_samples", type=int, default=-1,
97
+ help="Limits number of training samples (for faster debugging)")
98
+ parser.add_argument('--n_emotion_bins', type=int, default=5,
99
+ help='Number of emotion bins in each dimension')
100
+ parser.add_argument('--max_transpose', type=int, default=3,
101
+ help='Maximum transpose amount')
102
+ parser.add_argument('--no_amp', action="store_true",
103
+ help='Disable automatic mixed precision')
104
+ parser.add_argument('--reset_scaler', action="store_true",
105
+ help="Reset scaler (can help avoiding nans)")
106
+ parser.add_argument('--exhaustive_eval', action="store_true",
107
+ help="Use data exhaustively (for final evaluation)")
108
+ parser.add_argument('--regression', action="store_true",
109
+ help="Train a regression model")
110
+ parser.add_argument("--always_use_discrete_condition", action="store_true",
111
+ help="Discrete tokens are used for every sequence")
112
+ parser.add_argument("--regression_dir", type=str, default=None,
113
+ help="The path of folder with generations, to perform regression on")
114
+
115
+ args = parser.parse_args()
116
+
117
+ if args.regression_dir is not None:
118
+ args.regression = True
119
+
120
+ if args.conditioning != "continuous_concat":
121
+ args.d_condition = -1
122
+
123
+ assert not (args.exhaustive_eval and args.max_eval_step > 0)
124
+
125
+ if args.full_dataset:
126
+ assert args.conditioning in ["discrete_token", "none"] and not args.regression, "LPD-full has NaN features"
127
+
128
+ if args.regression:
129
+ args.n_layer = 8
130
+ print("Using 8 layers for regression")
131
+
132
+ args.batch_chunk = -1
133
+
134
+ if args.debug or args.overfit:
135
+ args.num_workers = 0
136
+
137
+ if args.find_lr:
138
+ args.debug = True
139
+
140
+ args.d_embed = args.d_model
141
+
142
+ if args.eval_tgt_len < 0:
143
+ args.eval_tgt_len = args.tgt_len
144
+
145
+ if args.scheduler == "cyclic":
146
+ args.lr = args.lr_min
147
+
148
+ if args.restart_dir:
149
+ args.restart_dir = os.path.join(args.work_dir, args.restart_dir)
150
+
151
+ if args.debug:
152
+ args.work_dir = os.path.join(args.work_dir, "DEBUG_" + time.strftime('%Y%m%d-%H%M%S'))
153
+ elif args.no_cuda:
154
+ args.work_dir = os.path.join(args.work_dir, "CPU_" + time.strftime('%Y%m%d-%H%M%S'))
155
+ else:
156
+ args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
midi_emotion/src/create_dataset/hdf5_getters.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Thierry Bertin-Mahieux (2010) Columbia University
3
4
+
5
+
6
+ This code contains a set of getters functions to access the fields
7
+ from an HDF5 song file (regular file with one song or
8
+ aggregate / summary file with many songs)
9
+
10
+ This is part of the Million Song Dataset project from
11
+ LabROSA (Columbia University) and The Echo Nest.
12
+
13
+
14
+ Copyright 2010, Thierry Bertin-Mahieux
15
+
16
+ This program is free software: you can redistribute it and/or modify
17
+ it under the terms of the GNU General Public License as published by
18
+ the Free Software Foundation, either version 3 of the License, or
19
+ (at your option) any later version.
20
+
21
+ This program is distributed in the hope that it will be useful,
22
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
23
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24
+ GNU General Public License for more details.
25
+
26
+ You should have received a copy of the GNU General Public License
27
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
28
+ """
29
+
30
+
31
+ import tables
32
+
33
+
34
+ def open_h5_file_read(h5filename):
35
+ """
36
+ Open an existing H5 in read mode.
37
+ Same function as in hdf5_utils, here so we avoid one import
38
+ """
39
+ return tables.open_file(h5filename, mode='r')
40
+
41
+
42
+ def get_num_songs(h5):
43
+ """
44
+ Return the number of songs contained in this h5 file, i.e. the number of rows
45
+ for all basic informations like name, artist, ...
46
+ """
47
+ return h5.root.metadata.songs.nrows
48
+
49
+ def get_artist_familiarity(h5,songidx=0):
50
+ """
51
+ Get artist familiarity from a HDF5 song file, by default the first song in it
52
+ """
53
+ return h5.root.metadata.songs.cols.artist_familiarity[songidx]
54
+
55
+ def get_artist_hotttnesss(h5,songidx=0):
56
+ """
57
+ Get artist hotttnesss from a HDF5 song file, by default the first song in it
58
+ """
59
+ return h5.root.metadata.songs.cols.artist_hotttnesss[songidx]
60
+
61
+ def get_artist_id(h5,songidx=0):
62
+ """
63
+ Get artist id from a HDF5 song file, by default the first song in it
64
+ """
65
+ return h5.root.metadata.songs.cols.artist_id[songidx]
66
+
67
+ def get_artist_mbid(h5,songidx=0):
68
+ """
69
+ Get artist musibrainz id from a HDF5 song file, by default the first song in it
70
+ """
71
+ return h5.root.metadata.songs.cols.artist_mbid[songidx]
72
+
73
+ def get_artist_playmeid(h5,songidx=0):
74
+ """
75
+ Get artist playme id from a HDF5 song file, by default the first song in it
76
+ """
77
+ return h5.root.metadata.songs.cols.artist_playmeid[songidx]
78
+
79
+ def get_artist_7digitalid(h5,songidx=0):
80
+ """
81
+ Get artist 7digital id from a HDF5 song file, by default the first song in it
82
+ """
83
+ return h5.root.metadata.songs.cols.artist_7digitalid[songidx]
84
+
85
+ def get_artist_latitude(h5,songidx=0):
86
+ """
87
+ Get artist latitude from a HDF5 song file, by default the first song in it
88
+ """
89
+ return h5.root.metadata.songs.cols.artist_latitude[songidx]
90
+
91
+ def get_artist_longitude(h5,songidx=0):
92
+ """
93
+ Get artist longitude from a HDF5 song file, by default the first song in it
94
+ """
95
+ return h5.root.metadata.songs.cols.artist_longitude[songidx]
96
+
97
+ def get_artist_location(h5,songidx=0):
98
+ """
99
+ Get artist location from a HDF5 song file, by default the first song in it
100
+ """
101
+ return h5.root.metadata.songs.cols.artist_location[songidx]
102
+
103
+ def get_artist_name(h5,songidx=0):
104
+ """
105
+ Get artist name from a HDF5 song file, by default the first song in it
106
+ """
107
+ return h5.root.metadata.songs.cols.artist_name[songidx]
108
+
109
+ def get_release(h5,songidx=0):
110
+ """
111
+ Get release from a HDF5 song file, by default the first song in it
112
+ """
113
+ return h5.root.metadata.songs.cols.release[songidx]
114
+
115
+ def get_release_7digitalid(h5,songidx=0):
116
+ """
117
+ Get release 7digital id from a HDF5 song file, by default the first song in it
118
+ """
119
+ return h5.root.metadata.songs.cols.release_7digitalid[songidx]
120
+
121
+ def get_song_id(h5,songidx=0):
122
+ """
123
+ Get song id from a HDF5 song file, by default the first song in it
124
+ """
125
+ return h5.root.metadata.songs.cols.song_id[songidx]
126
+
127
+ def get_song_hotttnesss(h5,songidx=0):
128
+ """
129
+ Get song hotttnesss from a HDF5 song file, by default the first song in it
130
+ """
131
+ return h5.root.metadata.songs.cols.song_hotttnesss[songidx]
132
+
133
+ def get_title(h5,songidx=0):
134
+ """
135
+ Get title from a HDF5 song file, by default the first song in it
136
+ """
137
+ return h5.root.metadata.songs.cols.title[songidx]
138
+
139
+ def get_track_7digitalid(h5,songidx=0):
140
+ """
141
+ Get track 7digital id from a HDF5 song file, by default the first song in it
142
+ """
143
+ return h5.root.metadata.songs.cols.track_7digitalid[songidx]
144
+
145
+ def get_similar_artists(h5,songidx=0):
146
+ """
147
+ Get similar artists array. Takes care of the proper indexing if we are in aggregate
148
+ file. By default, return the array for the first song in the h5 file.
149
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
150
+ """
151
+ if h5.root.metadata.songs.nrows == songidx + 1:
152
+ return h5.root.metadata.similar_artists[h5.root.metadata.songs.cols.idx_similar_artists[songidx]:]
153
+ return h5.root.metadata.similar_artists[h5.root.metadata.songs.cols.idx_similar_artists[songidx]:
154
+ h5.root.metadata.songs.cols.idx_similar_artists[songidx+1]]
155
+
156
+ def get_artist_terms(h5,songidx=0):
157
+ """
158
+ Get artist terms array. Takes care of the proper indexing if we are in aggregate
159
+ file. By default, return the array for the first song in the h5 file.
160
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
161
+ """
162
+ if h5.root.metadata.songs.nrows == songidx + 1:
163
+ return h5.root.metadata.artist_terms[h5.root.metadata.songs.cols.idx_artist_terms[songidx]:]
164
+ return h5.root.metadata.artist_terms[h5.root.metadata.songs.cols.idx_artist_terms[songidx]:
165
+ h5.root.metadata.songs.cols.idx_artist_terms[songidx+1]]
166
+
167
+ def get_artist_terms_freq(h5,songidx=0):
168
+ """
169
+ Get artist terms array frequencies. Takes care of the proper indexing if we are in aggregate
170
+ file. By default, return the array for the first song in the h5 file.
171
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
172
+ """
173
+ if h5.root.metadata.songs.nrows == songidx + 1:
174
+ return h5.root.metadata.artist_terms_freq[h5.root.metadata.songs.cols.idx_artist_terms[songidx]:]
175
+ return h5.root.metadata.artist_terms_freq[h5.root.metadata.songs.cols.idx_artist_terms[songidx]:
176
+ h5.root.metadata.songs.cols.idx_artist_terms[songidx+1]]
177
+
178
+ def get_artist_terms_weight(h5,songidx=0):
179
+ """
180
+ Get artist terms array frequencies. Takes care of the proper indexing if we are in aggregate
181
+ file. By default, return the array for the first song in the h5 file.
182
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
183
+ """
184
+ if h5.root.metadata.songs.nrows == songidx + 1:
185
+ return h5.root.metadata.artist_terms_weight[h5.root.metadata.songs.cols.idx_artist_terms[songidx]:]
186
+ return h5.root.metadata.artist_terms_weight[h5.root.metadata.songs.cols.idx_artist_terms[songidx]:
187
+ h5.root.metadata.songs.cols.idx_artist_terms[songidx+1]]
188
+
189
+ def get_analysis_sample_rate(h5,songidx=0):
190
+ """
191
+ Get analysis sample rate from a HDF5 song file, by default the first song in it
192
+ """
193
+ return h5.root.analysis.songs.cols.analysis_sample_rate[songidx]
194
+
195
+ def get_audio_md5(h5,songidx=0):
196
+ """
197
+ Get audio MD5 from a HDF5 song file, by default the first song in it
198
+ """
199
+ return h5.root.analysis.songs.cols.audio_md5[songidx]
200
+
201
+ def get_danceability(h5,songidx=0):
202
+ """
203
+ Get danceability from a HDF5 song file, by default the first song in it
204
+ """
205
+ return h5.root.analysis.songs.cols.danceability[songidx]
206
+
207
+ def get_duration(h5,songidx=0):
208
+ """
209
+ Get duration from a HDF5 song file, by default the first song in it
210
+ """
211
+ return h5.root.analysis.songs.cols.duration[songidx]
212
+
213
+ def get_end_of_fade_in(h5,songidx=0):
214
+ """
215
+ Get end of fade in from a HDF5 song file, by default the first song in it
216
+ """
217
+ return h5.root.analysis.songs.cols.end_of_fade_in[songidx]
218
+
219
+ def get_energy(h5,songidx=0):
220
+ """
221
+ Get energy from a HDF5 song file, by default the first song in it
222
+ """
223
+ return h5.root.analysis.songs.cols.energy[songidx]
224
+
225
+ def get_key(h5,songidx=0):
226
+ """
227
+ Get key from a HDF5 song file, by default the first song in it
228
+ """
229
+ return h5.root.analysis.songs.cols.key[songidx]
230
+
231
+ def get_key_confidence(h5,songidx=0):
232
+ """
233
+ Get key confidence from a HDF5 song file, by default the first song in it
234
+ """
235
+ return h5.root.analysis.songs.cols.key_confidence[songidx]
236
+
237
+ def get_loudness(h5,songidx=0):
238
+ """
239
+ Get loudness from a HDF5 song file, by default the first song in it
240
+ """
241
+ return h5.root.analysis.songs.cols.loudness[songidx]
242
+
243
+ def get_mode(h5,songidx=0):
244
+ """
245
+ Get mode from a HDF5 song file, by default the first song in it
246
+ """
247
+ return h5.root.analysis.songs.cols.mode[songidx]
248
+
249
+ def get_mode_confidence(h5,songidx=0):
250
+ """
251
+ Get mode confidence from a HDF5 song file, by default the first song in it
252
+ """
253
+ return h5.root.analysis.songs.cols.mode_confidence[songidx]
254
+
255
+ def get_start_of_fade_out(h5,songidx=0):
256
+ """
257
+ Get start of fade out from a HDF5 song file, by default the first song in it
258
+ """
259
+ return h5.root.analysis.songs.cols.start_of_fade_out[songidx]
260
+
261
+ def get_tempo(h5,songidx=0):
262
+ """
263
+ Get tempo from a HDF5 song file, by default the first song in it
264
+ """
265
+ return h5.root.analysis.songs.cols.tempo[songidx]
266
+
267
+ def get_time_signature(h5,songidx=0):
268
+ """
269
+ Get signature from a HDF5 song file, by default the first song in it
270
+ """
271
+ return h5.root.analysis.songs.cols.time_signature[songidx]
272
+
273
+ def get_time_signature_confidence(h5,songidx=0):
274
+ """
275
+ Get signature confidence from a HDF5 song file, by default the first song in it
276
+ """
277
+ return h5.root.analysis.songs.cols.time_signature_confidence[songidx]
278
+
279
+ def get_track_id(h5,songidx=0):
280
+ """
281
+ Get track id from a HDF5 song file, by default the first song in it
282
+ """
283
+ return h5.root.analysis.songs.cols.track_id[songidx]
284
+
285
+ def get_segments_start(h5,songidx=0):
286
+ """
287
+ Get segments start array. Takes care of the proper indexing if we are in aggregate
288
+ file. By default, return the array for the first song in the h5 file.
289
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
290
+ """
291
+ if h5.root.analysis.songs.nrows == songidx + 1:
292
+ return h5.root.analysis.segments_start[h5.root.analysis.songs.cols.idx_segments_start[songidx]:]
293
+ return h5.root.analysis.segments_start[h5.root.analysis.songs.cols.idx_segments_start[songidx]:
294
+ h5.root.analysis.songs.cols.idx_segments_start[songidx+1]]
295
+
296
+ def get_segments_confidence(h5,songidx=0):
297
+ """
298
+ Get segments confidence array. Takes care of the proper indexing if we are in aggregate
299
+ file. By default, return the array for the first song in the h5 file.
300
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
301
+ """
302
+ if h5.root.analysis.songs.nrows == songidx + 1:
303
+ return h5.root.analysis.segments_confidence[h5.root.analysis.songs.cols.idx_segments_confidence[songidx]:]
304
+ return h5.root.analysis.segments_confidence[h5.root.analysis.songs.cols.idx_segments_confidence[songidx]:
305
+ h5.root.analysis.songs.cols.idx_segments_confidence[songidx+1]]
306
+
307
+ def get_segments_pitches(h5,songidx=0):
308
+ """
309
+ Get segments pitches array. Takes care of the proper indexing if we are in aggregate
310
+ file. By default, return the array for the first song in the h5 file.
311
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
312
+ """
313
+ if h5.root.analysis.songs.nrows == songidx + 1:
314
+ return h5.root.analysis.segments_pitches[h5.root.analysis.songs.cols.idx_segments_pitches[songidx]:,:]
315
+ return h5.root.analysis.segments_pitches[h5.root.analysis.songs.cols.idx_segments_pitches[songidx]:
316
+ h5.root.analysis.songs.cols.idx_segments_pitches[songidx+1],:]
317
+
318
+ def get_segments_timbre(h5,songidx=0):
319
+ """
320
+ Get segments timbre array. Takes care of the proper indexing if we are in aggregate
321
+ file. By default, return the array for the first song in the h5 file.
322
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
323
+ """
324
+ if h5.root.analysis.songs.nrows == songidx + 1:
325
+ return h5.root.analysis.segments_timbre[h5.root.analysis.songs.cols.idx_segments_timbre[songidx]:,:]
326
+ return h5.root.analysis.segments_timbre[h5.root.analysis.songs.cols.idx_segments_timbre[songidx]:
327
+ h5.root.analysis.songs.cols.idx_segments_timbre[songidx+1],:]
328
+
329
+ def get_segments_loudness_max(h5,songidx=0):
330
+ """
331
+ Get segments loudness max array. Takes care of the proper indexing if we are in aggregate
332
+ file. By default, return the array for the first song in the h5 file.
333
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
334
+ """
335
+ if h5.root.analysis.songs.nrows == songidx + 1:
336
+ return h5.root.analysis.segments_loudness_max[h5.root.analysis.songs.cols.idx_segments_loudness_max[songidx]:]
337
+ return h5.root.analysis.segments_loudness_max[h5.root.analysis.songs.cols.idx_segments_loudness_max[songidx]:
338
+ h5.root.analysis.songs.cols.idx_segments_loudness_max[songidx+1]]
339
+
340
+ def get_segments_loudness_max_time(h5,songidx=0):
341
+ """
342
+ Get segments loudness max time array. Takes care of the proper indexing if we are in aggregate
343
+ file. By default, return the array for the first song in the h5 file.
344
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
345
+ """
346
+ if h5.root.analysis.songs.nrows == songidx + 1:
347
+ return h5.root.analysis.segments_loudness_max_time[h5.root.analysis.songs.cols.idx_segments_loudness_max_time[songidx]:]
348
+ return h5.root.analysis.segments_loudness_max_time[h5.root.analysis.songs.cols.idx_segments_loudness_max_time[songidx]:
349
+ h5.root.analysis.songs.cols.idx_segments_loudness_max_time[songidx+1]]
350
+
351
+ def get_segments_loudness_start(h5,songidx=0):
352
+ """
353
+ Get segments loudness start array. Takes care of the proper indexing if we are in aggregate
354
+ file. By default, return the array for the first song in the h5 file.
355
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
356
+ """
357
+ if h5.root.analysis.songs.nrows == songidx + 1:
358
+ return h5.root.analysis.segments_loudness_start[h5.root.analysis.songs.cols.idx_segments_loudness_start[songidx]:]
359
+ return h5.root.analysis.segments_loudness_start[h5.root.analysis.songs.cols.idx_segments_loudness_start[songidx]:
360
+ h5.root.analysis.songs.cols.idx_segments_loudness_start[songidx+1]]
361
+
362
+ def get_sections_start(h5,songidx=0):
363
+ """
364
+ Get sections start array. Takes care of the proper indexing if we are in aggregate
365
+ file. By default, return the array for the first song in the h5 file.
366
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
367
+ """
368
+ if h5.root.analysis.songs.nrows == songidx + 1:
369
+ return h5.root.analysis.sections_start[h5.root.analysis.songs.cols.idx_sections_start[songidx]:]
370
+ return h5.root.analysis.sections_start[h5.root.analysis.songs.cols.idx_sections_start[songidx]:
371
+ h5.root.analysis.songs.cols.idx_sections_start[songidx+1]]
372
+
373
+ def get_sections_confidence(h5,songidx=0):
374
+ """
375
+ Get sections confidence array. Takes care of the proper indexing if we are in aggregate
376
+ file. By default, return the array for the first song in the h5 file.
377
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
378
+ """
379
+ if h5.root.analysis.songs.nrows == songidx + 1:
380
+ return h5.root.analysis.sections_confidence[h5.root.analysis.songs.cols.idx_sections_confidence[songidx]:]
381
+ return h5.root.analysis.sections_confidence[h5.root.analysis.songs.cols.idx_sections_confidence[songidx]:
382
+ h5.root.analysis.songs.cols.idx_sections_confidence[songidx+1]]
383
+
384
+ def get_beats_start(h5,songidx=0):
385
+ """
386
+ Get beats start array. Takes care of the proper indexing if we are in aggregate
387
+ file. By default, return the array for the first song in the h5 file.
388
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
389
+ """
390
+ if h5.root.analysis.songs.nrows == songidx + 1:
391
+ return h5.root.analysis.beats_start[h5.root.analysis.songs.cols.idx_beats_start[songidx]:]
392
+ return h5.root.analysis.beats_start[h5.root.analysis.songs.cols.idx_beats_start[songidx]:
393
+ h5.root.analysis.songs.cols.idx_beats_start[songidx+1]]
394
+
395
+ def get_beats_confidence(h5,songidx=0):
396
+ """
397
+ Get beats confidence array. Takes care of the proper indexing if we are in aggregate
398
+ file. By default, return the array for the first song in the h5 file.
399
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
400
+ """
401
+ if h5.root.analysis.songs.nrows == songidx + 1:
402
+ return h5.root.analysis.beats_confidence[h5.root.analysis.songs.cols.idx_beats_confidence[songidx]:]
403
+ return h5.root.analysis.beats_confidence[h5.root.analysis.songs.cols.idx_beats_confidence[songidx]:
404
+ h5.root.analysis.songs.cols.idx_beats_confidence[songidx+1]]
405
+
406
+ def get_bars_start(h5,songidx=0):
407
+ """
408
+ Get bars start array. Takes care of the proper indexing if we are in aggregate
409
+ file. By default, return the array for the first song in the h5 file.
410
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
411
+ """
412
+ if h5.root.analysis.songs.nrows == songidx + 1:
413
+ return h5.root.analysis.bars_start[h5.root.analysis.songs.cols.idx_bars_start[songidx]:]
414
+ return h5.root.analysis.bars_start[h5.root.analysis.songs.cols.idx_bars_start[songidx]:
415
+ h5.root.analysis.songs.cols.idx_bars_start[songidx+1]]
416
+
417
+ def get_bars_confidence(h5,songidx=0):
418
+ """
419
+ Get bars start array. Takes care of the proper indexing if we are in aggregate
420
+ file. By default, return the array for the first song in the h5 file.
421
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
422
+ """
423
+ if h5.root.analysis.songs.nrows == songidx + 1:
424
+ return h5.root.analysis.bars_confidence[h5.root.analysis.songs.cols.idx_bars_confidence[songidx]:]
425
+ return h5.root.analysis.bars_confidence[h5.root.analysis.songs.cols.idx_bars_confidence[songidx]:
426
+ h5.root.analysis.songs.cols.idx_bars_confidence[songidx+1]]
427
+
428
+ def get_tatums_start(h5,songidx=0):
429
+ """
430
+ Get tatums start array. Takes care of the proper indexing if we are in aggregate
431
+ file. By default, return the array for the first song in the h5 file.
432
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
433
+ """
434
+ if h5.root.analysis.songs.nrows == songidx + 1:
435
+ return h5.root.analysis.tatums_start[h5.root.analysis.songs.cols.idx_tatums_start[songidx]:]
436
+ return h5.root.analysis.tatums_start[h5.root.analysis.songs.cols.idx_tatums_start[songidx]:
437
+ h5.root.analysis.songs.cols.idx_tatums_start[songidx+1]]
438
+
439
+ def get_tatums_confidence(h5,songidx=0):
440
+ """
441
+ Get tatums confidence array. Takes care of the proper indexing if we are in aggregate
442
+ file. By default, return the array for the first song in the h5 file.
443
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
444
+ """
445
+ if h5.root.analysis.songs.nrows == songidx + 1:
446
+ return h5.root.analysis.tatums_confidence[h5.root.analysis.songs.cols.idx_tatums_confidence[songidx]:]
447
+ return h5.root.analysis.tatums_confidence[h5.root.analysis.songs.cols.idx_tatums_confidence[songidx]:
448
+ h5.root.analysis.songs.cols.idx_tatums_confidence[songidx+1]]
449
+
450
+ def get_artist_mbtags(h5,songidx=0):
451
+ """
452
+ Get artist musicbrainz tag array. Takes care of the proper indexing if we are in aggregate
453
+ file. By default, return the array for the first song in the h5 file.
454
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
455
+ """
456
+ if h5.root.musicbrainz.songs.nrows == songidx + 1:
457
+ return h5.root.musicbrainz.artist_mbtags[h5.root.musicbrainz.songs.cols.idx_artist_mbtags[songidx]:]
458
+ return h5.root.musicbrainz.artist_mbtags[h5.root.metadata.songs.cols.idx_artist_mbtags[songidx]:
459
+ h5.root.metadata.songs.cols.idx_artist_mbtags[songidx+1]]
460
+
461
+ def get_artist_mbtags_count(h5,songidx=0):
462
+ """
463
+ Get artist musicbrainz tag count array. Takes care of the proper indexing if we are in aggregate
464
+ file. By default, return the array for the first song in the h5 file.
465
+ To get a regular numpy ndarray, cast the result to: numpy.array( )
466
+ """
467
+ if h5.root.musicbrainz.songs.nrows == songidx + 1:
468
+ return h5.root.musicbrainz.artist_mbtags_count[h5.root.musicbrainz.songs.cols.idx_artist_mbtags[songidx]:]
469
+ return h5.root.musicbrainz.artist_mbtags_count[h5.root.metadata.songs.cols.idx_artist_mbtags[songidx]:
470
+ h5.root.metadata.songs.cols.idx_artist_mbtags[songidx+1]]
471
+
472
+ def get_year(h5,songidx=0):
473
+ """
474
+ Get release year from a HDF5 song file, by default the first song in it
475
+ """
476
+ return h5.root.musicbrainz.songs.cols.year[songidx]
midi_emotion/src/create_dataset/run.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pretty_midi
3
+ import pypianoroll
4
+ import hdf5_getters
5
+ from tqdm import tqdm
6
+ import os
7
+ import concurrent.futures
8
+ import collections
9
+ import utils
10
+ from glob import glob
11
+ import pandas as pd
12
+ import csv
13
+ from copy import deepcopy
14
+
15
+ """
16
+ Written by Serkan Sulun
17
+
18
+ Creates labels for Lakh MIDI (or pianoroll) dataset.
19
+ Labels include low-level MIDI features such as tempo, note density and number of MIDI files.
20
+ They also include high-level features obtained from Spotify Developer API, such as valence, energy, etc.
21
+
22
+ See utils.py and fill in the variables client_id and client_secret.
23
+
24
+ When the user quota is exceeded, Spotify blocks access and the script gets stuck.
25
+ In that case, you may need to re-run the script some time later,
26
+ or use a different account with different client ID and secret.
27
+ """
28
+
29
+ def run_parallel(func, my_iter):
30
+ # Parallel processing visualized with tqdm
31
+ with concurrent.futures.ProcessPoolExecutor() as executor:
32
+ results = list(tqdm(executor.map(func, my_iter), total=len(my_iter)))
33
+ return results
34
+
35
+ write = False
36
+ redo = True
37
+
38
+ main_output_dir = "../../data_files/features"
39
+ os.makedirs(main_output_dir, exist_ok=True)
40
+
41
+ match_scores_path = "../../data_files/match_scores.json"
42
+ msd_summary_path = "../../data_files/msd_summary_file.h5"
43
+ echonest_folder_path = "../../data_files/millionsongdataset_echonest"
44
+
45
+ use_pianoroll_dataset = True
46
+ if use_pianoroll_dataset:
47
+ midi_dataset_path = "../../data_files/lpd_full/lpd/lpd_full"
48
+ extension = ".npz"
49
+ output_dir = os.path.join(main_output_dir, "pianoroll")
50
+ else:
51
+ midi_dataset_path = "lmd_full"
52
+ extension = ".mid"
53
+ output_dir = os.path.join(main_output_dir, "midi")
54
+ os.makedirs(output_dir, exist_ok=True)
55
+
56
+ ### PART I: Map track_ids (in midi dataset) to Spotify features
57
+
58
+ ### 1- Create mappings track_id (in midi dataset) -> metadata (in Echonest)
59
+
60
+ output_path = os.path.join(output_dir, "trackid_to_songid.json")
61
+
62
+ with open(match_scores_path, "r") as f:
63
+ match_scores = json.load(f)
64
+
65
+ track_ids = sorted(list(match_scores.keys()))
66
+
67
+ if os.path.exists(output_path) and not redo:
68
+ with open(output_path, "r") as f:
69
+ trackid_to_songid = json.load(f)
70
+ else:
71
+ h5_msd = hdf5_getters.open_h5_file_read(msd_summary_path)
72
+ n_msd = hdf5_getters.get_num_songs(h5_msd)
73
+
74
+ trackid_to_songid = {}
75
+ print("Adding metadata to each track in Lakh dataset")
76
+
77
+ for i in tqdm(range(n_msd)):
78
+ track_id = hdf5_getters.get_track_id(h5_msd, i).decode("utf-8")
79
+ if track_id in track_ids:
80
+ # get data from MSD
81
+ song_id = hdf5_getters.get_song_id(h5_msd, i).decode("utf-8")
82
+ artist = hdf5_getters.get_artist_name(h5_msd, i).decode("utf-8")
83
+ title = hdf5_getters.get_title(h5_msd, i).decode("utf-8")
84
+ release = hdf5_getters.get_release(h5_msd, i).decode("utf-8")
85
+ trackid_to_songid[track_id] = {"song_id": song_id,"title": title,
86
+ "artist": artist, "release": release}
87
+
88
+ # sort
89
+ trackid_to_songid = collections.OrderedDict(sorted(trackid_to_songid.items()))
90
+ if write:
91
+ with open(output_path, "w") as f:
92
+ json.dump(trackid_to_songid, f, indent=4)
93
+ print(f"Output saved to {output_path}")
94
+
95
+ ### 2- Create mappings metadata (in Echonest) -> Spotify IDs
96
+ output_path = os.path.join(output_dir, "songid_to_spotify.json")
97
+ if os.path.exists(output_path) and not redo:
98
+ with open(output_path, "r") as f:
99
+ songid_to_spotify = json.load(f)
100
+ else:
101
+ song_ids = sorted([val["song_id"] for val in trackid_to_songid.values()])
102
+ songid_to_spotify = {}
103
+ print("Mapping Echonest song IDs to Spotify song IDs")
104
+ for song_id in tqdm(song_ids):
105
+ file_path = os.path.join(echonest_folder_path, song_id[2:4], song_id + ".json")
106
+ spotify_ids = utils.get_spotify_ids(file_path)
107
+ songid_to_spotify[song_id] = spotify_ids
108
+ if write:
109
+ with open(output_path, "w") as f:
110
+ json.dump(songid_to_spotify, f, indent=4)
111
+ print(f"Output saved to {output_path}")
112
+
113
+
114
+ ### 3- Merge and add Spotify features
115
+ output_path = os.path.join(output_dir, "trackid_to_spotify_features.json")
116
+ # When user quota is exceeded, Spotify blocks access and the script gets stuck.
117
+ # In that case, you may need to re-run the script some time later,
118
+ # or use a different account with different client ID and secret.
119
+ # So we keep an incomplete csv file, so that we can continue later from where we left.
120
+ output_path_incomplete = os.path.join(output_dir, "incomplete_trackid_to_spotify_features.csv")
121
+
122
+ if os.path.exists(output_path) and not redo:
123
+ with open(output_path, "r") as f:
124
+ trackid_to_spotify_features = json.load(f)
125
+ else:
126
+ fieldnames = ["track_id", "song_id", "title", "artist", "release",
127
+ "spotify_id", "spotify_title", "spotify_artist", "spotify_album", "spotify_audio_features"]
128
+
129
+ data_to_process = deepcopy(trackid_to_songid)
130
+ write_header = True
131
+
132
+ if os.path.exists(output_path_incomplete):
133
+ # Continue from where we've left
134
+ data_already_processed = utils.read_csv(output_path_incomplete)
135
+ track_ids_already_processed = [entry["track_id"] for entry in data_already_processed]
136
+ data_to_process = {key: value for key, value in data_to_process.items() if key not in track_ids_already_processed}
137
+ write_header = False
138
+
139
+ with open(output_path_incomplete, "a") as f_out:
140
+ csv_writer = csv.DictWriter(f_out, fieldnames=fieldnames)
141
+ if write_header:
142
+ csv_writer.writeheader()
143
+
144
+ print("Adding Spotify features")
145
+ for track_id, data in tqdm(data_to_process.items()):
146
+ data["track_id"] = track_id
147
+ album = data["release"]
148
+ spotify_ids = songid_to_spotify[data["song_id"]]
149
+ if spotify_ids == []:
150
+ # use metadata to search spotify
151
+ best_spotify_track = utils.search_spotify_flexible(data["title"], data["artist"], data["release"])
152
+ else:
153
+ spotify_tracks = utils.get_spotify_tracks(spotify_ids)
154
+ if spotify_tracks == None:
155
+ for key in ["id", "title", "artist", "album", "audio_features"]:
156
+ data["spotify_" + key] = None
157
+ elif len(spotify_tracks) > 1:
158
+ # find best spotify id by comparing album names
159
+ best_match_score = 0
160
+ best_match_ind = 0
161
+ for i, track in enumerate(spotify_tracks):
162
+ if track is not None:
163
+ spotify_album = track["album"]["name"] if track is not None else ""
164
+ match_score = utils.matching_strings_flexible(album, spotify_album)
165
+
166
+ if match_score > best_match_score:
167
+ best_match_score = match_score
168
+ best_match_ind = i
169
+
170
+ best_spotify_track = spotify_tracks[best_match_ind]
171
+ else:
172
+ best_spotify_track = spotify_tracks[0]
173
+
174
+ if best_spotify_track is not None:
175
+ spotify_id = best_spotify_track["uri"].split(":")[-1]
176
+ spotify_audio_features = utils.get_spotify_features(spotify_id)[0]
177
+
178
+ # if spotify_audio_features["valence"] == 0.0:
179
+ # # A large portion of files have 0.0 valence, although they are NaNs
180
+ # spotify_audio_features["valence"] = float("nan")
181
+ spotify_artists = ", ".join([artist["name"] for artist in best_spotify_track["artists"]])
182
+
183
+ data["spotify_id"] = spotify_id
184
+ data["spotify_title"] = best_spotify_track['name']
185
+ data["spotify_artist"] = spotify_artists
186
+ data["spotify_album"] = best_spotify_track["album"]["name"]
187
+ data["spotify_audio_features"] = spotify_audio_features
188
+ else:
189
+ for key in ["id", "title", "artist", "album", "audio_features"]:
190
+ data["spotify_" + key] = None
191
+
192
+ csv_writer.writerow(data)
193
+
194
+ # Now write final data to json
195
+ trackid_to_spotify_features_list = utils.read_csv(output_path_incomplete)
196
+ trackid_to_spotify_features = {}
197
+ # unlike json, csv doesnt support dict within dict, so convert it to dict manually
198
+ for item in trackid_to_spotify_features_list:
199
+ spotify_audio_features = item["spotify_audio_features"]
200
+ if spotify_audio_features != "":
201
+ spotify_audio_features = eval(spotify_audio_features)
202
+ item["spotify_audio_features"] = spotify_audio_features
203
+ track_id = deepcopy(item["track_id"])
204
+ del item["track_id"]
205
+ trackid_to_spotify_features[track_id] = item
206
+
207
+ if write:
208
+ with open(output_path, "w") as f:
209
+ json.dump(trackid_to_spotify_features, f, indent=4)
210
+ print(f"Output saved to {output_path}")
211
+
212
+
213
+ ### PART II: Dealing with symbolic music data
214
+ ### 4- Revert matching scores
215
+ """ Matched data has the format: track_ID -> midi_file
216
+ where multiple tracks could be mapped to a single midi file.
217
+ We want to revert this mapping and then keep unique midi files
218
+ Revert match scores file to have mapping midi_file -> track_ID
219
+ """
220
+
221
+ output_path = os.path.join(output_dir, "match_scores_reverse.json")
222
+ if os.path.exists(output_path) and not redo:
223
+ with open(output_path, "r") as f:
224
+ match_scores_reversed = json.load(f)
225
+ else:
226
+ with open(match_scores_path, "r") as f:
227
+ in_data = json.load(f)
228
+ match_scores_reversed = {}
229
+ print("Reversing match scores.")
230
+ for track_id, matching in tqdm(in_data.items()):
231
+ for file_, score in matching.items():
232
+ if file_ not in match_scores_reversed.keys():
233
+ match_scores_reversed[file_] = {track_id: score}
234
+ else:
235
+ match_scores_reversed[file_][track_id] = score
236
+
237
+ # order match scores
238
+ for k in match_scores_reversed.keys():
239
+ match_scores_reversed[k] = collections.OrderedDict(sorted(match_scores_reversed[k].items(), reverse=True, key=lambda x: x[-1]))
240
+
241
+ # order filenames
242
+ match_scores_reversed = collections.OrderedDict(sorted(match_scores_reversed.items(), key=lambda x: x[0]))
243
+ if write:
244
+ with open(output_path, "w") as f:
245
+ json.dump(match_scores_reversed, f, indent=4)
246
+ print(f"Output saved to {output_path}")
247
+
248
+ # 5- Filter match scores to only keep best match
249
+ output_path = os.path.join(output_dir, "best_match_scores.json")
250
+ if os.path.exists(output_path) and not redo:
251
+ with open(output_path, "r") as f:
252
+ best_match_scores_reversed = json.load(f)
253
+ else:
254
+ best_match_scores_reversed = {}
255
+ print("Selecting best matching tracks.")
256
+ for midi_file, match in tqdm(match_scores_reversed.items()):
257
+ best_match_scores_reversed[midi_file] = list(match.items())[0]
258
+ if write:
259
+ with open(output_path, "w") as f:
260
+ json.dump(best_match_scores_reversed, f, indent=4)
261
+ print(f"Output saved to {output_path}")
262
+
263
+ ### 6- Filter unique midis
264
+ """LMD was created by creating hashes for the entire files
265
+ and then keeping files with unique hashes.
266
+ However, some files' musical content are the same, and only their metadata are different.
267
+ So we hash the content (pianoroll array), and further filter out the unique ones."""
268
+ # Create hashes for midis
269
+
270
+ output_path = os.path.join(output_dir, "hashes.json")
271
+
272
+ if os.path.exists(output_path) and not redo:
273
+ with open(output_path, "r") as f:
274
+ midi_file_to_hash = json.load(f)
275
+ else:
276
+ def get_hash_and_file(path):
277
+ hash_ = utils.get_hash(path)
278
+ file_ = os.path.basename(path)
279
+ file_ = file_[:-4]
280
+ return [file_, hash_]
281
+
282
+ file_paths = sorted(glob(midi_dataset_path + "/**/*" + extension, recursive=True))
283
+ assert len(file_paths) > 0, f"No MIDI files found at {midi_dataset_path}"
284
+ print("Getting hashes for MIDIs.")
285
+ midi_file_to_hash = run_parallel(get_hash_and_file, file_paths)
286
+ midi_file_to_hash = sorted(midi_file_to_hash, key=lambda x:x[0])
287
+ midi_file_to_hash = dict(midi_file_to_hash)
288
+ if write:
289
+ with open(output_path, "w") as f:
290
+ json.dump(midi_file_to_hash, f, indent=4)
291
+ print(f"Output saved to {output_path}")
292
+
293
+ # also do the reverse hash -> midi
294
+ output_path = os.path.join(output_dir, "unique_files.json")
295
+ if os.path.exists(output_path) and not redo:
296
+ with open(output_path, "r") as f:
297
+ midi_files_unique = json.load(f)
298
+ else:
299
+ hash_to_midi_file = {}
300
+ for midi_file, hash in midi_file_to_hash.items():
301
+ try:
302
+ best_match_score = best_match_scores_reversed[midi_file][1]
303
+ except:
304
+ best_match_score = 0
305
+ if hash in hash_to_midi_file.keys():
306
+ hash_to_midi_file[hash].append((midi_file, best_match_score))
307
+ else:
308
+ hash_to_midi_file[hash] = [(midi_file, best_match_score)]
309
+
310
+ midi_files_unique = []
311
+ # Get unique midis (with highest match score)
312
+ print("Getting unique MIDIs.")
313
+ for hash, midi_files_and_match_scores in hash_to_midi_file.items():
314
+ if hash != "empty_pianoroll":
315
+ midi_files_and_match_scores = sorted(midi_files_and_match_scores, key=lambda x: x[1], reverse=True)
316
+ midi_files_unique.append(midi_files_and_match_scores[0][0])
317
+ if write:
318
+ with open(output_path, "w") as f:
319
+ json.dump(midi_files_unique, f, indent=4)
320
+ print(f"Output saved to {output_path}")
321
+
322
+ # create unique matched midis list
323
+ midi_files_matched = list(match_scores_reversed.keys())
324
+
325
+ output_path = os.path.join(output_dir, "midis_matched_unique.json")
326
+ if os.path.exists(output_path) and not redo:
327
+ with open(output_path, "r") as f:
328
+ midi_files_matched_unique = json.load(f)
329
+ else:
330
+ print("Getting unique matched MIDIs.")
331
+ midi_files_matched_unique = sorted(list(set(midi_files_matched).intersection(midi_files_unique)))
332
+ if write:
333
+ with open(output_path, "w") as f:
334
+ json.dump(midi_files_matched_unique, f, indent=4)
335
+ print(f"Output saved to {output_path}")
336
+
337
+ # create unique unmatched midis list
338
+ output_path = os.path.join(output_dir, "midis_unmatched_unique.json")
339
+ if os.path.exists(output_path) and not redo:
340
+ with open(output_path, "r") as f:
341
+ midi_files_unmatched_unique = json.load(f)
342
+ else:
343
+ print("Getting unique unmatched MIDIs.")
344
+ midi_files_unmatched_unique = sorted(list(set(midi_files_unique) - set(midi_files_matched_unique)))
345
+ if write:
346
+ with open(output_path, "w") as f:
347
+ json.dump(midi_files_unmatched_unique, f, indent=4)
348
+ print(f"Output saved to {output_path}")
349
+
350
+ ### 6- Create mappings: midi -> best matching track ID, spotify features
351
+ output_path = os.path.join(output_dir, "spotify_features.json")
352
+ if os.path.exists(output_path) and not redo:
353
+ with open(output_path, "r") as f:
354
+ midi_file_to_spotify_features = json.load(f)
355
+ else:
356
+ midi_file_to_spotify_features = {}
357
+ print("Adding Spotify for matched unique MIDIs.")
358
+ for pr in tqdm(midi_files_matched_unique):
359
+ sample_data = {}
360
+ sample_data["track_id"], sample_data["match_score"] = best_match_scores_reversed[pr]
361
+ metadata_and_spotify = trackid_to_spotify_features[sample_data["track_id"]]
362
+ sample_data.update(metadata_and_spotify)
363
+ midi_file_to_spotify_features[pr] = sample_data
364
+ if write:
365
+ with open(output_path, "w") as f:
366
+ json.dump(midi_file_to_spotify_features, f, indent=4)
367
+ print(f"Output saved to {output_path}")
368
+
369
+ ### 7- For all midis, get low level features
370
+ # (tempo, note density, number of instruments)
371
+
372
+ output_path = os.path.join(output_dir, "midi_features.json")
373
+ if os.path.exists(output_path) and not redo:
374
+ with open(output_path, "r") as f:
375
+ midi_file_to_midi_features = json.load(f)
376
+ else:
377
+ def get_midi_features(midi_file):
378
+ midi_path = os.path.join(midi_dataset_path, midi_file[0], midi_file + extension)
379
+ if use_pianoroll_dataset:
380
+ mid = pypianoroll.load(midi_path).to_pretty_midi()
381
+ else:
382
+ mid = pretty_midi.PrettyMIDI(midi_path)
383
+ note_density = utils.get_note_density(mid)
384
+ tempo = utils.get_tempo(mid)
385
+ n_instruments = utils.get_n_instruments(mid)
386
+ duration = mid.get_end_time()
387
+ midi_features = {
388
+ "note_density": note_density,
389
+ "tempo": tempo,
390
+ "n_instruments": n_instruments,
391
+ "duration": duration,
392
+ }
393
+ return [midi_file, midi_features]
394
+ print("Getting low-level MIDI features")
395
+ midi_file_to_midi_features = run_parallel(get_midi_features, midi_files_unique)
396
+ midi_file_to_midi_features = dict(midi_file_to_midi_features)
397
+ if write:
398
+ with open(output_path, "w") as f:
399
+ json.dump(midi_file_to_midi_features, f, indent=4)
400
+ print(f"Output saved to {output_path}")
401
+
402
+ ### 8- Merge MIDI features and matched (Spotify) features
403
+ output_path = os.path.join(output_dir, "full_dataset_features.json")
404
+ if os.path.exists(output_path) and not redo:
405
+ with open(output_path, "r") as f:
406
+ midi_file_to_merged_features = json.load(f)
407
+ else:
408
+ midi_file_to_merged_features = {}
409
+ print("Merging MIDI features and Spotify features for full dataset.")
410
+ for midi_file in tqdm(midi_file_to_midi_features.keys()):
411
+ midi_file_to_merged_features[midi_file] = {}
412
+ midi_file_to_merged_features[midi_file]["midi_features"] = midi_file_to_midi_features[midi_file]
413
+ if midi_file in midi_file_to_spotify_features.keys():
414
+ matched_features = midi_file_to_spotify_features[midi_file]
415
+ else:
416
+ matched_features = {}
417
+ midi_file_to_merged_features[midi_file]["matched_features"] = matched_features
418
+ if write:
419
+ with open(output_path, "w") as f:
420
+ json.dump(midi_file_to_merged_features, f, indent=4)
421
+ print(f"Output saved to {output_path}")
422
+
423
+ ### Do the same for matched dataset
424
+ output_path = os.path.join(output_dir, "matched_dataset_features.json")
425
+ if os.path.exists(output_path) and not redo:
426
+ with open(output_path, "r") as f:
427
+ matched_midi_file_to_merged_features = json.load(f)
428
+ else:
429
+ print("Merging MIDI features and Spotify features for the matched dataset.")
430
+ matched_midi_file_to_merged_features = \
431
+ {file_: midi_file_to_merged_features[file_] for file_ in tqdm(midi_files_matched_unique)}
432
+ if write:
433
+ with open(output_path, "w") as f:
434
+ json.dump(matched_midi_file_to_merged_features, f, indent=4)
435
+ print(f"Output saved to {output_path}")
436
+
437
+ ### PART III: Constructing training dataset
438
+ ### 9- Summarize matched dataset features by only taking valence and note densities per instrument,
439
+ # number of instruments, durations, is_matched
440
+
441
+ output_path = os.path.join(output_dir, "full_dataset_features_summarized.csv")
442
+ if not os.path.exists(output_path) or redo:
443
+ print("Constructing training dataset (final file)")
444
+ dataset_summarized = []
445
+ for midi_file, features in tqdm(midi_file_to_merged_features.items()):
446
+ midi_features = features["midi_features"]
447
+ n_instruments = midi_features["n_instruments"]
448
+ note_density_per_instrument = midi_features["note_density"] / n_instruments
449
+ matched_features = features["matched_features"]
450
+ if matched_features == {}:
451
+ is_matched = False
452
+ valence = float("nan")
453
+ else:
454
+ is_matched = True
455
+ spotify_audio_features = matched_features["spotify_audio_features"]
456
+ if spotify_audio_features is None or spotify_audio_features == "":
457
+ valence = float("nan")
458
+ else:
459
+ if spotify_audio_features["valence"] == 0.0:
460
+ # An unusual number of samples have a valence of 0.0
461
+ # which is possibly due to an error. Feel free to comment out.
462
+ valence = float("nan")
463
+ else:
464
+ valence = spotify_audio_features["valence"]
465
+
466
+ dataset_summarized.append({
467
+ "file": midi_file,
468
+ "is_matched": is_matched,
469
+ "n_instruments": n_instruments,
470
+ "note_density_per_instrument": note_density_per_instrument,
471
+ "valence": valence
472
+ })
473
+ dataset_summarized = pd.DataFrame(dataset_summarized)
474
+ if write:
475
+ dataset_summarized.to_csv(output_path, index=False)
476
+ print(f"Output saved to {output_path}")
midi_emotion/src/create_dataset/utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spotipy
2
+ from spotipy.oauth2 import SpotifyClientCredentials
3
+ import re
4
+ import hashlib
5
+ import json
6
+ import pypianoroll
7
+ import numpy as np
8
+ import pretty_midi
9
+ import csv
10
+
11
+ """
12
+ You'll need a client ID and a client secret:
13
+ https://developer.spotify.com/dashboard/applications
14
+ Then, fill in the variables client_id and client_secret
15
+ """
16
+
17
+ client_id = 'c520641b167a4cd0872d48e5232a41e6'
18
+ client_secret = 'a455993eda164da2b67462c2e1382e91'
19
+ client_credentials_manager = SpotifyClientCredentials(client_id=client_id, client_secret=client_secret)
20
+ sp = spotipy.Spotify(client_credentials_manager=client_credentials_manager)
21
+
22
+ def get_drums_note_density(mid):
23
+ drum_mid = pretty_midi.PrettyMIDI()
24
+ for instrument in mid.instruments:
25
+ if instrument.is_drum:
26
+ drum_mid.instruments.append(instrument)
27
+ if len(drum_mid.instruments) != 1 or len(drum_mid.instruments[0].notes) == 0:
28
+ return float("nan")
29
+ else:
30
+ start_time = drum_mid.instruments[0].notes[0].start
31
+ end_time = drum_mid.instruments[0].notes[-1].end
32
+ duration = end_time - start_time
33
+ n_notes = len(drum_mid.instruments[0].notes)
34
+ density = n_notes / duration
35
+ return density
36
+
37
+ def get_md5(path):
38
+ with open(path, "rb") as f:
39
+ md5 = hashlib.md5(f.read()).hexdigest()
40
+ return md5
41
+
42
+ def get_hash(path):
43
+ if path[-4:] == ".mid":
44
+ try:
45
+ mid = pretty_midi.PrettyMIDI(path)
46
+ except:
47
+ return "empty_pianoroll"
48
+ try:
49
+ rolls = mid.get_piano_roll()
50
+ except:
51
+ return "empty_pianoroll"
52
+ if rolls.size == 0:
53
+ return "empty_pianoroll"
54
+ else:
55
+ pr = pypianoroll.load(path)
56
+ tracks = sorted(pr.tracks, key=lambda x: x.name)
57
+ rolls = [track.pianoroll for track in tracks if track.pianoroll.shape[0] > 0]
58
+ if rolls == []:
59
+ return "empty_pianoroll"
60
+ rolls = np.concatenate(rolls, axis=-1)
61
+ hash_ = hashlib.sha1(np.ascontiguousarray(rolls)).hexdigest()
62
+ return hash_
63
+
64
+ def get_note_density(mid):
65
+ duration = mid.get_end_time()
66
+ n_notes = sum([1 for instrument in mid.instruments for note in instrument.notes])
67
+ density = n_notes / duration
68
+ return density
69
+
70
+ def get_tempo(mid):
71
+ tick_scale = mid._tick_scales[-1][-1]
72
+ resolution = mid.resolution
73
+ beat_duration = tick_scale * resolution
74
+ mid_tempo = 60 / beat_duration
75
+ return mid_tempo
76
+
77
+ def get_n_instruments(mid):
78
+ n_instruments = sum([1 for instrument in mid.instruments if instrument.notes != []])
79
+ return n_instruments
80
+
81
+ def try_multiple(func, *args, **kwargs):
82
+ n_max = 29
83
+ n = 0
84
+ failed = True
85
+ while failed:
86
+ if n > n_max:
87
+ return None
88
+ try:
89
+ if args:
90
+ out = func(*args)
91
+ elif kwargs:
92
+ out = func(**kwargs)
93
+ failed = False
94
+ except Exception as e:
95
+ # print(e.error_description)
96
+ if e.args[0] == 404:
97
+ return None
98
+ else:
99
+ n += 1
100
+ return out
101
+
102
+ def search_spotify(title, artist, album=None):
103
+ query = '"{}"+artist:"{}"'.format(title, artist)
104
+ if album is not None:
105
+ query += '+album:"{}"'.format(album)
106
+ if len(query) <= 250:
107
+ result = try_multiple(sp.search, q=query, type='track')
108
+ items = result['tracks']['items']
109
+ else: # spotify doesnt search with a query longer than 250 characters
110
+ items = []
111
+ return items
112
+
113
+
114
+ def search_spotify_flexible(title, artist, album):
115
+ # Find Spotify URI based on metadata
116
+ items = search_spotify(title, artist, album)
117
+ if items == []:
118
+ items = search_spotify(title, artist)
119
+ if items == []:
120
+ title = fix_string(title)
121
+ items = search_spotify(title, artist)
122
+ if items == []:
123
+ artist = fix_string(artist)
124
+ items = search_spotify(title, artist)
125
+ if items == []:
126
+ artist = strip_artist(artist)
127
+ items = search_spotify(title, artist)
128
+ if items == []:
129
+ return None
130
+
131
+ elif len(items) == 1:
132
+ item = items[0]
133
+ else:
134
+ # Return most popular
135
+ max_popularity = 0
136
+ best_ind = 0
137
+ for i, item in enumerate(items):
138
+ if item is not None:
139
+ if item["popularity"] > max_popularity:
140
+ max_popularity = item["popularity"]
141
+ best_ind = i
142
+ item = items[best_ind]
143
+ return item
144
+
145
+ def matching_strings_flexible(a, b):
146
+ if a == "" or b == "":
147
+ matches = 0.0
148
+ else:
149
+ a = fix_string(a)
150
+ b = fix_string(b)
151
+ a = a.replace("'", "")
152
+ b = b.replace("'", "")
153
+ min_len = min(len(a), len(b))
154
+ matches = 0
155
+ for i in range(min_len):
156
+ if a[i] == b[i]:
157
+ matches += 1
158
+ matches /= min_len
159
+ return matches
160
+
161
+ def get_spotify_features(uri_list):
162
+ features = try_multiple(sp.audio_features, uri_list)
163
+ return features
164
+
165
+ def get_spotify_tracks(uri_list):
166
+ if len(uri_list) > 50:
167
+ uri_list = uri_list[:50]
168
+ tracks = try_multiple(sp.tracks, uri_list)
169
+ if tracks == None:
170
+ return None
171
+ else:
172
+ return tracks["tracks"]
173
+
174
+
175
+ def strip_artist(s):
176
+ s = s.lower() # lowercase
177
+ s = s.replace("the ", "")
178
+ keys = [' - ', '/', ' ft', 'feat', 'featuring', ' and ', ' with ', '_', ' vs', '&', ';', '+']
179
+ for key in keys:
180
+ loc = s.find(key)
181
+ if loc != -1:
182
+ s = s[:loc]
183
+ return s
184
+
185
+ def fix_string(s):
186
+ if s != "":
187
+ s = s.lower() # lowercase
188
+ s = s.replace('\'s', '') # remove 's
189
+ s = s.replace('_', ' ') # remove _
190
+ s = re.sub("[\(\[].*?[\)\]]", "", s) # remove everything in parantheses
191
+ if s[-1] == " ": # remove space at the end
192
+ s = s[:-1]
193
+ return s
194
+
195
+ def logprint(s, f):
196
+ f.write(s + '\n')
197
+
198
+ def get_spotify_ids(json_path):
199
+ with open(json_path) as f_json:
200
+ json_data = json.load(f_json)
201
+ json_data = json_data["response"]["songs"]
202
+ if len(json_data) == 0:
203
+ spotify_ids = []
204
+ else:
205
+ json_data = json_data[0]
206
+ spotify_ids = []
207
+ for track in json_data["tracks"]:
208
+ if track["catalog"] == "spotify" and "foreign_id" in list(track.keys()):
209
+ spotify_ids.append(track["foreign_id"].split(":")[-1])
210
+ return spotify_ids
211
+
212
+ def read_csv(input_file_path, delimiter=","):
213
+ with open(input_file_path, "r") as f_in:
214
+ reader = csv.DictReader(f_in, delimiter=delimiter)
215
+ data = [{key: value for key, value in row.items()} for row in reader]
216
+ return data
midi_emotion/src/data/collate.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ # from torch._six import container_abcs, string_classes, int_classes
4
+ from torch._six import string_classes
5
+ import collections
6
+ """
7
+ Modified by Serkan Sulun
8
+ Filters out None samples
9
+ """
10
+
11
+ """"Contains definitions of the methods used by the _DataLoaderIter workers to
12
+ collate samples fetched from dataset into Tensor(s).
13
+
14
+ These **needs** to be in global scope since Py2 doesn't support serializing
15
+ static methods.
16
+ """
17
+
18
+ _use_shared_memory = False
19
+ r"""Whether to use shared memory in batch_collate"""
20
+
21
+ np_str_obj_array_pattern = re.compile(r'[SaUO]')
22
+
23
+ error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"
24
+
25
+ numpy_type_map = {
26
+ 'float64': torch.DoubleTensor,
27
+ 'float32': torch.FloatTensor,
28
+ 'float16': torch.HalfTensor,
29
+ 'int64': torch.LongTensor,
30
+ 'int32': torch.IntTensor,
31
+ 'int16': torch.ShortTensor,
32
+ 'int8': torch.CharTensor,
33
+ 'uint8': torch.ByteTensor,
34
+ }
35
+
36
+
37
+ def filter_collate(batch):
38
+ r"""Puts each data field into a tensor with outer dimension batch size"""
39
+
40
+ if isinstance(batch, list) or isinstance(batch, tuple):
41
+ batch = [i for i in batch if i is not None] # filter out None s
42
+
43
+ if batch != []:
44
+ elem_type = type(batch[0])
45
+ if isinstance(batch[0], torch.Tensor):
46
+ out = None
47
+ if _use_shared_memory:
48
+ # If we're in a background process, concatenate directly into a
49
+ # shared memory tensor to avoid an extra copy
50
+ numel = sum([x.numel() for x in batch])
51
+ storage = batch[0].storage()._new_shared(numel)
52
+ out = batch[0].new(storage)
53
+ return torch.stack(batch, 0, out=out)
54
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
55
+ and elem_type.__name__ != 'string_':
56
+ elem = batch[0]
57
+ if elem_type.__name__ == 'ndarray':
58
+ # array of string classes and object
59
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
60
+ raise TypeError(error_msg_fmt.format(elem.dtype))
61
+
62
+ return filter_collate([torch.from_numpy(b) for b in batch])
63
+ if elem.shape == (): # scalars
64
+ py_type = float if elem.dtype.name.startswith('float') else int
65
+ return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
66
+ elif isinstance(batch[0], float):
67
+ return torch.tensor(batch, dtype=torch.float64)
68
+ elif isinstance(batch[0], int):
69
+ return torch.tensor(batch)
70
+ elif isinstance(batch[0], string_classes):
71
+ return batch
72
+ elif isinstance(batch[0], collections.abc.Mapping):
73
+ return {key: filter_collate([d[key] for d in batch]) for key in batch[0]}
74
+ elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
75
+ return type(batch[0])(*(filter_collate(samples) for samples in zip(*batch)))
76
+ elif isinstance(batch[0], collections.abc.Sequence):
77
+ transposed = zip(*batch)
78
+ return [filter_collate(samples) for samples in transposed]
79
+
80
+ raise TypeError((error_msg_fmt.format(type(batch[0]))))
81
+ else:
82
+ return batch
midi_emotion/src/data/data_processing.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pypianoroll
2
+ from operator import attrgetter
3
+ import torch
4
+ from copy import deepcopy
5
+ import numpy as np
6
+
7
+ # Forward processing. (Midi to indices)
8
+
9
+ def read_pianoroll(fp, return_tempo=False):
10
+ # Reads pianoroll file and converts to PrettyMidi
11
+ pr = pypianoroll.load(fp)
12
+ mid = pr.to_pretty_midi()
13
+ if return_tempo:
14
+ tempo = np.mean(pr.tempo)
15
+ return mid, tempo
16
+ else:
17
+ return mid
18
+
19
+ def trim_midi(mid_orig, start, end, strict=True):
20
+ """Trims midi file
21
+
22
+ Args:
23
+ mid (PrettyMidi): input midi file
24
+ start (float): start time
25
+ end (float): end time
26
+ strict (bool, optional):
27
+ If false, includes notes that starts earlier than start time,
28
+ and ends later than start time. Or ends later than end time,
29
+ but starts earlier than end time. The start and end times
30
+ are readjusted so they fit into the given boundaries.
31
+ Defaults to True.
32
+
33
+ Returns:
34
+ (PrettyMidi): Trimmed output MIDI.
35
+ """
36
+ eps = 1e-3
37
+ mid = deepcopy(mid_orig)
38
+ for ins in mid.instruments:
39
+ if strict:
40
+ ins.notes = [note for note in ins.notes if note.start >= start and note.end <= end]
41
+ else:
42
+ ins.notes = [note for note in ins.notes \
43
+ if note.end > start + eps and note.start < end - eps]
44
+
45
+ for note in ins.notes:
46
+ if not strict:
47
+ # readjustment
48
+ note.start = max(start, note.start)
49
+ note.end = min(end, note.end)
50
+ # Make the excerpt start at time zero
51
+ note.start -= start
52
+ note.end -= start
53
+ # Filter out empty tracks
54
+ mid.instruments = [ins for ins in mid.instruments if ins.notes]
55
+ return mid
56
+
57
+
58
+ def mid_to_timed_tuples(music, event_sym2idx, min_pitch: int = 21, max_pitch: int = 108):
59
+ # for sorting (though not absolutely necessary)
60
+ on_off_priority = ["ON", "OFF"]
61
+ ins_priority = ["DRUMS", "BASS", "GUITAR", "PIANO", "STRINGS"]
62
+
63
+ on_off_priority = {val: i for i, val in enumerate(on_off_priority)}
64
+ ins_priority = {val: i for i, val in enumerate(ins_priority)}
65
+
66
+ # Add instrument info to notes
67
+ for i, track in enumerate(music.instruments):
68
+ for note in track.notes:
69
+ note.instrument = track.name
70
+
71
+ # Collect notes
72
+ notes = []
73
+ for track in music.instruments:
74
+ notes.extend(track.notes)
75
+
76
+ # Raise an error if no notes is found
77
+ if not notes:
78
+ raise RuntimeError("No notes found.")
79
+
80
+ # Sort the notes
81
+ notes.sort(key=attrgetter("start", "pitch", "duration", "velocity", "instrument"))
82
+
83
+ # Collect note-related events
84
+ note_events = []
85
+
86
+ for note in notes:
87
+ if note.pitch >= min_pitch and note.pitch <= max_pitch:
88
+
89
+ start = round(note.start, 6)
90
+ end = round(note.end, 6)
91
+
92
+ ins = note.instrument.upper()
93
+
94
+ note_events.append((start, on_off_priority["ON"],
95
+ ins_priority[ins], (event_sym2idx["_".join(["ON", ins])], note.pitch)))
96
+ note_events.append((end, on_off_priority["OFF"],
97
+ ins_priority[ins], (event_sym2idx["_".join(["OFF", ins])], note.pitch)))
98
+
99
+ # Sort events by time
100
+ note_events = sorted(note_events)
101
+ note_events = [(note[0], note[-1]) for note in note_events]
102
+ return note_events
103
+
104
+ def timed_tuples_to_tuples(note_events, event_sym2idx, max_timeshift: int = 1000,
105
+ timeshift_step: int = 8):
106
+
107
+ # Create a list for all events
108
+ events = []
109
+ # Initialize the time cursor
110
+ time_cursor = int(round(note_events[0][0] * 1000))
111
+ # Iterate over note events
112
+ for time, symbol in note_events:
113
+ time = int(round(time * 1000))
114
+ if time > time_cursor:
115
+ timeshift = time - time_cursor
116
+ # First split timeshifts longer than max
117
+ n_max = timeshift // max_timeshift
118
+ for _ in range(n_max):
119
+ events.append((event_sym2idx["TIMESHIFT"], max_timeshift))
120
+ # quantize and add remaining
121
+ rem = timeshift % max_timeshift
122
+ if rem > 0:
123
+ # do not round to zero
124
+ rem = int(timeshift_step * round(float(rem) / timeshift_step))
125
+ if rem == 0:
126
+ rem = timeshift_step # do not round to zero
127
+ events.append((event_sym2idx["TIMESHIFT"], rem))
128
+ time_cursor = time
129
+ if symbol[0] != "<": # if not special symbol
130
+ events.append(symbol)
131
+ return events
132
+
133
+
134
+ def list_to_tensor(list_, sym2idx):
135
+ indices = [sym2idx[sym] for sym in list_]
136
+ indices = torch.LongTensor(indices)
137
+ return indices
138
+
139
+
140
+ def mid_to_bars(mid, event_sym2idx):
141
+ """Takes MIDI, extracts bars
142
+ returns ndarray where each row is a token
143
+ each token has two elements,
144
+ first is an index of event, such as DRUMS_OFF, or TIMESHIFT
145
+ second is the value (pitch for note or time for timeshift)
146
+ """
147
+ try:
148
+ bar_times = [round(bar, 6) for bar in mid.get_downbeats()]
149
+ bar_times.append(bar_times[-1] + (bar_times[-1] - bar_times[-2])) # to end
150
+ bar_times.append(bar_times[-1] + (bar_times[-1] - bar_times[-2])) # to end
151
+
152
+ note_events = mid_to_timed_tuples(mid, event_sym2idx)
153
+ i_bar = -1
154
+ i_note = 0
155
+ bars = []
156
+ cur_bar_note_events = []
157
+
158
+ cur_bar_end = -float("inf")
159
+ while i_note < len(note_events):
160
+ time, note = note_events[i_note]
161
+ if time < cur_bar_end:
162
+ cur_bar_note_events.append((time, note))
163
+ i_note += 1
164
+ else:
165
+ cur_bar_note_events.append((cur_bar_end, "<BAR_END>"))
166
+ if len(cur_bar_note_events) > 2:
167
+ events = timed_tuples_to_tuples(cur_bar_note_events, event_sym2idx)
168
+ events = tuples_to_array(events)
169
+ bars.append(events)
170
+ i_bar += 1
171
+ cur_bar_start = bar_times[i_bar]
172
+ cur_bar_end = bar_times[i_bar+1]
173
+ cur_bar_note_events = [(cur_bar_start, "<BAR_START>")]
174
+ except:
175
+ bars = None
176
+ return bars
177
+
178
+ def tuples_to_array(x):
179
+ x = [list(el) for el in x]
180
+ x = np.asarray(x, dtype=np.int16)
181
+ return x
182
+
183
+ def get_maps(min_pitch=21,max_pitch=108,max_timeshift=1000,timeshift_step=8):
184
+ # Get mapping dictionary
185
+ instruments = ["DRUMS", "GUITAR", "BASS", "PIANO", "STRINGS"]
186
+ special_symbols = ["<PAD>", "<START>"]
187
+ on_offs = ["OFF", "ON"]
188
+
189
+ token_syms = deepcopy(special_symbols)
190
+ event_syms = []
191
+ transposable_event_syms = []
192
+
193
+ for ins in instruments:
194
+ for on_off in on_offs:
195
+ event_syms.append(f"{on_off}_{ins}")
196
+ if ins != "DRUMS":
197
+ transposable_event_syms.append(f"{on_off}_{ins}")
198
+ for pitch in range(min_pitch, max_pitch + 1):
199
+ token_syms.append((f"{on_off}_{ins}", pitch))
200
+
201
+ for timeshift in range(timeshift_step, max_timeshift + timeshift_step, timeshift_step):
202
+ token_syms.append(("TIMESHIFT", timeshift))
203
+ event_syms.append("TIMESHIFT")
204
+
205
+ map = {}
206
+
207
+ map["event2idx"] = {sym: idx for idx, sym in enumerate(event_syms)}
208
+ map["idx2event"] = {idx: sym for idx, sym in enumerate(event_syms)}
209
+
210
+ map["tuple2idx"] = {}
211
+ map["idx2tuple"] = {}
212
+ for idx, sym in enumerate(token_syms):
213
+ if isinstance(sym, tuple):
214
+ indexed_tuple = (map["event2idx"][sym[0]], sym[1])
215
+ else:
216
+ indexed_tuple = sym
217
+ map["tuple2idx"][indexed_tuple] = idx
218
+ map["idx2tuple"][idx] = indexed_tuple
219
+
220
+ transposable_event_inds = [map["event2idx"][sym] for sym in transposable_event_syms]
221
+ map["transposable_event_inds"] = transposable_event_inds
222
+ return map
223
+
224
+
225
+ def transpose(x, n, transposable_event_inds, min_pitch = 21, max_pitch = 108):
226
+ # Transpose melody
227
+ for i in range(x.size(0)):
228
+ if x[i, 0].item() in transposable_event_inds and \
229
+ x[i, 1].item() + n <= max_pitch and \
230
+ x[i, 1].item() + n >= min_pitch:
231
+ x[i, 1] += n
232
+ return x
233
+
234
+ def tuples_to_ind_tensor(x, tuple2idx):
235
+ # Tuples to indices
236
+ x = [tuple2idx[el] for el in x]
237
+ x = torch.tensor(x, dtype=torch.int16)
238
+ return x
239
+
240
+ def tensor_to_tuples(x):
241
+ x = [tuple(row.tolist()) for row in x]
242
+ return x
243
+
244
+ def tensor_to_ind_tensor(x, tuple2idx):
245
+ x = tensor_to_tuples(x)
246
+ x = tuples_to_ind_tensor(x, tuple2idx)
247
+ return x
midi_emotion/src/data/data_processing_reverse.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pretty_midi
2
+ import csv
3
+
4
+ # For reverse processing (TOKENS TO MIDI)
5
+
6
+ def tensor_to_tuples(x):
7
+ x = x.tolist()
8
+ x = [tuple(el) for el in x]
9
+ return x
10
+
11
+
12
+ def tuples_to_mid(x, idx2event, verbose=False):
13
+ # Tuples to midi
14
+ instrument_to_program = {"DRUMS": (0, True), "PIANO": (0, False), "GUITAR": (24, False),
15
+ "BASS": (32, False), "STRINGS": (48, False)}
16
+ velocities = {
17
+ "BASS": 127,
18
+ "DRUMS": 120,
19
+ "GUITAR": 95,
20
+ "PIANO": 110,
21
+ "STRINGS": 85,
22
+ }
23
+
24
+ tracks = {}
25
+ for key, val in instrument_to_program.items():
26
+ track = pretty_midi.Instrument(program=val[0], is_drum=val[1], name=key.lower())
27
+ track.notes = []
28
+ tracks.update({key: track})
29
+
30
+ active_notes = {}
31
+
32
+ time_cursor = 0
33
+ for el in x:
34
+ if el[0] != "<": # if not special token
35
+ event = idx2event[el[0]]
36
+ if "TIMESHIFT" == event:
37
+ timeshift = float(el[1])
38
+ time_cursor += timeshift / 1000.0
39
+ else:
40
+ on_off, instrument = event.split("_")
41
+ pitch = int(el[1])
42
+ if on_off == "ON":
43
+ active_notes.update({(instrument, pitch): time_cursor})
44
+ elif (instrument, pitch) in active_notes:
45
+ start = active_notes[(instrument, pitch)]
46
+ end = time_cursor
47
+ tracks[instrument].notes.append(pretty_midi.Note(velocities[instrument], pitch, start, end))
48
+ elif verbose:
49
+ print("Ignoring {:>15s} {:4} because there was no previos ""ON"" event".format(event, pitch))
50
+
51
+ mid = pretty_midi.PrettyMIDI()
52
+ mid.instruments += tracks.values()
53
+ return mid
54
+
55
+
56
+ def ind_tensor_to_tuples(x, ind2tuple):
57
+ # Indices to tuples
58
+ x = [ind2tuple[el.item()] for el in x]
59
+ return x
60
+
61
+ def tuples_to_str(x, idx2event):
62
+ # Tuples to strings
63
+ str_list = []
64
+ for el in x:
65
+ if el[0] == "<": # special token
66
+ str_list.append(el)
67
+ else:
68
+ str_list.append(idx2event[el[0]] + "_" + str(el[1]))
69
+ return str_list
70
+
71
+ def ind_tensor_to_mid(x, idx2tuple, idx2event, verbose=False):
72
+ # Indices to midi
73
+ x = ind_tensor_to_tuples(x, idx2tuple)
74
+ x = tuples_to_mid(x, idx2event, verbose=verbose)
75
+ return x
76
+
77
+ def ind_tensor_to_str(x, idx2tuple, idx2event):
78
+ # Indices to string
79
+ x = ind_tensor_to_tuples(x, idx2tuple)
80
+ x = tuples_to_str(x, idx2event)
81
+ return x
midi_emotion/src/data/loader.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ from data.data_processing import transpose, tensor_to_ind_tensor
5
+ from data.data_processing_reverse import tuples_to_str
6
+ import sys
7
+ sys.path.append("..")
8
+ from utils import get_n_instruments
9
+ import os
10
+
11
+ """
12
+ Main data loader
13
+ """
14
+
15
+ class Loader:
16
+
17
+ def __init__(self, data_folder, data, input_len, conditioning, save_input_dir=None, pad=True,
18
+ use_start_token=True, use_end_token=False, max_transpose=3, n_try=5,
19
+ bar_start_prob=0.5, debug=False, overfit=False, regression=False,
20
+ max_samples=None, min_n_instruments=3, use_cls_token=True,
21
+ always_use_discrete_condition=False):
22
+
23
+ self.data_folder = data_folder
24
+ self.bar_start_prob = bar_start_prob
25
+ self.save_input_dir = save_input_dir
26
+ self.input_len = input_len
27
+ self.n_try = n_try # max number of trials to find suitable sample
28
+ self.min_n_instruments = min_n_instruments
29
+ self.overfit = overfit
30
+ self.one_sample = None
31
+ self.transpose_options = list(range(-max_transpose, max_transpose + 1))
32
+ self.conditioning = conditioning
33
+ self.regression = regression
34
+ self.use_cls_token = use_cls_token
35
+ self.pad = pad
36
+ self.always_use_discrete_condition = always_use_discrete_condition
37
+
38
+ self.pad_token = '<PAD>' if pad else None
39
+ self.start_token = '<START>' if use_start_token else None
40
+ self.end_token = '<END>' if use_end_token else None
41
+ self.cls_token = "<CLS>"
42
+
43
+ if debug or overfit:
44
+ data_folder = data_folder + "_debug"
45
+
46
+ self.data = data
47
+
48
+ data_files = os.listdir(self.data_folder)
49
+ self.data = [sample for sample in self.data if sample["file"] + '.pt' in data_files]
50
+
51
+ maps_file = os.path.join(os.path.abspath(data_folder + "/.."), "maps.pt")
52
+ self.maps = torch.load(maps_file)
53
+
54
+ extra_tokens = []
55
+ if self.conditioning == "continuous_token":
56
+ # two condition tokens will be concatenated later
57
+ self.input_len -= 2
58
+ elif self.conditioning == "discrete_token":
59
+ # add emotion tokens to mappings
60
+ for sample in self.data:
61
+ for label in ["valence", "arousal"]:
62
+ token = sample[label]
63
+ if token not in extra_tokens:
64
+ extra_tokens.append(token)
65
+ extra_tokens = sorted(extra_tokens)
66
+
67
+ if self.regression and self.use_cls_token:
68
+ extra_tokens.append(self.cls_token)
69
+
70
+ if extra_tokens != []:
71
+ # add to maps
72
+ maps_list = list(self.maps["idx2tuple"].values())
73
+ maps_list += extra_tokens
74
+ self.maps["idx2tuple"] = {i: val for i, val in enumerate(maps_list)}
75
+ self.maps["tuple2idx"] = {val: i for i, val in enumerate(maps_list)}
76
+
77
+ if max_samples is not None and not debug and not overfit:
78
+ self.data = self.data[:max_samples]
79
+
80
+ # roughly / 256, but *4 for flexibility. it is later cut anyway
81
+ self.n_bars = max(round(input_len / 256 * 4), 1)
82
+
83
+
84
+ def get_vocab_len(self):
85
+ return len(self.maps["tuple2idx"])
86
+
87
+ def get_maps(self):
88
+ return self.maps
89
+
90
+ def get_pad_idx(self):
91
+ return self.maps["tuple2idx"][self.pad_token]
92
+
93
+ def __len__(self):
94
+ return len(self.data)
95
+
96
+ def __getitem__(self, idx):
97
+
98
+ if not self.overfit or self.one_sample is None:
99
+ data_path = os.path.join(self.data_folder, self.data[idx]["file"] + ".pt")
100
+ item = torch.load(data_path)
101
+ all_bars = item["bars"]
102
+
103
+ n_instruments = 0
104
+ j = 0
105
+ while j < self.n_try and n_instruments < self.min_n_instruments:
106
+ # make sure to have n many instruments
107
+ # choose random bar
108
+ max_bar_start_idx = max(0, len(all_bars) - self.n_bars - 1)
109
+ bar_start_idx = random.randint(0, max_bar_start_idx)
110
+ bar_end_idx = min(len(all_bars), bar_start_idx + self.n_bars)
111
+ bars = all_bars[bar_start_idx:bar_end_idx]
112
+ # flatten
113
+ if bars != []:
114
+ bars = torch.cat(bars, dim=0)
115
+ symbols = tuples_to_str(bars.cpu().numpy(), self.maps["idx2event"])
116
+ n_instruments = get_n_instruments(symbols)
117
+ else:
118
+ n_instruments = 0
119
+
120
+ j += 1
121
+ if n_instruments < self.min_n_instruments:
122
+ return None, None, None
123
+
124
+ # transpose
125
+ if self.transpose_options != []:
126
+ n_transpose = random.choice(self.transpose_options)
127
+ bars = transpose(bars, n_transpose,
128
+ self.maps["transposable_event_inds"])
129
+
130
+ # convert to indices (final input)
131
+ bars = tensor_to_ind_tensor(bars, self.maps["tuple2idx"])
132
+
133
+ # Decide taking the sample from the start of a bar or not
134
+ r = np.random.uniform()
135
+
136
+ start_at_beginning = not (r > self.bar_start_prob and bars.size(0) > self.input_len)
137
+
138
+ if start_at_beginning:
139
+ # starts exactly at bar location
140
+ if self.start_token is not None:
141
+ # add start token
142
+ start_idx = torch.ShortTensor(
143
+ [self.maps["tuple2idx"][self.start_token]])
144
+ bars = torch.cat((start_idx, bars), dim=0)
145
+ else:
146
+ # it doesn't have to start at bar location so shift arbitrarily
147
+ start = np.random.randint(0, bars.size(0)-self.input_len)
148
+ bars = bars[start:start+self.input_len+1]
149
+
150
+ if self.regression and self.use_cls_token:
151
+ # prepend <CLS> token
152
+ cls_idx = torch.ShortTensor(
153
+ [self.maps["tuple2idx"][self.cls_token]])
154
+ bars = torch.cat((cls_idx, bars), 0)
155
+
156
+ # for now, no auxiliary conditions
157
+ condition = torch.FloatTensor([np.nan, np.nan])
158
+ if self.conditioning == "discrete_token" and \
159
+ (start_at_beginning or self.always_use_discrete_condition):
160
+ # add emotion tokens
161
+ valence, arousal = self.data[idx]["valence"], self.data[idx]["arousal"]
162
+ valence = torch.ShortTensor([self.maps["tuple2idx"][valence]])
163
+ arousal = torch.ShortTensor([self.maps["tuple2idx"][arousal]])
164
+ bars = torch.cat((valence, arousal, bars), dim=0)
165
+ elif self.conditioning in ("continuous_token", "continuous_concat") or self.regression:
166
+ # continuous conditions
167
+ condition = torch.FloatTensor([self.data[idx]["valence"], self.data[idx]["arousal"]])
168
+
169
+ bars = bars[:self.input_len + 1] # trim to length, +1 to include target
170
+
171
+ if self.pad_token is not None:
172
+ n_pad = self.input_len + 1 - bars.shape[0]
173
+ if n_pad > 0:
174
+ # pad if necessary
175
+ bars = torch.nn.functional.pad(bars, (0, n_pad), value=self.get_pad_idx())
176
+
177
+ bars = bars.long() # to int32
178
+ input_ = bars[:-1]
179
+
180
+ if self.regression:
181
+ target = None # will use condition as target
182
+ else:
183
+ target = bars[1:]
184
+ if self.conditioning == "continuous_token":
185
+ # pad target from left, because input will get conditions concatenated
186
+ # their sizes should match
187
+ target = torch.nn.functional.pad(target, (condition.size(0), 0), value=self.get_pad_idx())
188
+
189
+ if self.overfit:
190
+ self.one_sample = [input_, condition, target]
191
+ else:
192
+ # sanity check, using one sample repeatedly
193
+ input_, condition, target = self.one_sample
194
+
195
+ return input_, condition, target
196
+
197
+
198
+
199
+
200
+
201
+
202
+
203
+
204
+
205
+
206
+
midi_emotion/src/data/loader_exhaustive.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+ from data.data_processing import tensor_to_ind_tensor
5
+ import sys
6
+ sys.path.append("..")
7
+
8
+ import os
9
+
10
+ """
11
+ Loads ALL data for exhaustive evaluation
12
+ """
13
+
14
+ class LoaderExhaustive:
15
+
16
+ def __init__(self, data_folder, data, input_len, conditioning, save_input_dir=None, pad=True,
17
+ use_start_token=True, use_end_token=False, always_use_discrete_condition=False,
18
+ debug=False, overfit=False, regression=False,
19
+ max_samples=None, use_cls_token=True):
20
+
21
+ self.data_folder = data_folder
22
+ self.save_input_dir = save_input_dir
23
+ self.input_len = input_len
24
+ self.overfit = overfit
25
+ self.one_sample = None
26
+ self.conditioning = conditioning
27
+ self.regression = regression
28
+
29
+
30
+ if debug or overfit:
31
+ data_folder = data_folder + "_debug"
32
+
33
+ self.data = data
34
+
35
+ maps_file = os.path.join(data_folder, "maps.pt")
36
+ self.maps = torch.load(maps_file)
37
+
38
+ self.pad_token = '<PAD>' if pad else None
39
+ self.start_token = '<START>' if use_start_token else None
40
+ self.end_token = '<END>' if use_end_token else None
41
+ self.cls_token = "<CLS>"
42
+
43
+
44
+ extra_tokens = []
45
+ if self.conditioning == "continuous_token":
46
+ # two condition tokens will be concatenated later
47
+ self.input_len -= 2
48
+ elif self.conditioning == "discrete_token":
49
+ # two condition tokens will be concatenated later
50
+ self.input_len -= 2
51
+ # add emotion tokens to mappings
52
+ for sample in self.data:
53
+ for label in ["valence", "arousal"]:
54
+ token = sample[label]
55
+ if token not in extra_tokens:
56
+ extra_tokens.append(token)
57
+ extra_tokens = sorted(extra_tokens)
58
+
59
+ if self.regression and use_cls_token:
60
+ extra_tokens.append(self.cls_token)
61
+ self.input_len -= 1 # cls token
62
+
63
+ if self.regression:
64
+ chunk_len = self.input_len
65
+ else:
66
+ # +1 for target
67
+ chunk_len = self.input_len + 1
68
+
69
+ if extra_tokens != []:
70
+ # add to maps
71
+ maps_list = list(self.maps["idx2tuple"].values())
72
+ maps_list += extra_tokens
73
+ self.maps["idx2tuple"] = {i: val for i, val in enumerate(maps_list)}
74
+ self.maps["tuple2idx"] = {val: i for i, val in enumerate(maps_list)}
75
+
76
+ if max_samples is not None and not debug and not overfit:
77
+ self.data = self.data[:max_samples]
78
+
79
+ # Chunk entire data
80
+ chunked_data = []
81
+ print('Constructing data loader...')
82
+ for i in tqdm(range(len(self.data))):
83
+
84
+ data_path = os.path.join(data_folder, "lpd_5_full_transposable", self.data[i]["file"] + ".pt")
85
+ item = torch.load(data_path)
86
+ song = item["bars"]
87
+
88
+ if self.conditioning != 'none' or self.regression:
89
+ valence = self.data[i]["valence"]
90
+ arousal = self.data[i]["arousal"]
91
+
92
+ if self.conditioning in ("continuous_token", "continuous_concat") or self.regression:
93
+ condition = torch.FloatTensor([valence, arousal])
94
+ else:
95
+ condition = torch.FloatTensor([np.nan, np.nan])
96
+
97
+ song = torch.cat(song, 0)
98
+ song = tensor_to_ind_tensor(song, self.maps["tuple2idx"])
99
+ if self.start_token is not None:
100
+ # add start token
101
+ start_idx = torch.ShortTensor(
102
+ [self.maps["tuple2idx"][self.start_token]])
103
+ song = torch.cat((start_idx, song), 0)
104
+
105
+ if self.conditioning == "discrete_token":
106
+ condition_tokens = torch.ShortTensor([
107
+ self.maps["tuple2idx"][valence],
108
+ self.maps["tuple2idx"][arousal]])
109
+ if not always_use_discrete_condition:
110
+ song = torch.cat((condition_tokens, song), 0)
111
+
112
+ # split song into chunks
113
+ song = list(torch.split(song, chunk_len)) # +1 for target
114
+ if song[-1].size(0) != chunk_len:
115
+ song.pop(-1)
116
+
117
+ if self.regression and use_cls_token:
118
+ # prepend <CLS> token
119
+ cls_idx = torch.ShortTensor(
120
+ [self.maps["tuple2idx"][self.cls_token]])
121
+
122
+ song = [torch.cat((cls_idx, x), 0) for x in song]
123
+
124
+ if self.conditioning == "discrete_token" and always_use_discrete_condition:
125
+ song = [torch.cat((condition_tokens, x), 0) for x in song]
126
+
127
+ song = [(x, condition) for x in song]
128
+
129
+ chunked_data += song
130
+
131
+ self.data = chunked_data
132
+ print('Data loader constructed.')
133
+
134
+ def get_vocab_len(self):
135
+ return len(self.maps["tuple2idx"])
136
+
137
+ def get_maps(self):
138
+ return self.maps
139
+
140
+ def get_pad_idx(self):
141
+ return self.maps["tuple2idx"][self.pad_token]
142
+
143
+ def __len__(self):
144
+ return len(self.data)
145
+
146
+ def __getitem__(self, idx):
147
+ chunk, condition = self.data[idx]
148
+ chunk = chunk.long()
149
+
150
+ if self.regression:
151
+ input_ = chunk
152
+ target = None # will use condition as target
153
+ else:
154
+ input_ = chunk[:-1]
155
+ target = chunk[1:]
156
+
157
+ if self.conditioning == "continuous_token":
158
+ # pad target from left, because input will get conditions concatenated
159
+ # their sizes should match
160
+ target = torch.nn.functional.pad(target, (condition.size(0), 0), value=self.get_pad_idx())
161
+
162
+ return input_, condition, target
163
+
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
midi_emotion/src/data/loader_generations.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import os
3
+ from tkinter import TRUE
4
+ import torch
5
+ import sys
6
+ sys.path.append("..")
7
+
8
+ """
9
+ Data loader to perform regression on a folder with generations
10
+ """
11
+
12
+ class LoaderGenerations:
13
+
14
+ def __init__(self, gen_folder, seq_len, pad=True, use_start_token=True, use_end_token=False,
15
+ use_cls_token=TRUE, overlap=0.5):
16
+
17
+ self.seq_len = seq_len
18
+ self.one_sample = None
19
+
20
+ self.pad = pad
21
+
22
+ self.pad_token = '<PAD>' if pad else None
23
+ self.start_token = '<START>' if use_start_token else None
24
+ self.end_token = '<END>' if use_end_token else None
25
+ self.cls_token = "<CLS>" if use_cls_token else None
26
+
27
+ data_paths = glob(os.path.join("../output", gen_folder, "*.pt"), recursive=True)
28
+
29
+ maps = torch.load("../datasets/lpd_5/w_emotion_transposable/maps.pt")
30
+ n_vocab = len(maps["tuple2idx"])
31
+
32
+ self.data = []
33
+
34
+ if self.cls_token is not None:
35
+ seq_len -= 1
36
+ if self.cls_token not in maps["tuple2idx"].keys():
37
+ # add <CLS> token to vobac
38
+ maps["tuple2idx"][self.cls_token] = len(maps["idx2tuple"])
39
+ maps["idx2tuple"][len(maps["idx2tuple"])] = self.cls_token
40
+ # prepend <CLS> token
41
+ cls_idx = torch.ShortTensor(
42
+ [maps["tuple2idx"][self.cls_token]])
43
+
44
+ for data_path in data_paths:
45
+ generation = torch.load(data_path)
46
+ inds = generation["inds"]
47
+ # remove special tokens
48
+ inds = inds[inds < n_vocab]
49
+ # split with overlap
50
+ inds = inds.unfold(0, seq_len, int(seq_len*(1-overlap)))
51
+ inds = list(torch.split(inds, 1, dim=0))
52
+ inds = [sample.squeeze() for sample in inds]
53
+
54
+ if self.cls_token is not None:
55
+ inds = [torch.cat((cls_idx, sample), dim=0) for sample in inds]
56
+
57
+ condition = generation["condition"]
58
+ if inds[-1].size(0) != seq_len:
59
+ inds.pop()
60
+ self.data += [(sample, condition) for sample in inds]
61
+
62
+
63
+ self.discrete2continuous = {
64
+ "-2": -0.8,
65
+ "-1": -0.4,
66
+ "0": 0,
67
+ "1": 0.4,
68
+ "2": 0.8
69
+ }
70
+
71
+
72
+ def get_vocab_len(self):
73
+ return None
74
+
75
+ def get_maps(self):
76
+ return None
77
+
78
+ def get_pad_idx(self):
79
+ return None
80
+
81
+ def __len__(self):
82
+ return len(self.data)
83
+
84
+ def __getitem__(self, idx):
85
+
86
+ input_, condition = self.data[idx]
87
+ if input_.size(0) != self.seq_len:
88
+ Warning(f"Input length is {input_.size(0)}")
89
+ return None, None, None
90
+ if isinstance(condition[0], str):
91
+ condition = condition[:2]
92
+ for i in range(len(condition)):
93
+ condition[i] = self.discrete2continuous[condition[i][2:-1]]
94
+ condition = torch.Tensor(condition)
95
+
96
+ input_ = input_.cpu()
97
+ condition = condition.cpu()
98
+ return input_, condition, None
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
midi_emotion/src/data/preprocess_features.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+ def preprocess_features(feature_file, n_bins=None, min_n_instruments=3,
5
+ test_ratio=0.05, outlier_range=1.5, conditional=True,
6
+ use_labeled_only=True):
7
+
8
+ # Preprocess data
9
+ data = pd.read_csv(feature_file)
10
+ mapper = {"valence": "valence", "note_density_per_instrument": "arousal"}
11
+ data = data.rename(columns=mapper)
12
+ columns = data.columns.to_list()
13
+
14
+ # filter out ones with less instruments
15
+ data = data[data["n_instruments"] >= min_n_instruments]
16
+ # filter out ones with zero valence
17
+ data = data[data["valence"] != 0]
18
+
19
+ # filter out outliers
20
+ feature_labels = list(mapper.values())
21
+ outlier_indices = []
22
+ for label in feature_labels:
23
+ series = data[label]
24
+ q1 = series.quantile(0.25)
25
+ q3 = series.quantile(0.75)
26
+ iqr = q3 - q1
27
+ upper_limit = q3 + outlier_range * iqr
28
+ lower_limit = q1 - outlier_range * iqr
29
+
30
+ outlier_indices += series[series < lower_limit].index.to_list()
31
+ outlier_indices += series[series > upper_limit].index.to_list()
32
+ data.drop(outlier_indices, inplace=True)
33
+
34
+ # shift and scale features between -1 and 1
35
+ for label in feature_labels:
36
+ series = data[label]
37
+ min_ = series.min()
38
+ max_ = series.max()
39
+
40
+ data[label] = (data[label] - min_) / (max_ - min_) * 2 - 1
41
+
42
+ if n_bins is not None:
43
+ # digitize into bins using quantiles
44
+ quantile_indices = np.linspace(0, 1, n_bins+1)
45
+ for label in feature_labels:
46
+
47
+ # create token labels
48
+ if n_bins % 2 == 0:
49
+ bin_ids = list(range(-n_bins//2, 0)) + list(range(1, n_bins//2+1))
50
+ else:
51
+ bin_ids = list(range(-(n_bins-1)//2, (n_bins-1)//2 + 1))
52
+ token_labels = ["<{}{}>".format(label[0].upper(), bin_id) \
53
+ for bin_id in bin_ids]
54
+ # additional label for NaN (missing) values: <V>
55
+ token_labels.append(None) # to handle NaNs
56
+
57
+ series = data[label]
58
+ quantiles = [series.quantile(q) for q in quantile_indices]
59
+ quantiles[-1] += 1e-6
60
+ series = series.to_numpy()
61
+ series_digitized = np.digitize(series, quantiles)
62
+ series_tokenized = [token_labels[i-1] for i in series_digitized]
63
+
64
+ data[label] = series_tokenized
65
+ else:
66
+ # convert NaN into None
67
+ data = data.where(pd.notnull(data), None)
68
+
69
+ # Create train and test splits
70
+ matched = data[data["is_matched"]]
71
+ unmatched = data[~data["is_matched"]]
72
+
73
+ # reserve a portion of matched data for testing
74
+ matched = matched.sort_values("file")
75
+ matched = matched.reset_index(drop=True)
76
+ n_test_samples = round(len(matched) * test_ratio)
77
+
78
+ test_split = matched.loc[len(matched)-n_test_samples:len(matched)]
79
+
80
+ train_split = matched.loc[:len(matched)-n_test_samples]
81
+
82
+ if not use_labeled_only:
83
+ train_split = pd.concat([train_split, unmatched])
84
+ train_split = train_split.sort_values("file").reset_index(drop=True)
85
+
86
+ splits = [train_split, test_split]
87
+
88
+ # summarize
89
+ columns_to_drop = [col for col in columns if col not in ["file", "valence", "arousal"]]
90
+ if not conditional:
91
+ columns_to_drop += ["valence", "arousal"]
92
+
93
+ # filter data so all features are valid (not None = matched data)
94
+ for label in feature_labels:
95
+ # test split has to be identical across vanilla and conditional models
96
+ splits[1] = splits[1][~splits[1][label].isnull()]
97
+
98
+ # filter train split only for conditional models
99
+ if use_labeled_only:
100
+ splits[0] = splits[0][~splits[0][label].isnull()]
101
+
102
+ for i in range(len(splits)):
103
+ # summarize
104
+ splits[i] = splits[i].drop(columns=columns_to_drop, errors="ignore")
105
+ splits[i] = splits[i].to_dict("records")
106
+
107
+ return splits
midi_emotion/src/data/preprocess_pianorolls.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from data_processing import read_pianoroll, mid_to_bars, get_maps
3
+ import torch
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ from concurrent.futures import ProcessPoolExecutor
7
+ import time
8
+ from functools import partial
9
+ import os
10
+
11
+ """ Preprocessing Lakh MIDI pianoroll dataset.
12
+ Divides into bars. Encodes into tuples. Makes transposing easier. """
13
+
14
+ def run(f, my_iter):
15
+ with ProcessPoolExecutor(max_workers=16) as executor:
16
+ results = list(tqdm(executor.map(f, my_iter), total=len(my_iter)))
17
+ return results
18
+
19
+ def get_emotion_dict(path):
20
+ table = pd.read_csv(path)
21
+ table = table.to_dict(orient="records")
22
+ table = {item["path"].split("/")[-2]: \
23
+ {"valence": item["valence"], "energy": item["energy"], "tempo": item["tempo"]} \
24
+ for item in table}
25
+ return table
26
+
27
+ def process(pr_path, event_sym2idx):
28
+ time.sleep(0.001)
29
+ mid = read_pianoroll(pr_path)
30
+
31
+ bars = mid_to_bars(mid, event_sym2idx)
32
+
33
+ file_ = pr_path.split("/")[-1]
34
+
35
+ item_data = {
36
+ "file": file_,
37
+ "bars": bars,
38
+ }
39
+
40
+ return item_data
41
+
42
+ def main():
43
+
44
+ main_dir = "../../data_files/lpd_5"
45
+ input_dir = "../../data_files/lpd_5/lpd_5_full"
46
+ unique_pr_list_file = "../../data_files/features/pianoroll/unique_files.json"
47
+
48
+ output_dir = os.path.join(main_dir, "lpd_5_full_transposable")
49
+
50
+ os.makedirs(output_dir, exist_ok=True)
51
+ output_maps_path = os.path.join(main_dir, "maps.pt")
52
+
53
+ with open(unique_pr_list_file, "r") as f:
54
+ pr_paths = json.load(f)
55
+
56
+ pr_paths = [os.path.join(input_dir, pr_path[0], pr_path + ".npz") for pr_path in pr_paths]
57
+
58
+ maps = get_maps()
59
+
60
+ func = partial(process, event_sym2idx=maps["event2idx"])
61
+
62
+ os.makedirs(output_dir, exist_ok=True)
63
+
64
+ x = run(func, pr_paths)
65
+ x = [item for item in x if item["bars"] is not None]
66
+ for i in tqdm(range(len(x))):
67
+ for j in range(len(x[i]["bars"])):
68
+ x[i]["bars"][j] = torch.from_numpy(x[i]["bars"][j])
69
+ fname = x[i]["file"]
70
+ output_path = os.path.join(output_dir, fname.replace(".npz", ".pt"))
71
+ torch.save(x[i], output_path)
72
+
73
+ torch.save(maps, output_maps_path)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
78
+
79
+
80
+
81
+
82
+
midi_emotion/src/generate.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from copy import deepcopy
3
+ import os
4
+ import sys
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import datetime
9
+ from tqdm import tqdm
10
+ from .utils import get_n_instruments
11
+ from .models.build_model import build_model
12
+ from .data.data_processing_reverse import ind_tensor_to_mid, ind_tensor_to_str
13
+
14
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
15
+
16
+ def chunks(lst, n):
17
+ """Yield successive n-sized chunks from lst."""
18
+ for i in range(0, len(lst), n):
19
+ yield lst[i:i + n]
20
+
21
+ def generate(model, maps, device, out_dir, conditioning, short_filename=False,
22
+ penalty_coeff=0.5, discrete_conditions=None, continuous_conditions=None,
23
+ max_input_len=1024, amp=True, step=None,
24
+ gen_len=2048, temperatures=[1.2,1.2], top_k=-1,
25
+ top_p=0.7, debug=False, varying_condition=None, seed=-1,
26
+ verbose=False, primers=[["<START>"]], min_n_instruments=2):
27
+
28
+ if not debug:
29
+ os.makedirs(out_dir, exist_ok=True)
30
+
31
+ model = model.to(device)
32
+ model.eval()
33
+
34
+ assert len(temperatures) in (1, 2)
35
+
36
+ if varying_condition is not None:
37
+ batch_size = varying_condition[0].size(0)
38
+ else:
39
+ try:
40
+ continuous_conditions = torch.FloatTensor(continuous_conditions).to(device)
41
+ except:
42
+ continuous_conditions = None
43
+ if conditioning == "none":
44
+ batch_size = len(primers)
45
+ elif conditioning == "discrete_token":
46
+ assert discrete_conditions is not None
47
+ discrete_conditions_tensor = [[maps["tuple2idx"][symbol] for symbol in condition_sample] \
48
+ for condition_sample in discrete_conditions]
49
+ discrete_conditions_tensor = torch.LongTensor(discrete_conditions_tensor).t().to(device)
50
+ batch_size = discrete_conditions_tensor.size(1)
51
+
52
+ elif conditioning in ("continuous_token", "continuous_concat"):
53
+ batch_size = len(continuous_conditions)
54
+
55
+ # will be used to penalize repeats
56
+ repeat_counts = [0 for _ in range(batch_size)]
57
+
58
+ exclude_symbols = [symbol for symbol in maps["tuple2idx"].keys() if symbol[0] == "<"]
59
+
60
+ # will have generated symbols and indices
61
+ gen_song_tensor = torch.LongTensor([]).to(device)
62
+
63
+ if not isinstance(primers, list):
64
+ primers = [[primers]]
65
+ primer_inds = [[maps["tuple2idx"][symbol] for symbol in primer] \
66
+ for primer in primers]
67
+
68
+ gen_inds = torch.LongTensor(primer_inds)
69
+
70
+ null_conditions_tensor = torch.FloatTensor([np.nan, np.nan]).to(device)
71
+
72
+ if len(primers) == 1:
73
+ gen_inds = gen_inds.repeat(batch_size, 1)
74
+ null_conditions_tensor = null_conditions_tensor.repeat(batch_size, 1)
75
+
76
+ if conditioning == "continuous_token":
77
+ max_input_len -= 2
78
+ conditions_tensor = continuous_conditions
79
+ elif conditioning == "continuous_concat":
80
+ conditions_tensor = continuous_conditions
81
+ elif conditioning == "discrete_token":
82
+ max_input_len -= discrete_conditions_tensor.size(0)
83
+ conditions_tensor = null_conditions_tensor
84
+ else:
85
+ conditions_tensor = null_conditions_tensor
86
+
87
+ if varying_condition is not None:
88
+ varying_condition[0] = varying_condition[0].to(device)
89
+ varying_condition[1] = varying_condition[1].to(device)
90
+
91
+ gen_inds = gen_inds.t().to(device)
92
+
93
+ with torch.no_grad():
94
+ pbar = tqdm(total=gen_len, desc="Generating tokens", leave=True)
95
+ i = 0
96
+ while i < gen_len:
97
+ i += 1
98
+ pbar.update(1)
99
+
100
+ gen_song_tensor = torch.cat((gen_song_tensor, gen_inds), 0)
101
+
102
+ input_ = gen_song_tensor
103
+ if len(gen_song_tensor) > max_input_len:
104
+ input_ = input_[-max_input_len:, :]
105
+
106
+ if conditioning == "discrete_token":
107
+ # concat with conditions
108
+ input_ = torch.cat((discrete_conditions_tensor, input_), 0)
109
+
110
+ # INTERPOLATED CONDITIONS
111
+ if varying_condition is not None:
112
+ valences = varying_condition[0][:, i-1]
113
+ arousals = varying_condition[1][:, i-1]
114
+ conditions_tensor = torch.cat([valences[:, None], arousals[:, None]], dim=-1)
115
+
116
+ # Run model
117
+ with torch.cuda.amp.autocast(enabled=amp):
118
+ input_ = input_.t()
119
+ output = model(input_, conditions_tensor)
120
+ output = output.permute((1, 0, 2))
121
+
122
+ # Process output, get predicted token
123
+ output = output[-1, :, :] # Select last timestep
124
+ output[output != output] = 0 # zeroing nans
125
+
126
+ if torch.all(output == 0) and verbose:
127
+ # if everything becomes zero
128
+ print("All predictions were NaN during generation")
129
+ output = torch.ones(output.shape).to(device)
130
+
131
+ # exclude certain symbols
132
+ for symbol_exclude in exclude_symbols:
133
+ try:
134
+ idx_exclude = maps["tuple2idx"][symbol_exclude]
135
+ output[:, idx_exclude] = -float("inf")
136
+ except:
137
+ pass
138
+
139
+ effective_temps = []
140
+ for j in range(batch_size):
141
+ gen_idx = gen_inds[0, j].item()
142
+ gen_tuple = maps["idx2tuple"][gen_idx]
143
+ effective_temp = temperatures[1]
144
+ if isinstance(gen_tuple, tuple):
145
+ gen_event = maps["idx2event"][gen_tuple[0]]
146
+ if "TIMESHIFT" in gen_event:
147
+ # switch from rest temperature to note temperature
148
+ effective_temp = temperatures[0]
149
+ effective_temps.append(effective_temp)
150
+
151
+ temp_tensor = torch.Tensor([effective_temps]).to(device)
152
+
153
+ output = F.log_softmax(output, dim=-1)
154
+
155
+ # Add repeat penalty to temperature
156
+ if penalty_coeff > 0:
157
+ repeat_counts_array = torch.Tensor(repeat_counts).to(device)
158
+ temp_multiplier = torch.maximum(torch.zeros_like(repeat_counts_array, device=device),
159
+ torch.log((repeat_counts_array+1)/4)*penalty_coeff)
160
+ repeat_penalties = temp_multiplier * temp_tensor
161
+ temp_tensor += repeat_penalties
162
+
163
+ # Apply temperature
164
+ output /= temp_tensor.t()
165
+
166
+ # top-k
167
+ if top_k <= 0 or top_k > output.size(-1):
168
+ top_k_eff = output.size(-1)
169
+ else:
170
+ top_k_eff = top_k
171
+ output, top_inds = torch.topk(output, top_k_eff)
172
+
173
+ # top-p
174
+ if top_p > 0 and top_p < 1:
175
+ cumulative_probs = torch.cumsum(F.softmax(output, dim=-1), dim=-1)
176
+ remove_inds = cumulative_probs > top_p
177
+ remove_inds[:, 0] = False # at least keep top value
178
+ output[remove_inds] = -float("inf")
179
+
180
+ output = F.softmax(output, dim=-1)
181
+
182
+ # Sample from probabilities
183
+ inds_sampled = torch.multinomial(output, 1, replacement=True)
184
+ gen_inds = top_inds.gather(1, inds_sampled).t()
185
+
186
+ # Update repeat counts
187
+ num_choices = torch.sum((output > 0).int(), -1)
188
+ for j in range(batch_size):
189
+ if num_choices[j] <= 2: repeat_counts[j] += 1
190
+ else: repeat_counts[j] = repeat_counts[j] // 2
191
+
192
+ pbar.close()
193
+
194
+ # Convert to midi and save
195
+ print("\nConverting to MIDI...")
196
+
197
+ # If there are less than n instruments, repeat generation for specific condition
198
+ redo_primers, redo_discrete_conditions, redo_continuous_conditions = [], [], []
199
+ for i in range(gen_song_tensor.size(-1)):
200
+ if short_filename:
201
+ out_file_path = f"{i}"
202
+ else:
203
+ if step is None:
204
+ now = datetime.datetime.now()
205
+ out_file_path = now.strftime("%Y_%m_%d_%H_%M_%S")
206
+ else:
207
+ out_file_path = step
208
+
209
+ out_file_path += f"_{i}"
210
+
211
+ if seed > 0:
212
+ out_file_path += f"_s{seed}"
213
+
214
+ if continuous_conditions is not None:
215
+ condition = continuous_conditions[i, :].tolist()
216
+ # convert to string
217
+ condition = [str(round(c, 2)).replace(".", "") for c in condition]
218
+ out_file_path += f"_V{condition[0]}_A{condition[1]}"
219
+
220
+ out_file_path += ".mid"
221
+ out_path_mid = os.path.join(out_dir, out_file_path)
222
+
223
+ symbols = ind_tensor_to_str(gen_song_tensor[:, i], maps["idx2tuple"], maps["idx2event"])
224
+ n_instruments = get_n_instruments(symbols)
225
+
226
+ if n_instruments >= min_n_instruments:
227
+ mid = ind_tensor_to_mid(gen_song_tensor[:, i], maps["idx2tuple"], maps["idx2event"], verbose=False)
228
+ out_path_txt = "txt_" + out_file_path.replace(".mid", ".txt")
229
+ out_path_txt = os.path.join(out_dir, out_path_txt)
230
+ out_path_inds = "inds_" + out_file_path.replace(".mid", ".pt")
231
+ out_path_inds = os.path.join(out_dir, out_path_inds)
232
+
233
+ if not debug:
234
+ mid.write(out_path_mid)
235
+ if verbose:
236
+ print(f"Saved to {out_path_mid}")
237
+ else:
238
+ print(f"Only has {n_instruments} instruments, not saving.")
239
+ if conditioning == "none":
240
+ redo_primers.append(primers[i])
241
+ redo_discrete_conditions = None
242
+ redo_continuous_conditions = None
243
+ elif conditioning == "discrete_token":
244
+ redo_discrete_conditions.append(discrete_conditions[i])
245
+ redo_continuous_conditions = None
246
+ redo_primers = primers
247
+ else:
248
+ redo_discrete_conditions = None
249
+ redo_continuous_conditions.append(continuous_conditions[i, :].tolist())
250
+ redo_primers = primers
251
+
252
+ return redo_primers, redo_discrete_conditions, redo_continuous_conditions
253
+
254
+
255
+ if __name__ == '__main__':
256
+ script_dir = os.path.dirname(os.path.abspath(__file__))
257
+ code_model_dir = os.path.abspath(os.path.join(script_dir, 'model'))
258
+ code_utils_dir = os.path.join(code_model_dir, 'utils')
259
+ sys.path.extend([code_model_dir, code_utils_dir])
260
+
261
+ parser = ArgumentParser()
262
+
263
+ parser.add_argument('--model_dir', type=str, help='Directory with model', required=True)
264
+ parser.add_argument('--no_cuda', action='store_true', help="Use CPU")
265
+ parser.add_argument('--num_runs', type=int, help='Number of runs', default=1)
266
+ parser.add_argument('--gen_len', type=int, help='Max generation len', default=4096)
267
+ parser.add_argument('--max_input_len', type=int, help='Max input len', default=1216)
268
+ parser.add_argument('--temp', type=float, nargs='+', help='Generation temperature', default=[1.2, 1.2])
269
+ parser.add_argument('--topk', type=int, help='Top-k sampling', default=-1)
270
+ parser.add_argument('--topp', type=float, help='Top-p sampling', default=0.7)
271
+ parser.add_argument('--debug', action='store_true', help="Do not save anything")
272
+ parser.add_argument('--seed', type=int, default=0, help="Random seed")
273
+ parser.add_argument('--no_amp', action='store_true', help="Disable automatic mixed precision")
274
+ parser.add_argument("--conditioning", type=str, required=True,
275
+ choices=["none", "discrete_token", "continuous_token",
276
+ "continuous_concat"], help='Conditioning type')
277
+ parser.add_argument('--penalty_coeff', type=float, default=0.5,
278
+ help="Coefficient for penalizing repeating notes")
279
+ parser.add_argument("--quiet", action='store_true', help="Not verbose")
280
+ parser.add_argument("--short_filename", action='store_true')
281
+ parser.add_argument('--batch_size', type=int, help='Batch size', default=4)
282
+ parser.add_argument('--min_n_instruments', type=int, help='Minimum number of instruments', default=1)
283
+ parser.add_argument('--valence', type=float, help='Conditioning valence value', default=[None], nargs='+')
284
+ parser.add_argument('--arousal', type=float, help='Conditioning arousal value', default=[None], nargs='+')
285
+ parser.add_argument("--batch_gen_dir", type=str, default="")
286
+
287
+ args = parser.parse_args()
288
+
289
+ assert len(args.valence) == len(args.arousal), "Lengths of valence and arousal must be equal"
290
+ assert (args.conditioning == "none") == (args.valence == [None] or args.arousal == [None]), \
291
+ "If conditioning is used, specify valence and arousal; if not, don't"
292
+
293
+ if args.seed > 0:
294
+ torch.manual_seed(args.seed)
295
+ torch.cuda.manual_seed(args.seed)
296
+
297
+ main_output_dir = "../output"
298
+ assert os.path.exists(os.path.join(main_output_dir, args.model_dir))
299
+ midi_output_dir = os.path.join(main_output_dir, args.model_dir, "generations", "inference")
300
+
301
+ new_dir = ""
302
+ if args.batch_gen_dir != "":
303
+ new_dir = new_dir + "_" + args.batch_gen_dir
304
+ if new_dir != "":
305
+ midi_output_dir = os.path.join(midi_output_dir, new_dir)
306
+ if not args.debug:
307
+ os.makedirs(midi_output_dir, exist_ok=True)
308
+
309
+ model_fp = os.path.join(main_output_dir, args.model_dir, 'model.pt')
310
+ mappings_fp = os.path.join(main_output_dir, args.model_dir, 'mappings.pt')
311
+ config_fp = os.path.join(main_output_dir, args.model_dir, 'model_config.pt')
312
+
313
+ if os.path.exists(mappings_fp):
314
+ maps = torch.load(mappings_fp)
315
+ else:
316
+ raise ValueError("Mapping file not found.")
317
+
318
+ start_symbol = "<START>"
319
+ n_emotion_bins = 5
320
+ valence_symbols, arousal_symbols = [], []
321
+
322
+ emotion_bins = np.linspace(-1-1e-12, 1+1e-12, num=n_emotion_bins+1)
323
+ if n_emotion_bins % 2 == 0:
324
+ bin_ids = list(range(-n_emotion_bins//2, 0)) + list(range(1, n_emotion_bins//2+1))
325
+ else:
326
+ bin_ids = list(range(-(n_emotion_bins-1)//2, (n_emotion_bins-1)//2 + 1))
327
+
328
+ for bin_id in bin_ids:
329
+ valence_symbols.append(f"<V{bin_id}>")
330
+ arousal_symbols.append(f"<A{bin_id}>")
331
+
332
+ device = torch.device('cuda' if not args.no_cuda and torch.cuda.is_available() else 'cpu')
333
+
334
+ verbose = not args.quiet
335
+ if verbose:
336
+ if device == torch.device("cuda"):
337
+ print("Using GPU")
338
+ else:
339
+ print("Using CPU")
340
+
341
+ # Load model
342
+ config = torch.load(config_fp)
343
+ model, _ = build_model(None, load_config_dict=config)
344
+ model = model.to(device)
345
+ if os.path.exists(model_fp):
346
+ model.load_state_dict(torch.load(model_fp, map_location=device))
347
+ elif os.path.exists(model_fp.replace("best_", "")):
348
+ model.load_state_dict(torch.load(model_fp.replace("best_", ""), map_location=device))
349
+ else:
350
+ raise ValueError("Model not found")
351
+
352
+ # Process conditions
353
+ null_condition = torch.FloatTensor([np.nan, np.nan]).to(device)
354
+
355
+ varying_condition = None
356
+ label_conditions = None
357
+
358
+ conditions = []
359
+ if args.valence == [None]:
360
+ conditions = None
361
+ elif len(args.valence) == 1:
362
+ for _ in range(args.batch_size):
363
+ conditions.append([args.valence[0], args.arousal[0]])
364
+ else:
365
+ for i in range(len(args.valence)):
366
+ conditions.append([args.valence[i], args.arousal[i]])
367
+
368
+ primers = [["<START>"]]
369
+ continuous_conditions = conditions
370
+ if args.conditioning == "discrete_token":
371
+
372
+ discrete_conditions = []
373
+ for condition in conditions:
374
+ valence_val, arousal_val = condition
375
+ valence_symbol = valence_symbols[np.searchsorted(
376
+ emotion_bins, valence_val, side="right") - 1]
377
+ arousal_symbol = arousal_symbols[np.searchsorted(
378
+ emotion_bins, arousal_val, side="right") - 1]
379
+ discrete_conditions.append([valence_symbol, arousal_symbol])
380
+
381
+ conditions = null_condition
382
+
383
+ elif args.conditioning == "none":
384
+ discrete_conditions = None
385
+ primers = [["<START>"] for _ in range(args.batch_size)]
386
+
387
+ elif args.conditioning in ["continuous_token", "continuous_concat"]:
388
+ primers = [["<START>"]]
389
+ discrete_conditions = None
390
+
391
+ for i in range(args.num_runs):
392
+ primers_run = deepcopy(primers)
393
+ discrete_conditions_run = deepcopy(discrete_conditions)
394
+ continuous_conditions_run = deepcopy(continuous_conditions)
395
+ while not (primers_run == [] or discrete_conditions_run == [] or continuous_conditions_run == []):
396
+ primers_run, discrete_conditions_run, continuous_conditions_run = generate(
397
+ model, maps, device,
398
+ midi_output_dir, args.conditioning, discrete_conditions=discrete_conditions_run,
399
+ min_n_instruments=args.min_n_instruments,continuous_conditions=continuous_conditions_run,
400
+ penalty_coeff=args.penalty_coeff, short_filename=args.short_filename, top_p=args.topp,
401
+ gen_len=args.gen_len, max_input_len=args.max_input_len,
402
+ amp=not args.no_amp, primers=primers_run, temperatures=args.temp, top_k=args.topk,
403
+ debug=args.debug, verbose=not args.quiet, seed=args.seed)
midi_emotion/src/models/build_model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ def set_dropout(model, rate):
3
+ for name, child in model.named_children():
4
+ if isinstance(child, nn.Dropout):
5
+ child.p = rate
6
+ set_dropout(child, rate)
7
+ return model
8
+
9
+ def build_model(args, load_config_dict=None):
10
+
11
+ if load_config_dict is not None:
12
+ args = load_config_dict
13
+
14
+ config = {
15
+ "vocab_size": args["vocab_size"],
16
+ "num_layer": args["n_layer"],
17
+ "num_head": args["n_head"],
18
+ "embedding_dim": args["d_model"],
19
+ "d_inner": args["d_inner"],
20
+ "dropout": args["dropout"],
21
+ "d_condition": args["d_condition"],
22
+ "max_seq": 2048,
23
+ "pad_token": 0,
24
+ }
25
+
26
+ if not "regression" in list(args.keys()):
27
+ args["regression"] = False
28
+
29
+ if args["regression"]:
30
+ config["output_size"] = 2
31
+ from models.music_regression \
32
+ import MusicRegression as MusicTransformer
33
+
34
+ elif args["conditioning"] == "continuous_token":
35
+ from models.music_continuous_token \
36
+ import MusicTransformerContinuousToken as MusicTransformer
37
+ del config["d_condition"]
38
+ else:
39
+ from .music_multi \
40
+ import MusicTransformerMulti as MusicTransformer
41
+
42
+ model = MusicTransformer(**config)
43
+ if load_config_dict is not None and args is not None:
44
+ if args["overwrite_dropout"]:
45
+ model = set_dropout(model, args["dropout"])
46
+ rate = args["dropout"]
47
+ print(f"Dropout rate changed to {rate}")
48
+ return model, args
midi_emotion/src/models/music_continuous_token.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math as m
3
+ import numpy as np
4
+ import math
5
+ import torch.nn.functional as F
6
+
7
+ """
8
+ MUSIC TRANSFORMER
9
+
10
+ CONTINUOUS TOKEN
11
+ Takes continuous conditions separately, embeds them and
12
+ then inserts them before the embedded sequence
13
+ Hence, they are like continuous tokens
14
+ """
15
+
16
+ def generate_mask(x, pad_token=None, batch_first=True):
17
+
18
+ batch_size = x.size(0)
19
+ seq_len = x.size(1)
20
+
21
+ subsequent_mask = torch.logical_not(torch.triu(torch.ones(seq_len, seq_len, device=x.device)).t()).unsqueeze(
22
+ -1).repeat(1, 1, batch_size)
23
+ pad_mask = x == pad_token
24
+ if batch_first:
25
+ pad_mask = pad_mask.t()
26
+ mask = torch.logical_or(subsequent_mask, pad_mask)
27
+ if batch_first:
28
+ mask = mask.permute(2, 0, 1)
29
+ return mask
30
+
31
+
32
+ class MusicTransformerContinuousToken(torch.nn.Module):
33
+ def __init__(self, embedding_dim=None, d_inner=None, vocab_size=None, num_layer=None, num_head=None,
34
+ max_seq=None, dropout=None, pad_token=None, has_start_token=True, n_conditions=2,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.max_seq = max_seq
39
+ self.num_layer = num_layer
40
+ self.embedding_dim = embedding_dim
41
+ self.vocab_size = vocab_size
42
+
43
+ self.pad_token = pad_token
44
+ self.has_start_token = has_start_token
45
+ self.n_conditions = n_conditions
46
+
47
+
48
+ self.embedding = torch.nn.Embedding(num_embeddings=vocab_size,
49
+ embedding_dim=self.embedding_dim,
50
+ padding_idx=pad_token)
51
+
52
+ # two vectors for two types of emotion (valence, energy/tempo)
53
+ # just like token embedding
54
+ self.fc_condition = torch.nn.ModuleList([torch.nn.Linear(1, self.embedding_dim) \
55
+ for _ in range(self.n_conditions)])
56
+
57
+ self.pos_encoding = DynamicPositionEmbedding(self.embedding_dim, max_seq=max_seq)
58
+
59
+ self.enc_layers = torch.nn.ModuleList(
60
+ [EncoderLayer(embedding_dim, d_inner, dropout, h=num_head, additional=False, max_seq=max_seq)
61
+ for _ in range(num_layer)])
62
+ self.dropout = torch.nn.Dropout(dropout)
63
+
64
+ self.fc = torch.nn.Linear(self.embedding_dim, self.vocab_size)
65
+
66
+ self.init_weights()
67
+
68
+ def init_weights(self):
69
+ initrange = 0.1
70
+ self.embedding.weight.data.uniform_(-initrange, initrange)
71
+ self.fc.bias.data.zero_()
72
+ self.fc.weight.data.uniform_(-initrange, initrange)
73
+ for i in range(len(self.fc_condition)):
74
+ self.fc_condition[i].weight.data.uniform_(-initrange, initrange)
75
+ self.fc_condition[i].bias.data.zero_()
76
+
77
+ def forward(self, x_tokens, condition):
78
+ # takes batch first
79
+ # x.shape = [batch_size, sequence_length]
80
+
81
+ # embed input
82
+ x = self.embedding(x_tokens) # (batch_size, input_seq_len, d_model)
83
+ x *= math.sqrt(self.embedding_dim)
84
+
85
+ # pad input sequence to represent continuous emotion vectors
86
+ x_tokens_padded = torch.nn.functional.pad(x_tokens, (condition.size(-1), 0), value=-1)
87
+ mask = generate_mask(x_tokens_padded, self.pad_token)
88
+
89
+ # embed conditions one by one, using different linear layers,
90
+ # just like token embedding
91
+ c = []
92
+ for i in range(self.n_conditions):
93
+ c.append(self.fc_condition[i](condition[:, i, None]))
94
+ c = torch.stack(c, dim=1)
95
+
96
+ # concatenate with conditions
97
+ x = torch.cat((c, x), dim=1)
98
+
99
+ x = self.pos_encoding(x)
100
+ x = self.dropout(x)
101
+ for i in range(len(self.enc_layers)):
102
+ x = self.enc_layers[i](x, mask)
103
+
104
+ x = self.fc(x)
105
+ return x
106
+
107
+ class EncoderLayer(torch.nn.Module):
108
+ def __init__(self, d_model, d_inner, rate=0.1, h=16, additional=False, max_seq=2048):
109
+ super(EncoderLayer, self).__init__()
110
+
111
+ self.d_model = d_model
112
+ self.rga = RelativeGlobalAttention(h=h, d=d_model, max_seq=max_seq, add_emb=additional)
113
+
114
+ self.FFN_pre = torch.nn.Linear(self.d_model, d_inner)
115
+ self.FFN_suf = torch.nn.Linear(d_inner, self.d_model)
116
+
117
+ self.layernorm1 = torch.nn.LayerNorm(self.d_model, eps=1e-6)
118
+ self.layernorm2 = torch.nn.LayerNorm(self.d_model, eps=1e-6)
119
+
120
+ self.dropout1 = torch.nn.Dropout(rate)
121
+ self.dropout2 = torch.nn.Dropout(rate)
122
+
123
+ def forward(self, x, mask=None, **kwargs):
124
+ attn_out = self.rga([x,x,x], mask)
125
+ attn_out = self.dropout1(attn_out)
126
+ out1 = self.layernorm1(attn_out+x)
127
+
128
+ ffn_out = F.relu(self.FFN_pre(out1))
129
+ ffn_out = self.FFN_suf(ffn_out)
130
+ ffn_out = self.dropout2(ffn_out)
131
+ out2 = self.layernorm2(out1+ffn_out)
132
+ return out2
133
+
134
+ def sinusoid(max_seq, embedding_dim):
135
+ return np.array([[
136
+ [
137
+ m.sin(
138
+ pos * m.exp(-m.log(10000) * i / embedding_dim) * m.exp(
139
+ m.log(10000) / embedding_dim * (i % 2)) + 0.5 * m.pi * (i % 2)
140
+ )
141
+ for i in range(embedding_dim)
142
+ ]
143
+ for pos in range(max_seq)
144
+ ]])
145
+
146
+ def sinusoid2(max_seq, embedding_dim):
147
+ pos_emb = np.zeros((1, max_seq, embedding_dim))
148
+ for index in range(0, embedding_dim, 2):
149
+ pos_emb[0, :, index] = np.array([m.sin(pos/10000**(index/embedding_dim))
150
+ for pos in range(max_seq)])
151
+ pos_emb[0, :, index+1] = np.array([m.cos(pos/10000**(index/embedding_dim))
152
+ for pos in range(max_seq)])
153
+ return pos_emb
154
+
155
+
156
+ class DynamicPositionEmbedding(torch.nn.Module):
157
+ def __init__(self, embedding_dim, max_seq=2048):
158
+ super().__init__()
159
+ self.device = torch.device("cpu")
160
+ self.dtype = torch.float32
161
+ embed_sinusoid_list = sinusoid(max_seq, embedding_dim)
162
+
163
+ self.positional_embedding = torch.from_numpy(embed_sinusoid_list).to(
164
+ self.device, dtype=self.dtype)
165
+
166
+ def forward(self, x):
167
+ if x.device != self.device or x.dtype != self.dtype:
168
+ self.positional_embedding = self.positional_embedding.to(x.device, dtype=x.dtype)
169
+ x += self.positional_embedding[:, :x.size(1), :]
170
+ return x
171
+
172
+
173
+ class RelativeGlobalAttention(torch.nn.Module):
174
+ """
175
+ from Music Transformer ( Huang et al, 2018 )
176
+ [paper link](https://arxiv.org/pdf/1809.04281.pdf)
177
+ """
178
+ def __init__(self, h=4, d=256, add_emb=False, max_seq=2048, **kwargs):
179
+ super().__init__()
180
+ self.len_k = None
181
+ self.max_seq = max_seq
182
+ self.E = None
183
+ self.h = h
184
+ self.d = d
185
+ self.dh = d // h
186
+ self.Wq = torch.nn.Linear(self.d, self.d)
187
+ self.Wk = torch.nn.Linear(self.d, self.d)
188
+ self.Wv = torch.nn.Linear(self.d, self.d)
189
+ self.fc = torch.nn.Linear(d, d)
190
+ self.additional = add_emb
191
+ self.E = torch.nn.Parameter(torch.randn([self.max_seq, int(self.dh)]))
192
+ if self.additional:
193
+ self.Radd = None
194
+
195
+ def forward(self, inputs, mask=None, **kwargs):
196
+ """
197
+ :param inputs: a list of tensors. i.e) [Q, K, V]
198
+ :param mask: mask tensor
199
+ :param kwargs:
200
+ :return: final tensor ( output of attention )
201
+ """
202
+ q = inputs[0]
203
+ q = self.Wq(q)
204
+ q = torch.reshape(q, (q.size(0), q.size(1), self.h, -1))
205
+ q = q.permute(0, 2, 1, 3) # batch, h, seq, dh
206
+
207
+ k = inputs[1]
208
+ k = self.Wk(k)
209
+ k = torch.reshape(k, (k.size(0), k.size(1), self.h, -1))
210
+ k = k.permute(0, 2, 1, 3)
211
+
212
+ v = inputs[2]
213
+ v = self.Wv(v)
214
+ v = torch.reshape(v, (v.size(0), v.size(1), self.h, -1))
215
+ v = v.permute(0, 2, 1, 3)
216
+
217
+ self.len_k = k.size(2)
218
+ self.len_q = q.size(2)
219
+
220
+ E = self._get_left_embedding(self.len_q, self.len_k).to(q.device)
221
+ QE = torch.einsum('bhld,md->bhlm', [q, E])
222
+ QE = self._qe_masking(QE)
223
+ Srel = self._skewing(QE)
224
+
225
+ Kt = k.permute(0, 1, 3, 2)
226
+ QKt = torch.matmul(q, Kt)
227
+ logits = QKt + Srel
228
+ logits = logits / math.sqrt(self.dh)
229
+
230
+ if mask is not None:
231
+ mask = mask.unsqueeze(1)
232
+ new_mask = torch.zeros_like(mask, dtype=torch.float)
233
+ new_mask.masked_fill_(mask, float("-inf"))
234
+ mask = new_mask
235
+ logits += mask
236
+
237
+ attention_weights = F.softmax(logits, -1)
238
+ attention = torch.matmul(attention_weights, v)
239
+
240
+ out = attention.permute(0, 2, 1, 3)
241
+ out = torch.reshape(out, (out.size(0), -1, self.d))
242
+
243
+ out = self.fc(out)
244
+ return out
245
+
246
+ def _get_left_embedding(self, len_q, len_k):
247
+ starting_point = max(0,self.max_seq-len_q)
248
+ e = self.E[starting_point:,:]
249
+ return e
250
+
251
+ def _skewing(self, tensor: torch.Tensor):
252
+ padded = F.pad(tensor, [1, 0, 0, 0, 0, 0, 0, 0])
253
+ reshaped = torch.reshape(padded, shape=[padded.size(0), padded.size(1), padded.size(-1), padded.size(-2)])
254
+ Srel = reshaped[:, :, 1:, :]
255
+ if self.len_k > self.len_q:
256
+ Srel = F.pad(Srel, [0, 0, 0, 0, 0, 0, 0, self.len_k-self.len_q])
257
+ elif self.len_k < self.len_q:
258
+ Srel = Srel[:, :, :, :self.len_k]
259
+
260
+ return Srel
261
+
262
+ @staticmethod
263
+ def _qe_masking(qe):
264
+ mask = sequence_mask(
265
+ torch.arange(qe.size()[-1] - 1, qe.size()[-1] - qe.size()[-2] - 1, -1).to(qe.device),
266
+ qe.size()[-1])
267
+ mask = ~mask.to(mask.device)
268
+ return mask.to(qe.dtype) * qe
269
+
270
+ def sequence_mask(length, max_length=None):
271
+ """Tensorflow의 sequence_mask를 구현"""
272
+ if max_length is None:
273
+ max_length = length.max()
274
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
275
+ return x.unsqueeze(0) < length.unsqueeze(1)
midi_emotion/src/models/music_multi.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math as m
3
+ import numpy as np
4
+ import math
5
+ import torch.nn.functional as F
6
+ import sys
7
+
8
+ sys.path.append("..")
9
+
10
+
11
+ """
12
+ MUSIC TRANSFORMER
13
+ Multi use, can handle following conditioning methods:
14
+ none (vanilla), continuous_concat, discrete_token
15
+
16
+ CONTINUOUS CONCAT
17
+ Takes continuous conditions as a vector of length 2, embeds it and
18
+ then concatenates it with every input token
19
+
20
+ If d_condition <= 0, it become VANILLA music transformer
21
+ If d_condition <= 0 and discrete condition tokens are fed,
22
+ it becomes "DISCRETE TOKEN" music transformer
23
+ """
24
+
25
+ def generate_mask(x, pad_token=None, batch_first=True):
26
+
27
+ batch_size = x.size(0)
28
+ seq_len = x.size(1)
29
+
30
+ subsequent_mask = torch.logical_not(torch.triu(torch.ones(seq_len, seq_len, device=x.device)).t()).unsqueeze(
31
+ -1).repeat(1, 1, batch_size)
32
+ pad_mask = x == pad_token
33
+ if batch_first:
34
+ pad_mask = pad_mask.t()
35
+ mask = torch.logical_or(subsequent_mask, pad_mask)
36
+ if batch_first:
37
+ mask = mask.permute(2, 0, 1)
38
+ return mask
39
+
40
+
41
+ class MusicTransformerMulti(torch.nn.Module):
42
+ def __init__(self, embedding_dim=None, d_inner=None, d_condition=None, vocab_size=None, num_layer=None, num_head=None,
43
+ max_seq=None, dropout=None, pad_token=None,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.max_seq = max_seq
48
+ self.num_layer = num_layer
49
+ self.embedding_dim = embedding_dim
50
+ self.vocab_size = vocab_size
51
+
52
+ self.pad_token = pad_token
53
+
54
+ d_condition = 0 if d_condition < 0 else d_condition
55
+ self.d_condition = d_condition
56
+
57
+ self.embedding = torch.nn.Embedding(num_embeddings=vocab_size,
58
+ embedding_dim=self.embedding_dim-self.d_condition,
59
+ padding_idx=pad_token)
60
+
61
+ if self.d_condition > 0:
62
+ self.fc_condition = torch.nn.Linear(2, self.d_condition)
63
+
64
+ self.pos_encoding = DynamicPositionEmbedding(self.embedding_dim, max_seq=max_seq)
65
+
66
+ self.enc_layers = torch.nn.ModuleList(
67
+ [EncoderLayer(embedding_dim, d_inner, dropout, h=num_head, additional=False, max_seq=max_seq)
68
+ for _ in range(num_layer)])
69
+ self.dropout = torch.nn.Dropout(dropout)
70
+
71
+ self.fc = torch.nn.Linear(self.embedding_dim, self.vocab_size)
72
+
73
+ self.init_weights()
74
+
75
+ def init_weights(self):
76
+ initrange = 0.1
77
+ self.embedding.weight.data.uniform_(-initrange, initrange)
78
+ self.fc.bias.data.zero_()
79
+ self.fc.weight.data.uniform_(-initrange, initrange)
80
+ if self.d_condition > 0:
81
+ self.fc_condition.bias.data.zero_()
82
+ self.fc_condition.weight.data.uniform_(-initrange, initrange)
83
+
84
+ def forward(self, x, condition):
85
+ # no_conditioning = not torch.equal(condition, condition)
86
+ # assert (self.d_condition > 0) != no_conditioning
87
+ # takes batch first
88
+ # x.shape = [batch_size, sequence_length]
89
+ mask = generate_mask(x, self.pad_token)
90
+ # embed input
91
+ x = self.embedding(x) # (batch_size, input_seq_len, d_model)
92
+ x *= math.sqrt(self.embedding_dim-self.d_condition)
93
+
94
+ if self.d_condition > 0:
95
+ # embed condition using fully connected layer
96
+ condition = self.fc_condition(condition)
97
+ # tile to match input
98
+ condition = condition.unsqueeze(1).expand(-1, x.size(1), -1)
99
+ x = torch.cat([x, condition], dim=-1) # concatenate
100
+
101
+ x = self.pos_encoding(x)
102
+ x = self.dropout(x)
103
+ for i in range(len(self.enc_layers)):
104
+ x = self.enc_layers[i](x, mask)
105
+
106
+ x = self.fc(x)
107
+
108
+ return x
109
+
110
+ class EncoderLayer(torch.nn.Module):
111
+ def __init__(self, d_model, d_inner, rate=0.1, h=16, additional=False, max_seq=2048):
112
+ super(EncoderLayer, self).__init__()
113
+
114
+ self.d_model = d_model
115
+ self.rga = RelativeGlobalAttention(h=h, d=d_model, max_seq=max_seq, add_emb=additional)
116
+
117
+ self.FFN_pre = torch.nn.Linear(self.d_model, d_inner)
118
+ self.FFN_suf = torch.nn.Linear(d_inner, self.d_model)
119
+
120
+ self.layernorm1 = torch.nn.LayerNorm(self.d_model, eps=1e-6)
121
+ self.layernorm2 = torch.nn.LayerNorm(self.d_model, eps=1e-6)
122
+
123
+ self.dropout1 = torch.nn.Dropout(rate)
124
+ self.dropout2 = torch.nn.Dropout(rate)
125
+
126
+ def forward(self, x, mask=None):
127
+ attn_out = self.rga([x,x,x], mask)
128
+ attn_out = self.dropout1(attn_out)
129
+ out1 = self.layernorm1(attn_out+x)
130
+
131
+ ffn_out = F.relu(self.FFN_pre(out1))
132
+ ffn_out = self.FFN_suf(ffn_out)
133
+ ffn_out = self.dropout2(ffn_out)
134
+ out2 = self.layernorm2(out1+ffn_out)
135
+ return out2
136
+
137
+ def sinusoid(max_seq, embedding_dim):
138
+ return np.array([[
139
+ [
140
+ m.sin(
141
+ pos * m.exp(-m.log(10000) * i / embedding_dim) * m.exp(
142
+ m.log(10000) / embedding_dim * (i % 2)) + 0.5 * m.pi * (i % 2)
143
+ )
144
+ for i in range(embedding_dim)
145
+ ]
146
+ for pos in range(max_seq)
147
+ ]])
148
+
149
+
150
+ class DynamicPositionEmbedding(torch.nn.Module):
151
+ def __init__(self, embedding_dim, max_seq=2048):
152
+ super().__init__()
153
+ self.device = torch.device("cpu")
154
+ self.dtype = torch.float32
155
+ embed_sinusoid_list = sinusoid(max_seq, embedding_dim)
156
+
157
+ self.positional_embedding = torch.from_numpy(embed_sinusoid_list).to(
158
+ self.device, dtype=self.dtype)
159
+
160
+ def forward(self, x):
161
+ if x.device != self.device or x.dtype != self.dtype:
162
+ self.positional_embedding = self.positional_embedding.to(x.device, dtype=x.dtype)
163
+ x += self.positional_embedding[:, :x.size(1), :]
164
+ return x
165
+
166
+
167
+ class RelativeGlobalAttention(torch.nn.Module):
168
+ """
169
+ from Music Transformer ( Huang et al, 2018 )
170
+ [paper link](https://arxiv.org/pdf/1809.04281.pdf)
171
+ """
172
+ def __init__(self, h=4, d=256, add_emb=False, max_seq=2048):
173
+ super().__init__()
174
+ self.len_k = None
175
+ self.max_seq = max_seq
176
+ self.E = None
177
+ self.h = h
178
+ self.d = d
179
+ self.dh = d // h
180
+ self.Wq = torch.nn.Linear(self.d, self.d)
181
+ self.Wk = torch.nn.Linear(self.d, self.d)
182
+ self.Wv = torch.nn.Linear(self.d, self.d)
183
+ self.fc = torch.nn.Linear(d, d)
184
+ self.additional = add_emb
185
+ self.E = torch.nn.Parameter(torch.randn([self.max_seq, int(self.dh)]))
186
+ if self.additional:
187
+ self.Radd = None
188
+
189
+ def forward(self, inputs, mask=None):
190
+ """
191
+ :param inputs: a list of tensors. i.e) [Q, K, V]
192
+ :param mask: mask tensor
193
+ :param kwargs:
194
+ :return: final tensor ( output of attention )
195
+ """
196
+ q = inputs[0]
197
+ q = self.Wq(q)
198
+ q = torch.reshape(q, (q.size(0), q.size(1), self.h, -1))
199
+ q = q.permute(0, 2, 1, 3) # batch, h, seq, dh
200
+
201
+ k = inputs[1]
202
+ k = self.Wk(k)
203
+ k = torch.reshape(k, (k.size(0), k.size(1), self.h, -1))
204
+ k = k.permute(0, 2, 1, 3)
205
+
206
+ v = inputs[2]
207
+ v = self.Wv(v)
208
+ v = torch.reshape(v, (v.size(0), v.size(1), self.h, -1))
209
+ v = v.permute(0, 2, 1, 3)
210
+
211
+ self.len_k = k.size(2)
212
+ self.len_q = q.size(2)
213
+
214
+ E = self._get_left_embedding(self.len_q, self.len_k).to(q.device)
215
+ QE = torch.einsum('bhld,md->bhlm', [q, E])
216
+ QE = self._qe_masking(QE)
217
+ Srel = self._skewing(QE)
218
+
219
+ Kt = k.permute(0, 1, 3, 2)
220
+ QKt = torch.matmul(q, Kt)
221
+ logits = QKt + Srel
222
+ logits = logits / math.sqrt(self.dh)
223
+
224
+ if mask is not None:
225
+ mask = mask.unsqueeze(1)
226
+ new_mask = torch.zeros_like(mask, dtype=torch.float)
227
+ new_mask.masked_fill_(mask, float("-inf"))
228
+ mask = new_mask
229
+ logits += mask
230
+
231
+ attention_weights = F.softmax(logits, -1)
232
+ attention = torch.matmul(attention_weights, v)
233
+
234
+ out = attention.permute(0, 2, 1, 3)
235
+ out = torch.reshape(out, (out.size(0), -1, self.d))
236
+
237
+ out = self.fc(out)
238
+ return out
239
+
240
+ def _get_left_embedding(self, len_q, len_k):
241
+ starting_point = max(0,self.max_seq-len_q)
242
+ e = self.E[starting_point:,:]
243
+ return e
244
+
245
+ def _skewing(self, tensor: torch.Tensor):
246
+ padded = F.pad(tensor, [1, 0, 0, 0, 0, 0, 0, 0])
247
+ reshaped = torch.reshape(padded, shape=[padded.size(0), padded.size(1), padded.size(-1), padded.size(-2)])
248
+ Srel = reshaped[:, :, 1:, :]
249
+ if self.len_k > self.len_q:
250
+ Srel = F.pad(Srel, [0, 0, 0, 0, 0, 0, 0, self.len_k-self.len_q])
251
+ elif self.len_k < self.len_q:
252
+ Srel = Srel[:, :, :, :self.len_k]
253
+
254
+ return Srel
255
+
256
+ @staticmethod
257
+ def _qe_masking(qe):
258
+ mask = sequence_mask(
259
+ torch.arange(qe.size()[-1] - 1, qe.size()[-1] - qe.size()[-2] - 1, -1).to(qe.device),
260
+ qe.size()[-1])
261
+ mask = ~mask.to(mask.device)
262
+ return mask.to(qe.dtype) * qe
263
+
264
+ def sequence_mask(length, max_length=None):
265
+ """Tensorflow의 sequence_mask를 구현"""
266
+ if max_length is None:
267
+ max_length = length.max()
268
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
269
+ return x.unsqueeze(0) < length.unsqueeze(1)
midi_emotion/src/models/music_regression.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math as m
3
+ import numpy as np
4
+ import math
5
+ import torch.nn.functional as F
6
+ import sys
7
+
8
+ # from torch.nn.modules.activation import ReLU
9
+
10
+ sys.path.append("..")
11
+ # from utils import memory
12
+
13
+
14
+ """
15
+ MUSIC TRANSFORMER REGRESSION (to output emotion)
16
+ """
17
+
18
+ def generate_mask(x, pad_token=None, batch_first=True):
19
+
20
+ batch_size = x.size(0)
21
+ seq_len = x.size(1)
22
+
23
+ subsequent_mask = torch.logical_not(torch.triu(torch.ones(seq_len, seq_len, device=x.device)).t()).unsqueeze(
24
+ -1).repeat(1, 1, batch_size)
25
+ pad_mask = x == pad_token
26
+ if batch_first:
27
+ pad_mask = pad_mask.t()
28
+ mask = torch.logical_or(subsequent_mask, pad_mask)
29
+ if batch_first:
30
+ mask = mask.permute(2, 0, 1)
31
+ return mask
32
+
33
+
34
+ class MusicRegression(torch.nn.Module):
35
+ def __init__(self, embedding_dim=None, d_inner=None, vocab_size=None, num_layer=None, num_head=None,
36
+ max_seq=None, dropout=None, pad_token=None, output_size=None,
37
+ d_condition=-1, no_mask=True
38
+ ):
39
+ super().__init__()
40
+
41
+ assert d_condition <= 0
42
+
43
+ self.max_seq = max_seq
44
+ self.num_layer = num_layer
45
+ self.embedding_dim = embedding_dim
46
+ self.vocab_size = vocab_size
47
+
48
+ self.pad_token = pad_token
49
+
50
+ self.no_mask = no_mask
51
+
52
+ self.embedding = torch.nn.Embedding(num_embeddings=vocab_size,
53
+ embedding_dim=self.embedding_dim,
54
+ padding_idx=pad_token)
55
+
56
+
57
+ self.pos_encoding = DynamicPositionEmbedding(self.embedding_dim, max_seq=max_seq)
58
+
59
+ self.enc_layers = torch.nn.ModuleList(
60
+ [EncoderLayer(embedding_dim, d_inner, dropout, h=num_head, additional=False, max_seq=max_seq)
61
+ for _ in range(num_layer)])
62
+ self.dropout = torch.nn.Dropout(dropout)
63
+
64
+ self.fc = torch.nn.Sequential(
65
+ torch.nn.Linear(self.embedding_dim, output_size),
66
+ torch.nn.Tanh()
67
+ )
68
+
69
+ self.init_weights()
70
+
71
+ def init_weights(self):
72
+ initrange = 0.1
73
+ self.embedding.weight.data.uniform_(-initrange, initrange)
74
+
75
+ def forward(self, x):
76
+
77
+ mask = None if self.no_mask else generate_mask(x, self.pad_token)
78
+ # embed input
79
+ x = self.embedding(x) # (batch_size, input_seq_len, d_model)
80
+ x *= math.sqrt(self.embedding_dim)
81
+
82
+ x = self.pos_encoding(x)
83
+ x = self.dropout(x)
84
+ for i in range(len(self.enc_layers)):
85
+ x = self.enc_layers[i](x, mask)
86
+
87
+ x = self.fc(x[:, 0, :])
88
+
89
+ return x
90
+
91
+ class EncoderLayer(torch.nn.Module):
92
+ def __init__(self, d_model, d_inner, rate=0.1, h=16, additional=False, max_seq=2048):
93
+ super(EncoderLayer, self).__init__()
94
+
95
+ self.d_model = d_model
96
+ self.rga = RelativeGlobalAttention(h=h, d=d_model, max_seq=max_seq, add_emb=additional)
97
+
98
+ self.FFN_pre = torch.nn.Linear(self.d_model, d_inner)
99
+ self.FFN_suf = torch.nn.Linear(d_inner, self.d_model)
100
+
101
+ self.layernorm1 = torch.nn.LayerNorm(self.d_model, eps=1e-6)
102
+ self.layernorm2 = torch.nn.LayerNorm(self.d_model, eps=1e-6)
103
+
104
+ self.dropout1 = torch.nn.Dropout(rate)
105
+ self.dropout2 = torch.nn.Dropout(rate)
106
+
107
+ def forward(self, x, mask=None):
108
+ attn_out = self.rga([x,x,x], mask)
109
+ attn_out = self.dropout1(attn_out)
110
+ out1 = self.layernorm1(attn_out+x)
111
+
112
+ ffn_out = F.relu(self.FFN_pre(out1))
113
+ ffn_out = self.FFN_suf(ffn_out)
114
+ ffn_out = self.dropout2(ffn_out)
115
+ out2 = self.layernorm2(out1+ffn_out)
116
+ return out2
117
+
118
+ def sinusoid(max_seq, embedding_dim):
119
+ return np.array([[
120
+ [
121
+ m.sin(
122
+ pos * m.exp(-m.log(10000) * i / embedding_dim) * m.exp(
123
+ m.log(10000) / embedding_dim * (i % 2)) + 0.5 * m.pi * (i % 2)
124
+ )
125
+ for i in range(embedding_dim)
126
+ ]
127
+ for pos in range(max_seq)
128
+ ]])
129
+
130
+
131
+ class DynamicPositionEmbedding(torch.nn.Module):
132
+ def __init__(self, embedding_dim, max_seq=2048):
133
+ super().__init__()
134
+ self.device = torch.device("cpu")
135
+ self.dtype = torch.float32
136
+ embed_sinusoid_list = sinusoid(max_seq, embedding_dim)
137
+
138
+ self.positional_embedding = torch.from_numpy(embed_sinusoid_list).to(
139
+ self.device, dtype=self.dtype)
140
+
141
+ def forward(self, x):
142
+ if x.device != self.device or x.dtype != self.dtype:
143
+ self.positional_embedding = self.positional_embedding.to(x.device, dtype=x.dtype)
144
+ x += self.positional_embedding[:, :x.size(1), :]
145
+ return x
146
+
147
+
148
+ class RelativeGlobalAttention(torch.nn.Module):
149
+ """
150
+ from Music Transformer ( Huang et al, 2018 )
151
+ [paper link](https://arxiv.org/pdf/1809.04281.pdf)
152
+ """
153
+ def __init__(self, h=4, d=256, add_emb=False, max_seq=2048):
154
+ super().__init__()
155
+ self.len_k = None
156
+ self.max_seq = max_seq
157
+ self.E = None
158
+ self.h = h
159
+ self.d = d
160
+ self.dh = d // h
161
+ self.Wq = torch.nn.Linear(self.d, self.d)
162
+ self.Wk = torch.nn.Linear(self.d, self.d)
163
+ self.Wv = torch.nn.Linear(self.d, self.d)
164
+ self.fc = torch.nn.Linear(d, d)
165
+ self.additional = add_emb
166
+ self.E = torch.nn.Parameter(torch.randn([self.max_seq, int(self.dh)]))
167
+ if self.additional:
168
+ self.Radd = None
169
+
170
+ def forward(self, inputs, mask=None):
171
+ """
172
+ :param inputs: a list of tensors. i.e) [Q, K, V]
173
+ :param mask: mask tensor
174
+ :param kwargs:
175
+ :return: final tensor ( output of attention )
176
+ """
177
+ q = inputs[0]
178
+ q = self.Wq(q)
179
+ q = torch.reshape(q, (q.size(0), q.size(1), self.h, -1))
180
+ q = q.permute(0, 2, 1, 3) # batch, h, seq, dh
181
+
182
+ k = inputs[1]
183
+ k = self.Wk(k)
184
+ k = torch.reshape(k, (k.size(0), k.size(1), self.h, -1))
185
+ k = k.permute(0, 2, 1, 3)
186
+
187
+ v = inputs[2]
188
+ v = self.Wv(v)
189
+ v = torch.reshape(v, (v.size(0), v.size(1), self.h, -1))
190
+ v = v.permute(0, 2, 1, 3)
191
+
192
+ self.len_k = k.size(2)
193
+ self.len_q = q.size(2)
194
+
195
+ E = self._get_left_embedding(self.len_q, self.len_k).to(q.device)
196
+ QE = torch.einsum('bhld,md->bhlm', [q, E])
197
+ QE = self._qe_masking(QE)
198
+ Srel = self._skewing(QE)
199
+
200
+ Kt = k.permute(0, 1, 3, 2)
201
+ QKt = torch.matmul(q, Kt)
202
+ logits = QKt + Srel
203
+ logits = logits / math.sqrt(self.dh)
204
+
205
+ if mask is not None:
206
+ mask = mask.unsqueeze(1)
207
+ new_mask = torch.zeros_like(mask, dtype=torch.float)
208
+ new_mask.masked_fill_(mask, float("-inf"))
209
+ mask = new_mask
210
+ logits += mask
211
+
212
+ attention_weights = F.softmax(logits, -1)
213
+ attention = torch.matmul(attention_weights, v)
214
+
215
+ out = attention.permute(0, 2, 1, 3)
216
+ out = torch.reshape(out, (out.size(0), -1, self.d))
217
+
218
+ out = self.fc(out)
219
+ return out
220
+
221
+ def _get_left_embedding(self, len_q, len_k):
222
+ starting_point = max(0,self.max_seq-len_q)
223
+ e = self.E[starting_point:,:]
224
+ return e
225
+
226
+ def _skewing(self, tensor: torch.Tensor):
227
+ padded = F.pad(tensor, [1, 0, 0, 0, 0, 0, 0, 0])
228
+ reshaped = torch.reshape(padded, shape=[padded.size(0), padded.size(1), padded.size(-1), padded.size(-2)])
229
+ Srel = reshaped[:, :, 1:, :]
230
+ if self.len_k > self.len_q:
231
+ Srel = F.pad(Srel, [0, 0, 0, 0, 0, 0, 0, self.len_k-self.len_q])
232
+ elif self.len_k < self.len_q:
233
+ Srel = Srel[:, :, :, :self.len_k]
234
+
235
+ return Srel
236
+
237
+ @staticmethod
238
+ def _qe_masking(qe):
239
+ mask = sequence_mask(
240
+ torch.arange(qe.size()[-1] - 1, qe.size()[-1] - qe.size()[-2] - 1, -1).to(qe.device),
241
+ qe.size()[-1])
242
+ mask = ~mask.to(mask.device)
243
+ return mask.to(qe.dtype) * qe
244
+
245
+ def sequence_mask(length, max_length=None):
246
+ """Tensorflow의 sequence_mask를 구현"""
247
+ if max_length is None:
248
+ max_length = length.max()
249
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
250
+ return x.unsqueeze(0) < length.unsqueeze(1)
midi_emotion/src/models/transfer_model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import sys
4
+ sys.path.append("..")
5
+ from models.build_model import build_model
6
+
7
+ """
8
+ Transfers model weights.
9
+ You can create a non-trained target model buy running:
10
+ python train.py --log_step 1 --max_step 1 ...
11
+ """
12
+
13
+ trained_model_dir = "20220803-130921"
14
+ new_model_dir = "20220803-131016"
15
+
16
+ device = "cuda" if torch.cuda.is_available() else 'cpu'
17
+
18
+ main_dir = "../../output"
19
+
20
+ trained_config = torch.load(os.path.join(main_dir, trained_model_dir, "model_config.pt"))
21
+
22
+ trained_model, _ = build_model(None, load_config_dict=trained_config)
23
+ trained_model = trained_model.to(device)
24
+ trained_model.load_state_dict(torch.load(os.path.join(main_dir, trained_model_dir, 'model.pt'), map_location=device))
25
+
26
+ new_config = torch.load(os.path.join(main_dir, new_model_dir, "model_config.pt"))
27
+ new_model, _ = build_model(None, load_config_dict=new_config)
28
+ new_model = new_model.to(device)
29
+
30
+ trained_params = trained_model.named_parameters()
31
+ new_params = new_model.named_parameters()
32
+ dict_new_params = dict(new_params)
33
+ for name1, param1 in trained_params:
34
+ if name1 in dict_new_params:
35
+
36
+ if name1 == 'embedding.weight':
37
+ # continuous_concat may have different sized embedding
38
+ size1 = dict_new_params[name1].data.shape[1]
39
+ size2 = param1.data.shape[1]
40
+ size_transfer = min((size1, size2))
41
+ dict_new_params[name1].data[:, :size_transfer] = param1.data[:, :size_transfer]
42
+ else:
43
+ dict_new_params[name1].data.copy_(param1.data)
44
+
45
+
46
+ output_path = os.path.join(main_dir, new_model_dir, 'model.pt')
47
+ torch.save(new_model.state_dict(), output_path)
48
+
49
+ print(f"Saved to {output_path}")
midi_emotion/src/models/transformer.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+
7
+ class Transformer(nn.Module):
8
+
9
+ def __init__(self, n_tokens=None, n_layer=None, n_head=None, d_model=None, d_ff=None,
10
+ dropout=0.0, pad_idx=0):
11
+ super(Transformer, self).__init__()
12
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
13
+ # self.name = 'Transformer'
14
+ self.pos_encoder = PositionalEncoding(d_model, dropout)
15
+ encoder_layers = TransformerEncoderLayer(d_model, n_head, dim_feedforward=d_ff, dropout=dropout)
16
+ norm = nn.LayerNorm(d_model)
17
+ self.transformer_encoder = TransformerEncoder(encoder_layers, n_layer, norm=norm)
18
+ self.encoder = nn.Embedding(n_tokens, d_model, padding_idx=pad_idx)
19
+ self.d_model = d_model
20
+ self.decoder = nn.Linear(d_model, n_tokens)
21
+
22
+ self.init_weights()
23
+
24
+ def init_weights(self):
25
+ initrange = 0.1
26
+ self.encoder.weight.data.uniform_(-initrange, initrange)
27
+ self.decoder.bias.data.zero_()
28
+ self.decoder.weight.data.uniform_(-initrange, initrange)
29
+
30
+ def forward(self, src, src_mask, src_key_padding_mask=None):
31
+
32
+ src = self.encoder(src) * math.sqrt(self.d_model)
33
+ src = self.pos_encoder(src)
34
+ output = self.transformer_encoder(src, src_mask,
35
+ src_key_padding_mask=src_key_padding_mask)
36
+ output = self.decoder(output)
37
+ return output
38
+
39
+
40
+ class PositionalEncoding(nn.Module):
41
+
42
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
43
+ super(PositionalEncoding, self).__init__()
44
+ self.dropout = nn.Dropout(p=dropout)
45
+
46
+ pe = torch.zeros(max_len, d_model)
47
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
48
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
49
+ pe[:, 0::2] = torch.sin(position * div_term)
50
+ pe[:, 1::2] = torch.cos(position * div_term)
51
+ pe = pe.unsqueeze(0).transpose(0, 1)
52
+ self.register_buffer('pe', pe)
53
+
54
+ def forward(self, x):
55
+ x = x + self.pe[:x.size(0), :]
56
+ return self.dropout(x)
midi_emotion/src/train.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import math
3
+ import datetime
4
+ import os
5
+ import random
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ from tqdm import tqdm
11
+ from models.build_model import build_model
12
+ from generate import generate
13
+ from data.preprocess_features import preprocess_features
14
+ from data.loader import Loader
15
+ from data.loader_exhaustive import LoaderExhaustive
16
+ from data.loader_generations import LoaderGenerations
17
+ from data.collate import filter_collate
18
+ from utils import CsvWriter, create_exp_dir, accuracy
19
+ from config import args
20
+
21
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
22
+
23
+ # Set the random seed manually for reproducibility.
24
+ if args.seed > 0:
25
+ np.random.seed(args.seed)
26
+ torch.manual_seed(args.seed)
27
+ torch.cuda.manual_seed(args.seed)
28
+ random.seed(args.seed)
29
+
30
+ class Runner:
31
+ def __init__(self):
32
+ self.logging = create_exp_dir(args.work_dir, debug=args.debug)
33
+ use_cuda = torch.cuda.is_available() and not args.no_cuda
34
+ self.device = torch.device('cuda' if use_cuda else 'cpu')
35
+ if self.device == torch.device("cuda"):
36
+ self.logging("Using GPU")
37
+ else:
38
+ self.logging("Using CPU")
39
+
40
+ self.train_step = 0
41
+ self.n_sequences_total = 0
42
+ self.init_hours = 0
43
+ self.epoch = 0
44
+ self.init_time = time.time()
45
+
46
+ # Load data
47
+ n_bins = args.n_emotion_bins if args.conditioning == "discrete_token" and \
48
+ not args.regression else None
49
+
50
+ conditional = args.conditioning != "none" or args.regression
51
+
52
+ # Preprocessing
53
+ train_feats, test_feats = preprocess_features(
54
+ "../data_files/features/pianoroll/full_dataset_features_summarized.csv",
55
+ n_bins=n_bins, conditional=conditional,
56
+ use_labeled_only=not args.full_dataset)
57
+
58
+ if args.exhaustive_eval:
59
+ # Evaluate using ENTIRE test set
60
+ train_dataset = []
61
+ test_dataset = LoaderExhaustive(args.data_folder, test_feats, args.tgt_len, args.conditioning,
62
+ max_samples=args.n_samples, regression=args.regression,
63
+ always_use_discrete_condition=args.always_use_discrete_condition)
64
+ else:
65
+ train_dataset = Loader(args.data_folder, train_feats, args.tgt_len, args.conditioning,
66
+ regression=args.regression, always_use_discrete_condition=args.always_use_discrete_condition)
67
+ test_dataset = Loader(args.data_folder, test_feats, args.tgt_len, args.conditioning,
68
+ regression=args.regression, always_use_discrete_condition=args.always_use_discrete_condition)
69
+
70
+ if args.regression_dir is not None:
71
+ # Perform emotion regression on generated samples
72
+ train_dataset = []
73
+ test_dataset = LoaderGenerations(args.regression_dir, args.tgt_len)
74
+
75
+ self.null_condition = torch.FloatTensor([np.nan, np.nan]).to(self.device)
76
+
77
+ self.maps = test_dataset.get_maps()
78
+ self.pad_idx = test_dataset.get_pad_idx()
79
+
80
+ self.vocab_size = test_dataset.get_vocab_len()
81
+ args.vocab_size = self.vocab_size
82
+ self.logging(f"Number of tokens: {self.vocab_size}")
83
+
84
+ if args.exhaustive_eval or args.regression_dir is not None:
85
+ self.train_loader = []
86
+ else:
87
+ self.train_loader = torch.utils.data.DataLoader(train_dataset, args.batch_size, shuffle=not args.debug,
88
+ num_workers=args.num_workers, collate_fn=filter_collate,
89
+ pin_memory=not args.no_cuda, drop_last=True)
90
+ self.test_loader = torch.utils.data.DataLoader(test_dataset, args.batch_size, shuffle=False,
91
+ num_workers=args.num_workers, collate_fn=filter_collate,
92
+ pin_memory=not args.no_cuda and args.regression_dir is None,
93
+ drop_last=True)
94
+ print(f"Data loader lengths\nTrain: {len(train_dataset)}")
95
+ if not args.overfit:
96
+ print(f"Test:{len(test_dataset)}")
97
+
98
+ self.gen_dir = os.path.join(args.work_dir, "generations", "training")
99
+
100
+ # Automatic mixed precision
101
+ self.amp = not args.no_amp and self.device == torch.device('cuda')
102
+
103
+ if self.amp:
104
+ self.logging("Using automatic mixed precision")
105
+ else:
106
+ self.logging("Using float32")
107
+
108
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
109
+ self.init_model() # Build the model
110
+
111
+ if not args.debug:
112
+ # Save mappings
113
+ os.makedirs(self.gen_dir, exist_ok=True)
114
+ torch.save(self.maps, os.path.join(args.work_dir, "mappings.pt"))
115
+
116
+ self.csv_writer = CsvWriter(os.path.join(args.work_dir, "performance.csv"),
117
+ ["epoch", "step", "hour", "lr", "trn_loss", "val_loss", "val_l1_v", "val_l1_a"],
118
+ in_path=self.csv_in, debug=args.debug)
119
+
120
+ args.n_all_param = sum([p.nelement() for p in self.model.parameters()])
121
+
122
+ self.model = self.model.to(self.device)
123
+
124
+ self.ce_loss = nn.CrossEntropyLoss(ignore_index=self.pad_idx).to(self.device)
125
+ self.mse_loss = nn.MSELoss()
126
+ self.l1_loss = nn.L1Loss()
127
+
128
+ #### scheduler
129
+ if args.scheduler == '--':
130
+ self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
131
+ args.max_step, eta_min=args.eta_min)
132
+ elif args.scheduler == 'dev_perf':
133
+ self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
134
+ factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
135
+ elif args.scheduler == 'constant':
136
+ pass
137
+ elif args.scheduler == 'cyclic':
138
+ self.scheduler = optim.lr_scheduler.CyclicLR(self.optimizer,
139
+ args.lr_min, args.lr_max, verbose=False, cycle_momentum=False)
140
+
141
+ # Print log
142
+ if not args.debug:
143
+ self.logging('=' * 120)
144
+ for k, v in args.__dict__.items():
145
+ self.logging(' - {} : {}'.format(k, v))
146
+ self.logging('=' * 120)
147
+ self.logging('#params = {}'.format(args.n_all_param))
148
+
149
+ now = datetime.datetime.now()
150
+ now = now.strftime("%d-%m-%Y %H:%M")
151
+ self.logging(f"Run started at {now}")
152
+ self.once = True
153
+
154
+ def init_model(self):
155
+ # Initialize model
156
+ if args.restart_dir:
157
+ # Load existing model
158
+ config = torch.load(os.path.join(args.restart_dir, "model_config.pt"))
159
+ self.model, config = build_model(None, load_config_dict=config)
160
+ self.model = self.model.to(self.device)
161
+
162
+ model_fp = os.path.join(args.restart_dir, 'model.pt')
163
+ optimizer_fp = os.path.join(args.restart_dir, 'optimizer.pt')
164
+ stats_fp = os.path.join(args.restart_dir, 'stats.pt')
165
+ scaler_fp = os.path.join(args.restart_dir, 'scaler.pt')
166
+
167
+ self.model.load_state_dict(
168
+ torch.load(model_fp, map_location=lambda storage, loc: storage))
169
+ self.logging(f"Model loaded from {model_fp}")
170
+
171
+ self.csv_in = os.path.join(args.restart_dir, 'performance.csv')
172
+ else:
173
+ # Build model from scratch
174
+ self.csv_in = None
175
+ self.model, config = build_model(vars(args))
176
+ self.model = self.model.to(self.device)
177
+
178
+ # save model configuration for later load
179
+ if not args.debug:
180
+ torch.save(config, os.path.join(args.work_dir, "model_config.pt"))
181
+
182
+ self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
183
+
184
+ # Load self.optimizer if necessary
185
+ if args.restart_dir:
186
+ if os.path.exists(optimizer_fp):
187
+ try:
188
+ self.optimizer.load_state_dict(
189
+ torch.load(optimizer_fp, map_location=lambda storage, loc: storage))
190
+ except:
191
+ pass
192
+ else:
193
+ print('Optimizer was not saved. Start from scratch.')
194
+
195
+ try:
196
+ stats = torch.load(stats_fp)
197
+ self.train_step = stats["step"]
198
+ self.init_hours = stats["hour"]
199
+ self.epoch = stats["epoch"]
200
+ self.n_sequences_total = stats["sample"]
201
+ except:
202
+ self.train_step = 0
203
+ self.init_hours = 0
204
+ self.epoch = 0
205
+ self.n_sequences_total = 0
206
+
207
+ if os.path.exists(scaler_fp) and not args.reset_scaler:
208
+ try:
209
+ self.scaler.load_state_dict(torch.load(scaler_fp))
210
+ except:
211
+ pass
212
+
213
+ if args.overwrite_lr:
214
+ # New learning rate
215
+ for p in self.optimizer.param_groups:
216
+ p['lr'] = args.lr
217
+
218
+ ###############################################################################
219
+ # EVALUATION
220
+ ###############################################################################
221
+
222
+ def evaluate(self):
223
+
224
+ # Turn on evaluation mode which disables dropout.
225
+ self.model.eval()
226
+
227
+ # Evaluation
228
+ topk = (1, 5) # find accuracy for top-1 and top-5
229
+ n_elements_total, n_sequences_total, total_loss = 0, 0, 0.
230
+ total_accs = {"l1_v": 0., "l1_a": 0., "l1_mean": 0., "l1_mean_normal":0
231
+ } if args.regression else {k: 0. for k in topk}
232
+ with torch.no_grad():
233
+ n_batches = len(self.test_loader)
234
+ loader = enumerate(self.test_loader)
235
+ if args.exhaustive_eval or args.regression:
236
+ loader = tqdm(loader, total=n_batches)
237
+ for i, (input_, condition, target) in loader:
238
+ if args.max_eval_step > 0 and i >= args.max_eval_step:
239
+ break
240
+ if input_ != []:
241
+ input_ = input_.to(self.device)
242
+ condition = condition.to(self.device)
243
+ if not args.regression:
244
+ target = target.to(self.device)
245
+ loss, pred = self.forward_pass(input_, condition, target)
246
+ if args.regression:
247
+ pred = torch.clamp(pred, min=-1.0, max=1.0)
248
+ loss = self.l1_loss(pred, condition)
249
+ l1_v = self.l1_loss(pred[:, 0], condition[:, 0]).item()
250
+ l1_a = self.l1_loss(pred[:, 1], condition[:, 1]).item()
251
+ accuracies = {"l1_v": l1_v, "l1_a": l1_a,
252
+ "l1_mean": (l1_v + l1_a) / 2,
253
+ "l1_mean_normal": (l1_v + l1_a) / 2 / 2}
254
+ n_elements = pred[:, 0].numel()
255
+ else:
256
+ accuracies = accuracy(pred, target, topk=topk, ignore_index=self.pad_idx)
257
+ n_elements = input_.numel()
258
+ n_sequences = input_.size(0)
259
+ total_loss += n_elements * loss.item()
260
+ for key, value in accuracies.items():
261
+ total_accs[key] += n_elements * value
262
+ n_elements_total += n_elements
263
+ n_sequences_total += n_sequences
264
+
265
+ if n_elements_total == 0:
266
+ avg_loss = float('nan')
267
+ avg_accs = float('nan')
268
+ else:
269
+ avg_loss = total_loss / n_elements_total
270
+ avg_accs = {k: v/n_elements_total for k, v in total_accs.items()}
271
+ if args.exhaustive_eval:
272
+ print(f"Total number of sequences: {n_sequences_total}")
273
+
274
+ return avg_loss, avg_accs
275
+
276
+ def forward_pass(self, input_, condition, target):
277
+
278
+ input_ = input_.to(self.device)
279
+ condition = condition.to(self.device)
280
+
281
+ with torch.cuda.amp.autocast(enabled=self.amp):
282
+ if args.regression:
283
+ output = self.model(input_)
284
+ loss = self.l1_loss(output, condition)
285
+ else:
286
+ target = target.to(self.device)
287
+ output = self.model(input_, condition)
288
+ output_flat = output.reshape(-1, output.size(-1))
289
+ target = target.reshape(-1)
290
+ loss = self.ce_loss(output_flat, target)
291
+
292
+ return loss, output
293
+
294
+ def train(self):
295
+ # Turn on training mode which enables dropout.
296
+ self.model.train()
297
+
298
+ train_loss = 0
299
+ n_elements_total = 0
300
+ train_interval_start = time.time()
301
+
302
+ while True:
303
+ for input_, condition, target in self.train_loader:
304
+ self.model.train()
305
+ if input_ != []:
306
+
307
+ loss, _ = self.forward_pass(input_, condition, target)
308
+ loss_val = loss.item()
309
+ loss /= args.accumulate_step
310
+
311
+ n_elements = input_.numel()
312
+ if not math.isnan(loss_val):
313
+ train_loss += n_elements * loss_val
314
+ n_elements_total += n_elements
315
+ self.n_sequences_total += input_.size(0)
316
+
317
+ self.scaler.scale(loss).backward()
318
+
319
+ if self.train_step % args.accumulate_step == 0:
320
+ self.scaler.unscale_(self.optimizer)
321
+ if args.clip > 0:
322
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.clip)
323
+ self.scaler.step(self.optimizer)
324
+ self.scaler.update()
325
+ self.model.zero_grad()
326
+
327
+ if args.scheduler != "constant":
328
+ # linear warmup stage
329
+ if self.train_step <= args.warmup_step:
330
+ curr_lr = args.lr * self.train_step / args.warmup_step
331
+ self.optimizer.param_groups[0]['lr'] = curr_lr
332
+ else:
333
+ self.scheduler.step()
334
+
335
+ if (self.train_step % args.gen_step == 0) and self.train_step > 0 and not args.regression:
336
+ # Generate and save samples
337
+ with torch.no_grad():
338
+ self.model.eval()
339
+ if args.max_gen_input_len > 0:
340
+ max_input_len = args.max_gen_input_len
341
+ else:
342
+ max_input_len = args.tgt_len
343
+
344
+ primers = [["<START>"]]
345
+ # Use fixed set of conditions
346
+ if args.conditioning == "none":
347
+ discrete_conditions = None
348
+ continuous_conditions = None
349
+ primers = [["<START>"] for _ in range(4)]
350
+
351
+ elif args.conditioning == "discrete_token":
352
+ discrete_conditions = [
353
+ ["<V-2>", "<A-2>"],
354
+ ["<V-2>", "<A2>"],
355
+ ["<V2>", "<A-2>"],
356
+ ["<V2>", "<A2>"],
357
+ ]
358
+ continuous_conditions = None
359
+ elif args.conditioning in ["continuous_token", "continuous_concat"]:
360
+ discrete_conditions = None
361
+ continuous_conditions = [
362
+ [-0.8, -0.8],
363
+ [-0.8, 0.8],
364
+ [0.8, -0.8],
365
+ [0.8, 0.8]
366
+ ]
367
+
368
+ generate(self.model, self.maps, self.device, self.gen_dir, args.conditioning,
369
+ debug=args.debug, verbose=False, amp=self.amp, discrete_conditions=discrete_conditions,
370
+ continuous_conditions=continuous_conditions, min_n_instruments=1,
371
+ gen_len=args.gen_len, max_input_len=max_input_len,
372
+ step=str(self.train_step), primers=primers,
373
+ temperatures=[args.temp_note, args.temp_rest])
374
+
375
+ if (self.train_step % args.log_step == 0):
376
+ # Print log
377
+ if n_elements_total > 0:
378
+ cur_loss = train_loss / n_elements_total
379
+ elapsed_total = time.time() - self.init_time
380
+ elapsed_interval = time.time() - train_interval_start
381
+ hours_elapsed = elapsed_total / 3600.0
382
+ hours_total = self.init_hours + hours_elapsed
383
+ lr = self.optimizer.param_groups[0]['lr']
384
+ log_str = '| Epoch {:3d} step {:>8d} | {:>6d} sequences | {:>3.1f} h | lr {:.2e} ' \
385
+ '| ms/batch {:4.0f} | loss {:7.4f}'.format(
386
+ self.epoch, self.train_step, self.n_sequences_total, hours_total, lr,
387
+ elapsed_interval * 1000 / args.log_step, cur_loss)
388
+ self.logging(log_str)
389
+ self.csv_writer.update({"epoch": self.epoch, "step": self.train_step, "hour": hours_total,
390
+ "lr": lr, "trn_loss": cur_loss, "val_loss": np.nan,
391
+ "val_l1_v": np.nan, "val_l1_a": np.nan})
392
+ train_loss = 0
393
+ n_elements_total = 0
394
+ self.n_good_output, self.n_nan_output = 0, 0
395
+ train_interval_start = time.time()
396
+
397
+ if not args.debug:
398
+ # Save model
399
+ model_fp = os.path.join(args.work_dir, 'model.pt')
400
+ torch.save(self.model.state_dict(), model_fp)
401
+ optimizer_fp = os.path.join(args.work_dir, 'optimizer.pt')
402
+ torch.save(self.optimizer.state_dict(), optimizer_fp)
403
+ scaler_fp = os.path.join(args.work_dir, 'scaler.pt')
404
+ torch.save(self.scaler.state_dict(), scaler_fp)
405
+ torch.save({"step": self.train_step, "hour": hours_total, "epoch": self.epoch,
406
+ "sample": self.n_sequences_total},
407
+ os.path.join(args.work_dir, 'stats.pt'))
408
+
409
+ if (self.train_step % args.eval_step == 0):
410
+ # Evaluate model
411
+ val_loss, val_acc = self.evaluate()
412
+ elapsed_total = time.time() - self.init_time
413
+ hours_elapsed = elapsed_total / 3600.0
414
+ hours_total = self.init_hours + hours_elapsed
415
+ lr = self.optimizer.param_groups[0]['lr']
416
+ self.logging('-' * 120)
417
+ log_str = '| Eval {:3d} step {:>8d} | now: {} | {:>3.1f} h' \
418
+ '| valid loss {:7.4f} | ppl {:5.3f}'.format(
419
+ self.train_step // args.eval_step, self.train_step,
420
+ time.strftime("%d-%m - %H:%M"), hours_total,
421
+ val_loss, math.exp(val_loss))
422
+ if args.regression:
423
+ log_str += " | l1_v: {:5.3f} | l1_a: {:5.3f}".format(
424
+ val_acc["l1_v"], val_acc["l1_a"])
425
+
426
+ self.csv_writer.update({"epoch": self.epoch, "step": self.train_step, "hour": hours_total,
427
+ "lr": lr, "trn_loss": np.nan, "val_loss": val_loss})
428
+
429
+ self.logging(log_str)
430
+ self.logging('-' * 120)
431
+
432
+ # dev-performance based learning rate annealing
433
+ if args.scheduler == 'dev_perf':
434
+ self.scheduler.step(val_loss)
435
+
436
+ if self.train_step >= args.max_step:
437
+ break
438
+ self.train_step += 1
439
+ self.epoch += 1
440
+ if self.train_step >= args.max_step:
441
+ break
442
+
443
+ def run(self):
444
+
445
+ # Loop over epochs.
446
+ # At any point you can hit Ctrl + C to break out of training early.
447
+ try:
448
+ if args.exhaustive_eval or args.regression_dir is not None:
449
+ self.logging("Exhaustive evaluation")
450
+ if args.regression_dir is not None:
451
+ self.logging(f"For regression on folder {args.regression_dir}")
452
+ loss, accuracies = self.evaluate()
453
+ perplexity = math.exp(loss)
454
+ elapsed_total = time.time() - self.init_time
455
+ hours_elapsed = elapsed_total / 3600.0
456
+ msg = f"Loss: {loss:7.4f}, ppl: {perplexity:5.2f}"
457
+ for k, v in accuracies.items():
458
+ if args.regression:
459
+ msg += f", {k}: {v:7.4f}"
460
+ else:
461
+ msg += f", top{k:1.0f}: {v:7.4f}"
462
+ msg += f", hours: {hours_elapsed:3.1f}"
463
+ self.logging(msg)
464
+ else:
465
+ while True:
466
+ self.train()
467
+ if self.train_step >= args.max_step:
468
+ self.logging('-' * 120)
469
+ self.logging('End of training')
470
+ break
471
+ except KeyboardInterrupt:
472
+ self.logging('-' * 120)
473
+ self.logging('Exiting from training early')
474
+
475
+ if __name__ == "__main__":
476
+ runner = Runner()
477
+ runner.run()
midi_emotion/src/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import csv
3
+ import shutil
4
+ import functools
5
+ import os
6
+
7
+
8
+ def split_list(alist, n_parts):
9
+ if n_parts == 0:
10
+ n_parts = 1
11
+ length = len(alist)
12
+ return [ alist[i*length // n_parts: (i+1)*length // n_parts]
13
+ for i in range(n_parts)]
14
+
15
+ def accuracy(output: torch.Tensor, target: torch.Tensor, topk=(1, 5), ignore_index=None):
16
+ """
17
+ Computes the accuracy over the k top predictions for the specified values of k
18
+ In top-5 accuracy you give yourself credit for having the right answer
19
+ if the right answer appears in your top five guesses.
20
+
21
+ ref:
22
+ - https://discuss.pytorch.org/t/top-k-error-calculation/48815/3
23
+
24
+ - https://pytorch.org/docs/stable/generated/torch.topk.html
25
+ - https://discuss.pytorch.org/t/imagenet-example-accuracy-calculation/7840
26
+ - https://gist.github.com/weiaicunzai/2a5ae6eac6712c70bde0630f3e76b77b
27
+ - https://discuss.pytorch.org/t/top-k-error-calculation/48815/2
28
+ - https://stackoverflow.com/questions/59474987/how-to-get-top-k-accuracy-in-semantic-segmentation-using-pytorch
29
+
30
+ :param output: output is the prediction of the model e.g. scores, logits, raw y_pred before normalization or getting classes
31
+ :param target: target is the truth
32
+ :param topk: tuple of topk's to compute e.g. (1, 2, 5) computes top 1, top 2 and top 5.
33
+ e.g. in top 2 it means you get a +1 if your models's top 2 predictions are in the right label.
34
+ So if your model predicts cat, dog (0, 1) and the true label was bird (3) you get zero
35
+ but if it were either cat or dog you'd accumulate +1 for that example.
36
+ :return: list of topk accuracy [top1st, top2nd, ...] depending on your topk input
37
+ """
38
+ with torch.no_grad():
39
+ # ---- get the topk most likely labels according to your model
40
+ # get the largest k \in [n_classes] (i.e. the number of most likely probabilities we will use)
41
+
42
+ maxk = max(topk) # max number labels we will consider in the right choices for out model
43
+
44
+ output = output.reshape(-1, output.size(-1))
45
+ target = target.reshape(-1)
46
+
47
+ valid_inds = torch.where(target != ignore_index)[0]
48
+ target = target[valid_inds]
49
+ output = output[valid_inds, :]
50
+
51
+ sample_size = target.size(0)
52
+
53
+ # get top maxk indicies that correspond to the most likely probability scores
54
+ # (note _ means we don't care about the actual top maxk scores just their corresponding indicies/labels)
55
+ _, y_pred = output.topk(k=maxk, dim=-1) # _, [B, n_classes] -> [B, maxk]
56
+ y_pred = y_pred.t() # [B, maxk] -> [maxk, B] Expects input to be <= 2-D tensor and transposes dimensions 0 and 1.
57
+
58
+ # - get the credit for each example if the models predictions is in maxk values (main crux of code)
59
+ # for any example, the model will get credit if it's prediction matches the ground truth
60
+ # for each example we compare if the model's best prediction matches the truth. If yes we get an entry of 1.
61
+ # if the k'th top answer of the model matches the truth we get 1.
62
+ # Note: this for any example in batch we can only ever get 1 match (so we never overestimate accuracy <1)
63
+ target_reshaped = target.view(1, -1).expand_as(y_pred) # [B] -> [B, 1] -> [maxk, B]
64
+ # compare every topk's model prediction with the ground truth & give credit if any matches the ground truth
65
+ correct = (y_pred == target_reshaped) # [maxk, B] were for each example we know which topk prediction matched truth
66
+ # original: correct = pred.eq(target.view(1, -1).expand_as(pred))
67
+
68
+ # -- get topk accuracy
69
+ list_topk_accs = {}
70
+ for k in topk:
71
+ # get tensor of which topk answer was right
72
+ ind_which_topk_matched_truth = correct[:k] # [maxk, B] -> [k, B]
73
+ # flatten it to help compute if we got it correct for each example in batch
74
+ flattened_indicator_which_topk_matched_truth = ind_which_topk_matched_truth.reshape(-1).float() # [k, B] -> [kB]
75
+ # get if we got it right for any of our top k prediction for each example in batch
76
+ tot_correct_topk = flattened_indicator_which_topk_matched_truth.float().sum(dim=0, keepdim=True) # [kB] -> [1]
77
+ # compute topk accuracy - the accuracy of the mode's ability to get it right within it's top k guesses/preds
78
+ topk_acc = tot_correct_topk / sample_size # topk accuracy for entire batch
79
+ list_topk_accs[k] = topk_acc.item()
80
+ return list_topk_accs # list of topk accuracies for entire batch [topk1, topk2, ... etc]
81
+
82
+ class CsvWriter:
83
+ # Save performance as a csv file
84
+ def __init__(self, out_path, fieldnames, in_path=None, debug=False):
85
+
86
+ self.out_path = out_path
87
+ self.fieldnames = fieldnames
88
+ self.debug = debug
89
+
90
+ if not debug:
91
+ if in_path is None:
92
+ with open(out_path, "w") as f:
93
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
94
+ writer.writeheader()
95
+ else:
96
+ try:
97
+ shutil.copy(in_path, out_path)
98
+ except:
99
+ with open(out_path, "w") as f:
100
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
101
+ writer.writeheader()
102
+
103
+
104
+ def update(self, performance_dict):
105
+ if not self.debug:
106
+ with open(self.out_path, "a") as f:
107
+ writer = csv.DictWriter(f, fieldnames=self.fieldnames)
108
+ writer.writerow(performance_dict)
109
+ a = 0
110
+
111
+ def generate_square_subsequent_mask(sz):
112
+ # Triangular mask to avoid looking at future tokens
113
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
114
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
115
+ return mask
116
+
117
+
118
+ def logging(s, log_path, print_=True, log_=True):
119
+ # Prints log
120
+ if print_:
121
+ print(s)
122
+ if log_:
123
+ with open(log_path, 'a+') as f_log:
124
+ f_log.write(s + '\n')
125
+
126
+ def get_logger(log_path, **kwargs):
127
+ return functools.partial(logging, log_path=log_path, **kwargs)
128
+
129
+ def create_exp_dir(dir_path, debug=False):
130
+ # Create experiment directory
131
+ if debug:
132
+ print('Debug Mode : no experiment dir created')
133
+ return functools.partial(logging, log_path=None, log_=False)
134
+ else:
135
+ if not os.path.exists(dir_path):
136
+ os.makedirs(dir_path)
137
+
138
+ print('Experiment dir : {}'.format(dir_path))
139
+
140
+ return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
141
+
142
+
143
+ def get_n_instruments(symbols):
144
+ # Find number of instruments
145
+ symbols_split = [s.split("_") for s in symbols]
146
+ symbols_split = [s[1] for s in symbols_split if len(s) == 3]
147
+ events = list(set(symbols_split))
148
+ return len(events)
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fluidsynth
2
+ fluid-soundfont-gm
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ numpy>=1.24.0
4
+ matplotlib>=3.7.0
5
+ Pillow>=10.0.0
6
+ huggingface-hub>=0.19.0
7
+ pretty-midi>=0.2.10
8
+ librosa>=0.10.0
9
+ soundfile>=0.12.0