projectlosangeles commited on
Commit
f6d08a7
·
verified ·
1 Parent(s): 9bade1b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +486 -0
app.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #==================================================================================
2
+ # https://huggingface.co/spaces/projectlosangeles/MuseCraft-Piano-Chords-Texturing
3
+ #==================================================================================
4
+
5
+ print('=' * 70)
6
+ print('MuseCraft Piano Chords Texturing Gradio App')
7
+
8
+ print('=' * 70)
9
+ print('Loading core MuseCraft Piano Chords Texturing modules...')
10
+
11
+ import os
12
+ import copy
13
+
14
+ import time as reqtime
15
+ import datetime
16
+ from pytz import timezone
17
+
18
+ print('=' * 70)
19
+ print('Loading main MuseCraft Piano Chords Texturing modules...')
20
+
21
+ os.environ['USE_FLASH_ATTENTION'] = '1'
22
+
23
+ import torch
24
+
25
+ torch.set_float32_matmul_precision('medium')
26
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
27
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
28
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
29
+ torch.backends.cuda.enable_math_sdp(True)
30
+ torch.backends.cuda.enable_flash_sdp(True)
31
+ torch.backends.cuda.enable_cudnn_sdp(True)
32
+
33
+ from huggingface_hub import hf_hub_download
34
+
35
+ import TMIDIX
36
+
37
+ from midi_to_colab_audio import midi_to_colab_audio
38
+
39
+ from x_transformer_2_3_1 import *
40
+
41
+ import random
42
+
43
+ import tqdm
44
+
45
+ print('=' * 70)
46
+ print('Loading aux MuseCraft Piano Chords Texturing modules...')
47
+
48
+ import matplotlib.pyplot as plt
49
+
50
+ import gradio as gr
51
+ import spaces
52
+
53
+ print('=' * 70)
54
+ print('PyTorch version:', torch.__version__)
55
+ print('=' * 70)
56
+ print('Done!')
57
+ print('Enjoy! :)')
58
+ print('=' * 70)
59
+
60
+ #==================================================================================
61
+
62
+ MODEL_CHECKPOINT = 'Godzilla_Piano_Chords_Texturing_Trained_Model_36457_steps_0.5384_loss_0.8417_acc.pth'
63
+
64
+ SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
65
+
66
+ MAX_MELODY_NOTES = 64
67
+
68
+ MAX_GEN_TOKS = 3072
69
+
70
+ #==================================================================================
71
+
72
+ print('=' * 70)
73
+ print('Loading popular hook melodies dataset...')
74
+
75
+ popular_hook_melodies_pickle = hf_hub_download(repo_id='projectlosangeles/MuseCraft-Piano-Chords-Texturing',
76
+ filename='popular_hook_melodies_24_64_CC_BY_NC_SA.pickle'
77
+ )
78
+
79
+ popular_hook_melodies = TMIDIX.Tegridy_Any_Pickle_File_Reader(popular_hook_melodies_pickle)
80
+
81
+ print('=' * 70)
82
+ print('Done!')
83
+ print('=' * 70)
84
+
85
+ #==================================================================================
86
+
87
+ print('=' * 70)
88
+ print('Instantiating model...')
89
+
90
+ device_type = 'cuda'
91
+ dtype = 'bfloat16'
92
+
93
+ ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
94
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
95
+
96
+ SEQ_LEN = 4096
97
+ PAD_IDX = 1794
98
+
99
+ model = TransformerWrapper(
100
+ num_tokens = PAD_IDX+1,
101
+ max_seq_len = SEQ_LEN,
102
+ attn_layers = Decoder(dim = 2048,
103
+ depth = 4,
104
+ heads = 32,
105
+ rotary_pos_emb = True,
106
+ attn_flash = True
107
+ )
108
+ )
109
+
110
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
111
+
112
+ print('=' * 70)
113
+ print('Loading model checkpoint...')
114
+
115
+ model_checkpoint = hf_hub_download(repo_id='projectlosangeles/MuseCraft-Piano-Chords-Texturing', filename=MODEL_CHECKPOINT)
116
+
117
+ model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
118
+
119
+ model = torch.compile(model, mode='max-autotune')
120
+
121
+ print('=' * 70)
122
+ print('Done!')
123
+ print('=' * 70)
124
+ print('Model will use', dtype, 'precision...')
125
+ print('=' * 70)
126
+
127
+ #==================================================================================
128
+
129
+ def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
130
+
131
+ raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
132
+
133
+ escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
134
+ escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
135
+
136
+ sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes, keep_drums=False)
137
+
138
+ if melody_patch == -1:
139
+ zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
140
+
141
+ else:
142
+ mel_score = [e for e in sp_escore_notes if e[6] == melody_patch]
143
+
144
+ if mel_score:
145
+ zscore = TMIDIX.recalculate_score_timings(mel_score)
146
+
147
+ else:
148
+ zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
149
+
150
+ cscore = TMIDIX.chordify_score([1000, zscore])[:MAX_MELODY_NOTES:use_nth_note]
151
+
152
+ score = []
153
+
154
+ score_list = []
155
+
156
+ pc = cscore[0]
157
+
158
+ for c in cscore:
159
+ score.append(max(0, min(127, c[0][1]-pc[0][1])))
160
+
161
+ scl = [[max(0, min(127, c[0][1]-pc[0][1]))]]
162
+
163
+ n = c[0]
164
+
165
+ score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
166
+ scl.append([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
167
+
168
+ score_list.append(scl)
169
+
170
+ pc = c
171
+
172
+ score_list.append(scl)
173
+
174
+ return score, score_list
175
+
176
+ #==================================================================================
177
+
178
+ @spaces.GPU
179
+ def Generate_Accompaniment(input_midi,
180
+ input_melody,
181
+ melody_patch,
182
+ use_nth_note,
183
+ model_temperature,
184
+ model_sampling_top_k
185
+ ):
186
+
187
+ #===============================================================================
188
+
189
+ def generate_full_seq(input_seq,
190
+ max_toks=3072,
191
+ temperature=0.9,
192
+ top_k_value=15,
193
+ verbose=True
194
+ ):
195
+
196
+ seq_abs_run_time = sum([t for t in input_seq if t < 128])
197
+
198
+ cur_time = 0
199
+
200
+ full_seq = copy.deepcopy(input_seq)
201
+
202
+ toks_counter = 0
203
+
204
+ while cur_time <= seq_abs_run_time+32:
205
+
206
+ if verbose:
207
+ if toks_counter % 128 == 0:
208
+ print('Generated', toks_counter, 'tokens')
209
+
210
+ x = torch.LongTensor(full_seq).cuda()
211
+
212
+ with ctx:
213
+ out = model.generate(x,
214
+ 1,
215
+ filter_logits_fn=top_k,
216
+ filter_kwargs={'k': top_k_value},
217
+ temperature=temperature,
218
+ return_prime=False,
219
+ verbose=False)
220
+
221
+ y = out.tolist()[0][0]
222
+
223
+ if y < 128:
224
+ cur_time += y
225
+
226
+ full_seq.append(y)
227
+
228
+ toks_counter += 1
229
+
230
+ if toks_counter == max_toks:
231
+ return full_seq
232
+
233
+ return full_seq
234
+
235
+ #===============================================================================
236
+
237
+ print('=' * 70)
238
+ print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
239
+ start_time = reqtime.time()
240
+ print('=' * 70)
241
+
242
+ print('=' * 70)
243
+ print('Requested settings:')
244
+ print('=' * 70)
245
+ if input_midi:
246
+ fn = os.path.basename(input_midi)
247
+ fn1 = fn.split('.')[0]
248
+ print('Input MIDI file name:', fn)
249
+
250
+ else:
251
+ print('Input sample melody:', input_melody)
252
+ print('Source melody patch:', melody_patch)
253
+ print('Use nth melody note:', use_nth_note)
254
+ print('Model temperature:', model_temperature)
255
+ print('Model top k:', model_sampling_top_k)
256
+
257
+ print('=' * 70)
258
+
259
+ #==================================================================
260
+
261
+ print('Prepping melody...')
262
+
263
+ if input_midi:
264
+ inp_mel = 'Custom MIDI'
265
+ score, score_list = load_midi(input_midi.name, melody_patch, use_nth_note)
266
+
267
+ else:
268
+ mel_list = [m[0].lower() for m in popular_hook_melodies]
269
+
270
+ inp_mel = random.choice(mel_list).title()
271
+
272
+ for m in mel_list:
273
+ if input_melody.lower().strip() in m:
274
+ inp_mel = m.title()
275
+ break
276
+
277
+ score = popular_hook_melodies[[m[0] for m in popular_hook_melodies].index(inp_mel)][1]
278
+ score_list = [[[score[i]], score[i+1:i+3]] for i in range(0, len(score)-3, 3)]
279
+
280
+ print('Selected melody:', inp_mel)
281
+
282
+ print('Sample score events', score[:12])
283
+
284
+ #==================================================================
285
+
286
+ print('=' * 70)
287
+ print('Generating...')
288
+
289
+ model.to(device_type)
290
+ model.eval()
291
+
292
+ #==================================================================
293
+
294
+ start_score_seq = [1792] + score + [1793]
295
+
296
+ #==================================================================
297
+
298
+ input_seq = generate_full_seq(start_score_seq,
299
+ max_toks=MAX_GEN_TOKS,
300
+ temperature=model_temperature,
301
+ top_k_value=model_sampling_top_k,
302
+ )
303
+
304
+ final_song = input_seq[len(start_score_seq):]
305
+
306
+ print('=' * 70)
307
+ print('Done!')
308
+ print('=' * 70)
309
+
310
+ #===============================================================================
311
+
312
+ print('Rendering results...')
313
+
314
+ print('=' * 70)
315
+ print('Sample INTs', final_song[:15])
316
+ print('=' * 70)
317
+
318
+ song_f = []
319
+
320
+ if len(final_song) != 0:
321
+
322
+ time = 0
323
+ dur = 0
324
+ vel = 90
325
+ pitch = 0
326
+ channel = 0
327
+ patch = 0
328
+
329
+ channels_map = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 9, 12, 13, 14, 15]
330
+ patches_map = [40, 0, 10, 19, 24, 35, 40, 52, 56, 9, 65, 73, 0, 0, 0, 0]
331
+ velocities_map = [125, 80, 100, 80, 90, 100, 100, 80, 110, 110, 110, 110, 80, 80, 80, 80]
332
+
333
+ for m in final_song:
334
+
335
+ if 0 <= m < 128:
336
+ time += m * 32
337
+
338
+ elif 128 < m < 256:
339
+ dur = (m-128) * 32
340
+
341
+ elif 256 < m < 1792:
342
+ cha = (m-256) // 128
343
+ pitch = (m-256) % 128
344
+
345
+ channel = channels_map[cha]
346
+ patch = patches_map[channel]
347
+ vel = velocities_map[channel]
348
+
349
+ song_f.append(['note', time, dur, channel, pitch, vel, patch])
350
+
351
+ fn1 = "MuseCraft-Piano-Chords-Texturing-Composition"
352
+
353
+ detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
354
+ output_signature = 'MuseCraft Piano Chords Texturing',
355
+ output_file_name = fn1,
356
+ track_name='Project Los Angeles',
357
+ list_of_MIDI_patches=patches_map
358
+ )
359
+
360
+ new_fn = fn1+'.mid'
361
+
362
+
363
+ audio = midi_to_colab_audio(new_fn,
364
+ soundfont_path=SOUDFONT_PATH,
365
+ sample_rate=16000,
366
+ volume_scale=10,
367
+ output_for_gradio=True
368
+ )
369
+
370
+ print('Done!')
371
+ print('=' * 70)
372
+
373
+ #========================================================
374
+
375
+ output_title = str(inp_mel)
376
+ output_midi = str(new_fn)
377
+ output_audio = (16000, audio)
378
+
379
+ output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True)
380
+
381
+ print('Output MIDI file name:', output_midi)
382
+ print('Output MIDI melody title:', output_title)
383
+ print('=' * 70)
384
+
385
+ #========================================================
386
+
387
+ print('-' * 70)
388
+ print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
389
+ print('-' * 70)
390
+ print('Req execution time:', (reqtime.time() - start_time), 'sec')
391
+
392
+ return output_title, output_audio, output_plot, output_midi
393
+
394
+ #==================================================================================
395
+
396
+ PDT = timezone('US/Pacific')
397
+
398
+ print('=' * 70)
399
+ print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
400
+ print('=' * 70)
401
+
402
+ #==================================================================================
403
+
404
+ with gr.Blocks() as demo:
405
+
406
+ #==================================================================================
407
+
408
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MuseCraft Piano Chords Texturing</h1>")
409
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Solo Piano chords texturing model for MuseCraft project</h1>")
410
+ gr.HTML("""
411
+ <p>
412
+ <a href="https://huggingface.co/spaces/projectlosangeles/MuseCraft-Piano-Chords-Texturing?duplicate=true">
413
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face">
414
+ </a>
415
+ </p>
416
+
417
+ for faster execution and endless generation!
418
+ """)
419
+
420
+ #==================================================================================
421
+
422
+ gr.Markdown("## Upload source melody MIDI or enter a search query for a sample melody below")
423
+
424
+ input_midi = gr.File(label="Input MIDI",
425
+ file_types=[".midi", ".mid", ".kar"]
426
+ )
427
+
428
+ input_melody = gr.Textbox(value="Hotel California",
429
+ label="Popular melodies database search query",
430
+ info='If the query is not found, random melody will be selected. Custom MIDI overrides search query'
431
+ )
432
+
433
+ gr.Markdown("## Generation options")
434
+
435
+ melody_patch = gr.Slider(-1, 127, value=-1, step=1, label="Source melody MIDI patch")
436
+ use_nth_note = gr.Slider(1, 8, value=1, step=1, label="Use each nth melody note")
437
+ model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
438
+ model_sampling_top_k = gr.Slider(1, 100, value=15, step=1, label="Model sampling top k value")
439
+
440
+ generate_btn = gr.Button("Generate", variant="primary")
441
+
442
+ gr.Markdown("## Generation results")
443
+
444
+ output_title = gr.Textbox(label="MIDI melody title")
445
+ output_audio = gr.Audio(label="MIDI audio", format="wav", elem_id="midi_audio")
446
+ output_plot = gr.Plot(label="MIDI score plot")
447
+ output_midi = gr.File(label="MIDI file", file_types=[".mid"])
448
+
449
+ generate_btn.click(Generate_Accompaniment,
450
+ [input_midi,
451
+ input_melody,
452
+ melody_patch,
453
+ use_nth_note,
454
+ model_temperature,
455
+ model_sampling_top_k
456
+ ],
457
+ [output_title,
458
+ output_audio,
459
+ output_plot,
460
+ output_midi
461
+ ]
462
+ )
463
+
464
+ gr.Examples(
465
+ [["Sharing The Night Together.kar", "Custom MIDI", -1, 1, 0.9, 15]
466
+ ],
467
+ [input_midi,
468
+ input_melody,
469
+ melody_patch,
470
+ use_nth_note,
471
+ model_temperature,
472
+ model_sampling_top_k
473
+ ],
474
+ [output_title,
475
+ output_audio,
476
+ output_plot,
477
+ output_midi
478
+ ],
479
+ Generate_Accompaniment
480
+ )
481
+
482
+ #==================================================================================
483
+
484
+ demo.launch()
485
+
486
+ #==================================================================================