asigalov61 commited on
Commit
12a4852
·
verified ·
1 Parent(s): 51677ef

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -517
app.py CHANGED
@@ -1,613 +1,426 @@
1
- #==================================================================================
2
- # https://huggingface.co/spaces/asigalov61/Godzilla-Piano-Transformer
3
- #==================================================================================
4
-
5
- print('=' * 70)
6
- print('Godzilla Piano Transformer Gradio App')
7
-
8
- print('=' * 70)
9
- print('Loading core Godzilla Piano Transformer modules...')
10
 
11
  import os
12
-
13
  import time as reqtime
14
  import datetime
15
  from pytz import timezone
16
 
17
- print('=' * 70)
18
- print('Loading main Godzilla Piano Transformer modules...')
19
-
20
- os.environ['USE_FLASH_ATTENTION'] = '1'
21
-
22
  import torch
23
-
24
- torch.set_float32_matmul_precision('high')
25
- torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
26
- torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
27
- torch.backends.cuda.enable_mem_efficient_sdp(True)
28
- torch.backends.cuda.enable_math_sdp(True)
29
- torch.backends.cuda.enable_flash_sdp(True)
30
- torch.backends.cuda.enable_cudnn_sdp(True)
31
-
32
- from huggingface_hub import hf_hub_download
33
-
34
- import TMIDIX
35
-
36
- from midi_to_colab_audio import midi_to_colab_audio
37
-
38
- from x_transformer_2_3_1 import *
39
-
40
- import random
41
-
42
- print('=' * 70)
43
- print('Loading aux Godzilla Piano Transformer modules...')
44
-
45
  import matplotlib.pyplot as plt
46
-
47
  import gradio as gr
48
  import spaces
49
 
50
- print('=' * 70)
51
- print('PyTorch version:', torch.__version__)
52
- print('=' * 70)
53
- print('Done!')
54
- print('Enjoy! :)')
55
- print('=' * 70)
56
-
57
- #==================================================================================
58
 
59
- MODEL_CHECKPOINTS = {
60
- 'without velocity - 3 epochs': 'Godzilla_Piano_Transformer_No_Velocity_Trained_Model_14075_steps_0.4534_loss_0.8687_acc.pth'
61
- }
 
 
62
 
 
63
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
64
-
65
  NUM_OUT_BATCHES = 12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- PREVIEW_LENGTH = 120 # in tokens
 
 
 
 
 
 
68
 
69
- #==================================================================================
 
 
 
70
 
71
- print('=' * 70)
72
- print('Instantiating model...')
 
 
 
73
 
74
  device_type = 'cuda'
75
  dtype = 'bfloat16'
76
-
77
  ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
78
  ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
79
 
80
  SEQ_LEN = 4096
81
-
82
  PAD_IDX = 384
83
 
84
  model = TransformerWrapper(
85
- num_tokens = PAD_IDX+1,
86
- max_seq_len = SEQ_LEN,
87
- attn_layers = Decoder(dim = 2048,
88
- depth = 16,
89
- heads = 32,
90
- rotary_pos_emb = True,
91
- attn_flash = True
92
- )
 
93
  )
94
-
95
  model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
96
 
97
- print('=' * 70)
98
- print('Loading model checkpoint...')
99
-
100
- model_checkpoint = hf_hub_download(repo_id='asigalov61/Godzilla-Piano-Transformer',
101
- filename='Godzilla_Piano_Transformer_No_Velocity_Trained_Model_14075_steps_0.4534_loss_0.8687_acc.pth')
102
-
103
- model.load_state_dict(torch.load(model_checkpoint, map_location='cuda', weights_only=True))
104
-
105
  model = torch.compile(model, mode='max-autotune')
106
-
107
- print('=' * 70)
108
- print('Done!')
109
- print('=' * 70)
110
- print('Model will use', dtype, 'precision...')
111
- print('=' * 70)
112
 
113
  model.cuda()
114
  model.eval()
115
 
116
- #==================================================================================
117
-
118
- def load_model(model_selector):
119
-
120
- return [[], []]
121
-
122
- #==================================================================================
123
-
124
- def load_midi(input_midi, model_selector=''):
125
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
127
-
128
- escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True)[0]
129
-
130
  sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes)
131
  zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
132
-
133
  zscore = TMIDIX.augment_enhanced_score_notes(zscore, timings_divider=32)
134
-
135
  fscore = TMIDIX.fix_escore_notes_durations(zscore)
136
-
137
  cscore = TMIDIX.chordify_score([1000, fscore])
138
 
139
  score = []
140
-
141
- pc = cscore[0]
142
-
143
- for c in cscore:
144
- score.append(max(0, min(127, c[0][1]-pc[0][1])))
145
-
146
- for n in c:
147
- if model_selector == 'with velocity - 3 epochs':
148
- score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256, max(1, min(127, n[5]))+384])
149
-
150
- else:
151
- score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
152
-
153
- pc = c
154
-
155
  return score
156
 
157
- #==================================================================================
158
-
159
- def save_midi(tokens, batch_number=None, model_selector=''):
160
-
161
- song = tokens
162
- song_f = []
163
-
164
- time = 0
165
- dur = 0
166
- vel = 90
167
  pitch = 0
168
- channel = 0
169
- patch = 0
170
-
171
  patches = [0] * 16
172
 
173
- for m in song:
174
-
175
- if 0 <= m < 128:
176
- time += m * 32
177
-
178
- elif 128 < m < 256:
179
- dur = (m-128) * 32
180
-
181
- elif 256 < m < 384:
182
- pitch = (m-256)
183
-
184
- if model_selector == 'without velocity - 3 epochs' or model_selector == 'without velocity - 7 epochs':
185
- song_f.append(['note', time, dur, 0, pitch, max(40, pitch), 0])
186
-
187
- elif 384 < m < 512:
188
- vel = (m-384)
189
-
190
- if model_selector == 'with velocity - 3 epochs':
191
- song_f.append(['note', time, dur, 0, pitch, vel, 0])
192
-
193
- if batch_number == None:
194
- fname = 'Godzilla-Piano-Transformer-Music-Composition'
195
-
196
  else:
197
- fname = 'Godzilla-Piano-Transformer-Music-Composition_'+str(batch_number)
198
-
199
- data = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
200
- output_signature = 'Godzilla Piano Transformer',
201
- output_file_name = fname,
202
- track_name='Project Los Angeles',
203
- list_of_MIDI_patches=patches,
204
- verbose=False
205
- )
206
-
207
- return song_f
208
-
209
- #==================================================================================
210
-
 
211
  @spaces.GPU
212
- def generate_music(prime,
213
- num_gen_tokens,
214
- num_mem_tokens,
215
- num_gen_batches,
216
- model_temperature,
217
- # model_sampling_top_p,
218
- model_state
219
- ):
220
-
221
- if not prime:
222
- inputs = [0]
223
-
224
- else:
225
- inputs = prime[-num_mem_tokens:]
226
-
227
- print('Generating...')
228
-
229
- inp = [inputs] * num_gen_batches
230
-
231
- inp = torch.LongTensor(inp).cuda()
232
-
233
  with ctx:
234
- out = model.generate(inp,
235
- num_gen_tokens,
236
- #filter_logits_fn=top_p,
237
- #filter_kwargs={'thres': 0.96},
238
- temperature=model_temperature,
239
- return_prime=False,
240
- verbose=False)
241
-
242
- output = out.tolist()
243
-
244
- print('Done!')
245
- print('=' * 70)
246
-
247
- return output
248
-
249
- #==================================================================================
250
-
251
- def generate_callback(input_midi,
252
- num_prime_tokens,
253
- num_gen_tokens,
254
- num_mem_tokens,
255
- model_temperature,
256
- # model_sampling_top_p,
257
- final_composition,
258
- generated_batches,
259
- block_lines,
260
- model_state
261
- ):
262
-
263
- generated_batches = []
264
-
265
  if not final_composition and input_midi is not None:
266
- final_composition = load_midi(input_midi, model_selector=model_state[2])[:num_prime_tokens]
267
- midi_score = save_midi(final_composition, model_selector=model_state[2])
268
- block_lines.append(midi_score[-1][1] / 1000)
269
-
270
- batched_gen_tokens = generate_music(final_composition,
271
- num_gen_tokens,
272
- num_mem_tokens,
273
- NUM_OUT_BATCHES,
274
- model_temperature,
275
- # model_sampling_top_p,
276
- model_state
277
- )
278
-
279
- outputs = []
280
-
281
- for i in range(len(batched_gen_tokens)):
282
-
283
- tokens = batched_gen_tokens[i]
284
-
285
- # Preview
286
- tokens_preview = final_composition[-PREVIEW_LENGTH:]
287
-
288
- # Save MIDI to a temporary file
289
- midi_score = save_midi(tokens_preview + tokens, i, model_selector=model_state[2])
290
-
291
- # MIDI plot
292
-
293
  if len(final_composition) > PREVIEW_LENGTH:
294
- midi_plot = TMIDIX.plot_ms_SONG(midi_score,
295
- plot_title='Batch # ' + str(i),
296
- preview_length_in_notes=int(PREVIEW_LENGTH / 3),
297
- return_plt=True
298
- )
299
-
300
- else:
301
- midi_plot = TMIDIX.plot_ms_SONG(midi_score,
302
- plot_title='Batch # ' + str(i),
303
- return_plt=True
304
- )
305
-
306
- # File name
307
- fname = 'Godzilla-Piano-Transformer-Music-Composition_'+str(i)
308
-
309
- # Save audio to a temporary file
310
- midi_audio = midi_to_colab_audio(fname + '.mid',
311
- soundfont_path=SOUDFONT_PATH,
312
- sample_rate=16000,
313
- output_for_gradio=True
314
- )
315
-
316
- outputs.append([(16000, midi_audio), midi_plot, tokens])
317
-
318
- return outputs, final_composition, generated_batches, block_lines
319
-
320
- #==================================================================================
321
-
322
- def generate_callback_wrapper(input_midi,
323
- num_prime_tokens,
324
- num_gen_tokens,
325
- num_mem_tokens,
326
- model_temperature,
327
- # model_sampling_top_p,
328
- final_composition,
329
- generated_batches,
330
- block_lines,
331
- model_selector,
332
- model_state
333
- ):
334
-
335
- print('=' * 70)
336
- print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
337
- start_time = reqtime.time()
338
-
339
- print('=' * 70)
340
- if input_midi is not None:
341
- fn = os.path.basename(input_midi.name)
342
- fn1 = fn.split('.')[0]
343
- print('Input file name:', fn)
344
-
345
- print('Selected model type:', model_selector)
346
-
347
- if not model_state:
348
- model_state = load_model(model_selector)
349
- model_state.append(model_selector)
350
-
351
- else:
352
- if model_selector != model_state[2]:
353
- print('=' * 70)
354
- print('Switching model...')
355
- model_state = load_model(model_selector)
356
- model_state.append(model_selector)
357
- print('=' * 70)
358
-
359
- print('Num prime tokens:', num_prime_tokens)
360
- print('Num gen tokens:', num_gen_tokens)
361
- print('Num mem tokens:', num_mem_tokens)
362
-
363
- print('Model temp:', model_temperature)
364
- # print('Model top_p:', model_sampling_top_p)
365
- print('=' * 70)
366
-
367
- result = generate_callback(input_midi,
368
- num_prime_tokens,
369
- num_gen_tokens,
370
- num_mem_tokens,
371
- model_temperature,
372
- # model_sampling_top_p,
373
- final_composition,
374
- generated_batches,
375
- block_lines,
376
- model_state
377
- )
378
-
379
- generated_batches = [sublist[-1] for sublist in result[0]]
380
-
381
- print('=' * 70)
382
- print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
383
- print('=' * 70)
384
- print('Req execution time:', (reqtime.time() - start_time), 'sec')
385
- print('*' * 70)
386
-
387
- return tuple([result[1], generated_batches, result[3]] + [item for sublist in result[0] for item in sublist[:-1]] + [model_state])
388
-
389
- #==================================================================================
390
-
391
- def add_batch(batch_number, final_composition, generated_batches, block_lines, model_state=[]):
392
-
393
  if generated_batches:
394
  final_composition.extend(generated_batches[batch_number])
395
-
396
- # Save MIDI to a temporary file
397
- midi_score = save_midi(final_composition, model_selector=model_state[2])
398
-
399
- block_lines.append(midi_score[-1][1] / 1000)
400
-
401
- # MIDI plot
402
- midi_plot = TMIDIX.plot_ms_SONG(midi_score,
403
- plot_title='Godzilla Piano Transformer Composition',
404
- block_lines_times_list=block_lines[:-1],
405
- return_plt=True)
406
-
407
- # File name
408
- fname = 'Godzilla-Piano-Transformer-Music-Composition'
409
-
410
- # Save audio to a temporary file
411
- midi_audio = midi_to_colab_audio(fname + '.mid',
412
- soundfont_path=SOUDFONT_PATH,
413
- sample_rate=16000,
414
- output_for_gradio=True
415
- )
416
-
417
- print('Added batch #', batch_number)
418
- print('=' * 70)
419
-
420
- return (16000, midi_audio), midi_plot, fname+'.mid', final_composition, generated_batches, block_lines
421
-
422
  else:
423
  return None, None, None, [], [], []
424
 
425
- #==================================================================================
426
-
427
- def remove_batch(batch_number, num_tokens, final_composition, generated_batches, block_lines, model_state=[]):
428
-
429
- if final_composition:
430
-
431
- if len(final_composition) > num_tokens:
432
- final_composition = final_composition[:-num_tokens]
433
  block_lines.pop()
434
-
435
- # Save MIDI to a temporary file
436
- midi_score = save_midi(final_composition, model_selector=model_state[2])
437
-
438
- # MIDI plot
439
- midi_plot = TMIDIX.plot_ms_SONG(midi_score,
440
- plot_title='Godzilla Piano Transformer Composition',
441
- block_lines_times_list=block_lines[:-1],
442
- return_plt=True)
443
-
444
- # File name
445
- fname = 'Godzilla-Piano-Transformer-Music-Composition'
446
-
447
- # Save audio to a temporary file
448
- midi_audio = midi_to_colab_audio(fname + '.mid',
449
- soundfont_path=SOUDFONT_PATH,
450
- sample_rate=16000,
451
- output_for_gradio=True
452
- )
453
-
454
- print('Removed batch #', batch_number)
455
- print('=' * 70)
456
-
457
- return (16000, midi_audio), midi_plot, fname+'.mid', final_composition, generated_batches, block_lines
458
-
459
  else:
460
  return None, None, None, [], [], []
461
-
462
- def clear():
463
- return None, None, None, [], []
464
-
465
- #==================================================================================
466
 
467
- def reset(final_composition=[], generated_batches=[], block_lines=[], model_state=[]):
468
-
469
- final_composition = []
470
- generated_batches = []
471
- block_lines = []
472
- model_state = []
473
-
474
- return final_composition, generated_batches, block_lines
475
-
476
- #==================================================================================
477
-
478
- def reset_demo(final_composition=[], generated_batches=[], block_lines=[], model_state=[]):
479
-
480
- final_composition = []
481
- generated_batches = []
482
- block_lines = []
483
- model_state = []
484
-
485
- #==================================================================================
486
-
487
- PDT = timezone('US/Pacific')
488
 
489
- print('=' * 70)
490
- print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
491
- print('=' * 70)
492
 
493
- #==================================================================================
 
 
494
 
 
 
 
495
  with gr.Blocks() as demo:
496
-
497
- #==================================================================================
498
-
499
  demo.load(reset_demo)
500
 
501
- #==================================================================================
502
-
503
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Godzilla Piano Transformer</h1>")
504
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Fast 807M 4k solo Piano music transformer trained on 1.14M+ MIDIs (2.7M+ samples)</h1>")
505
  gr.HTML("""
506
- Check out <a href="https://huggingface.co/datasets/asigalov61/Godzilla-Piano">Godzilla Piano dataset</a> on Hugging Face
507
-
508
- <p>
509
- <a href="https://huggingface.co/spaces/asigalov61/Godzilla-Piano-Transformer?duplicate=true">
510
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face">
511
- </a>
512
- </p>
513
-
514
- for faster execution and endless generation!
515
- """)
516
-
517
- #==================================================================================
518
-
519
  final_composition = gr.State([])
520
  generated_batches = gr.State([])
521
  block_lines = gr.State([])
522
- model_state = gr.State([])
523
-
524
- #==================================================================================
525
-
526
- gr.Markdown("## Upload seed MIDI or click 'Generate' button for random output")
527
-
528
  input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
529
- input_midi.upload(reset, [final_composition, generated_batches, block_lines],
530
- [final_composition, generated_batches, block_lines])
531
-
532
- gr.Markdown("## Generate")
533
 
534
- model_selector = gr.Dropdown(["without velocity - 3 epochs"],
535
- label="Select model",
536
- )
537
-
538
  num_prime_tokens = gr.Slider(15, 3072, value=3072, step=1, label="Number of prime tokens")
539
  num_gen_tokens = gr.Slider(15, 1024, value=512, step=1, label="Number of tokens to generate")
540
  num_mem_tokens = gr.Slider(15, 4096, value=4096, step=1, label="Number of memory tokens")
541
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
542
- # model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
543
-
544
  generate_btn = gr.Button("Generate", variant="primary")
545
 
546
- gr.Markdown("## Select batch")
547
-
548
  outputs = [final_composition, generated_batches, block_lines]
549
-
550
  for i in range(NUM_OUT_BATCHES):
551
- with gr.Tab(f"Batch # {i}") as tab:
552
-
553
- audio_output = gr.Audio(label=f"Batch # {i} MIDI Audio", format="mp3", elem_id="midi_audio")
554
  plot_output = gr.Plot(label=f"Batch # {i} MIDI Plot")
555
-
556
  outputs.extend([audio_output, plot_output])
557
-
558
- outputs.extend([model_state])
559
-
560
- generate_btn.click(generate_callback_wrapper,
561
- [input_midi,
562
- num_prime_tokens,
563
- num_gen_tokens,
564
- num_mem_tokens,
565
- model_temperature,
566
- # model_sampling_top_p,
567
- final_composition,
568
- generated_batches,
569
- block_lines,
570
- model_selector,
571
- model_state
572
- ],
573
- outputs
574
- )
575
-
576
- gr.Markdown("## Add/Remove batch")
577
-
578
- batch_number = gr.Slider(0, NUM_OUT_BATCHES-1, value=0, step=1, label="Batch number to add/remove")
579
-
580
  add_btn = gr.Button("Add batch", variant="primary")
581
  remove_btn = gr.Button("Remove batch", variant="stop")
582
  clear_btn = gr.ClearButton()
583
-
584
- final_audio_output = gr.Audio(label="Final MIDI audio", format="mp3", elem_id="midi_audio")
585
  final_plot_output = gr.Plot(label="Final MIDI plot")
586
  final_file_output = gr.File(label="Final MIDI file")
587
 
588
- #==================================================================================
589
-
590
- add_btn.click(add_batch, [batch_number, final_composition, generated_batches, block_lines, model_state],
591
- [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines])
592
-
593
- #==================================================================================
594
-
595
- remove_btn.click(remove_batch, [batch_number, num_gen_tokens, final_composition, generated_batches, block_lines, model_state],
596
- [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines])
597
-
598
- #==================================================================================
599
-
600
- clear_btn.click(clear,
601
- inputs=None,
602
- outputs=[final_audio_output, final_plot_output, final_file_output, final_composition, block_lines]
603
- )
604
-
605
- #==================================================================================
606
 
607
  demo.unload(reset_demo)
608
 
609
- #==================================================================================
610
-
611
- demo.launch()
612
-
613
- #==================================================================================
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Godzilla Piano Transformer Gradio App - Single Model, Simplified Version
4
+ Fast 807M 4k solo Piano music transformer trained on 1.14M+ MIDIs (2.7M+ samples)
5
+ Using only one model: "without velocity - 3 epochs"
6
+ """
 
 
 
7
 
8
  import os
 
9
  import time as reqtime
10
  import datetime
11
  from pytz import timezone
12
 
 
 
 
 
 
13
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  import matplotlib.pyplot as plt
 
15
  import gradio as gr
16
  import spaces
17
 
18
+ from huggingface_hub import hf_hub_download
19
+ import TMIDIX
20
+ from midi_to_colab_audio import midi_to_colab_audio
21
+ from x_transformer_2_3_1 import TransformerWrapper, AutoregressiveWrapper, Decoder
 
 
 
 
22
 
23
+ # -----------------------------
24
+ # CONFIGURATION & GLOBALS
25
+ # -----------------------------
26
+ SEP = '=' * 70
27
+ PDT = timezone('US/Pacific')
28
 
29
+ MODEL_CHECKPOINT = 'Godzilla_Piano_Transformer_No_Velocity_Trained_Model_14075_steps_0.4534_loss_0.8687_acc.pth'
30
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
 
31
  NUM_OUT_BATCHES = 12
32
+ PREVIEW_LENGTH = 120 # in tokens
33
+
34
+ # -----------------------------
35
+ # PRINT START-UP INFO
36
+ # -----------------------------
37
+ def print_sep():
38
+ print(SEP)
39
+
40
+ print_sep()
41
+ print("Godzilla Piano Transformer Gradio App")
42
+ print_sep()
43
+ print("Loading modules...")
44
+
45
+ # -----------------------------
46
+ # ENVIRONMENT & PyTorch Settings
47
+ # -----------------------------
48
+ os.environ['USE_FLASH_ATTENTION'] = '1'
49
 
50
+ torch.set_float32_matmul_precision('high')
51
+ torch.backends.cuda.matmul.allow_tf32 = True
52
+ torch.backends.cudnn.allow_tf32 = True
53
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
54
+ torch.backends.cuda.enable_math_sdp(True)
55
+ torch.backends.cuda.enable_flash_sdp(True)
56
+ torch.backends.cuda.enable_cudnn_sdp(True)
57
 
58
+ print_sep()
59
+ print("PyTorch version:", torch.__version__)
60
+ print("Done loading modules!")
61
+ print_sep()
62
 
63
+ # -----------------------------
64
+ # MODEL INITIALIZATION
65
+ # -----------------------------
66
+ print_sep()
67
+ print("Instantiating model...")
68
 
69
  device_type = 'cuda'
70
  dtype = 'bfloat16'
 
71
  ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
72
  ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
73
 
74
  SEQ_LEN = 4096
 
75
  PAD_IDX = 384
76
 
77
  model = TransformerWrapper(
78
+ num_tokens=PAD_IDX + 1,
79
+ max_seq_len=SEQ_LEN,
80
+ attn_layers=Decoder(
81
+ dim=2048,
82
+ depth=16,
83
+ heads=32,
84
+ rotary_pos_emb=True,
85
+ attn_flash=True
86
+ )
87
  )
 
88
  model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
89
 
90
+ print_sep()
91
+ print("Loading model checkpoint...")
92
+ checkpoint = hf_hub_download(
93
+ repo_id='asigalov61/Godzilla-Piano-Transformer',
94
+ filename=MODEL_CHECKPOINT
95
+ )
96
+ model.load_state_dict(torch.load(checkpoint, map_location='cuda', weights_only=True))
 
97
  model = torch.compile(model, mode='max-autotune')
98
+ print_sep()
99
+ print("Done!")
100
+ print("Model will use", dtype, "precision...")
101
+ print_sep()
 
 
102
 
103
  model.cuda()
104
  model.eval()
105
 
106
+ # -----------------------------
107
+ # HELPER FUNCTIONS
108
+ # -----------------------------
109
+ def render_midi_output(final_composition):
110
+ """Generate MIDI score, plot, and audio from final composition."""
111
+ midi_score = save_midi(final_composition)
112
+ time_val = midi_score[-1][1] / 1000 # seconds marker from last note
113
+ midi_plot = TMIDIX.plot_ms_SONG(
114
+ midi_score,
115
+ plot_title='Godzilla Piano Transformer Composition',
116
+ block_lines_times_list=[],
117
+ return_plt=True
118
+ )
119
+ fname = save_midi(final_composition) # The file name is embedded in the saved MIDI.
120
+ midi_audio = midi_to_colab_audio(
121
+ fname + '.mid',
122
+ soundfont_path=SOUDFONT_PATH,
123
+ sample_rate=16000,
124
+ output_for_gradio=True
125
+ )
126
+ return (16000, midi_audio), midi_plot, fname + '.mid', time_val
127
+
128
+ # -----------------------------
129
+ # MIDI PROCESSING FUNCTIONS
130
+ # -----------------------------
131
+ def load_midi(input_midi):
132
+ """Process the input MIDI file and create a token sequence using without velocity logic."""
133
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
134
+ escore_notes = TMIDIX.advanced_score_processor(
135
+ raw_score, return_enhanced_score_notes=True, apply_sustain=True
136
+ )[0]
137
  sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes)
138
  zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
 
139
  zscore = TMIDIX.augment_enhanced_score_notes(zscore, timings_divider=32)
 
140
  fscore = TMIDIX.fix_escore_notes_durations(zscore)
 
141
  cscore = TMIDIX.chordify_score([1000, fscore])
142
 
143
  score = []
144
+ prev_chord = cscore[0]
145
+ for chord in cscore:
146
+ # Time difference token.
147
+ score.append(max(0, min(127, chord[0][1] - prev_chord[0][1])))
148
+ for note in chord:
149
+ score.extend([
150
+ max(1, min(127, note[2])) + 128,
151
+ max(1, min(127, note[4])) + 256
152
+ ])
153
+ prev_chord = chord
 
 
 
 
 
154
  return score
155
 
156
+ def save_midi(tokens, batch_number=None):
157
+ """Convert token sequence back to a MIDI score and write it using TMIDIX (without velocity).
158
+ The output MIDI file name incorporates a date-time stamp.
159
+ """
160
+ song_events = []
161
+ time_marker = 0
162
+ duration = 0
 
 
 
163
  pitch = 0
 
 
 
164
  patches = [0] * 16
165
 
166
+ for token in tokens:
167
+ if 0 <= token < 128:
168
+ time_marker += token * 32
169
+ elif 128 <= token < 256:
170
+ duration = (token - 128) * 32
171
+ elif 256 <= token < 384:
172
+ pitch = token - 256
173
+ song_events.append(['note', time_marker, duration, 0, pitch, max(40, pitch), 0])
174
+ # No velocity tokens are used.
175
+
176
+ # Generate a time stamp using the PDT timezone.
177
+ timestamp = datetime.datetime.now(PDT).strftime("%Y%m%d_%H%M%S")
178
+ if batch_number is None:
179
+ fname = f"Godzilla-Piano-Transformer-Music-Composition_{timestamp}"
 
 
 
 
 
 
 
 
 
180
  else:
181
+ fname = f"Godzilla-Piano-Transformer-Music-Composition_{timestamp}_Batch_{batch_number}"
182
+
183
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
184
+ song_events,
185
+ output_signature='Godzilla Piano Transformer',
186
+ output_file_name=fname,
187
+ track_name='Project Los Angeles',
188
+ list_of_MIDI_patches=patches,
189
+ verbose=False
190
+ )
191
+ return fname
192
+
193
+ # -----------------------------
194
+ # MUSIC GENERATION FUNCTION (Combined)
195
+ # -----------------------------
196
  @spaces.GPU
197
+ def generate_music(prime, num_gen_tokens, num_mem_tokens, num_gen_batches, model_temperature):
198
+ """Generate music tokens given prime tokens and parameters."""
199
+ inputs = prime[-num_mem_tokens:] if prime else [0]
200
+ print("Generating...")
201
+ inp = torch.LongTensor([inputs] * num_gen_batches).cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  with ctx:
203
+ out = model.generate(
204
+ inp,
205
+ num_gen_tokens,
206
+ temperature=model_temperature,
207
+ return_prime=False,
208
+ verbose=False
209
+ )
210
+ print("Done!")
211
+ print_sep()
212
+ return out.tolist()
213
+
214
+ def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens,
215
+ model_temperature, final_composition, generated_batches, block_lines):
216
+ """
217
+ Generate tokens using the model, update the composition state, and prepare outputs.
218
+ This function combines seed loading, token generation, and UI output packaging.
219
+ """
220
+ print_sep()
221
+ print("Request start time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S"))
222
+
223
+ # Load seed from MIDI if there is no existing composition.
 
 
 
 
 
 
 
 
 
 
224
  if not final_composition and input_midi is not None:
225
+ final_composition = load_midi(input_midi)[:num_prime_tokens]
226
+ midi_fname = save_midi(final_composition)
227
+ # Use the last note's time as a marker.
228
+ midi_score = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
229
+ final_composition,
230
+ output_signature='Godzilla Piano Transformer',
231
+ output_file_name=midi_fname,
232
+ track_name='Project Los Angeles',
233
+ list_of_MIDI_patches=[0]*16,
234
+ verbose=False
235
+ )
236
+ block_lines.append(final_composition[-1] if final_composition else 0)
237
+
238
+ batched_gen_tokens = generate_music(final_composition, num_gen_tokens, num_mem_tokens,
239
+ NUM_OUT_BATCHES, model_temperature)
240
+
241
+ output_batches = []
242
+ for i, tokens in enumerate(batched_gen_tokens):
243
+ preview_tokens = final_composition[-PREVIEW_LENGTH:]
244
+ midi_fname = save_midi(preview_tokens + tokens, batch_number=i)
245
+ plot_kwargs = {'plot_title': f'Batch # {i}', 'return_plt': True}
 
 
 
 
 
 
246
  if len(final_composition) > PREVIEW_LENGTH:
247
+ plot_kwargs['preview_length_in_notes'] = int(PREVIEW_LENGTH / 3)
248
+ midi_score = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
249
+ preview_tokens + tokens,
250
+ output_signature='Godzilla Piano Transformer',
251
+ output_file_name=midi_fname,
252
+ track_name='Project Los Angeles',
253
+ list_of_MIDI_patches=[0]*16,
254
+ verbose=False
255
+ )
256
+ midi_plot = TMIDIX.plot_ms_SONG(midi_score, **plot_kwargs)
257
+ midi_audio = midi_to_colab_audio(midi_fname + '.mid',
258
+ soundfont_path=SOUDFONT_PATH,
259
+ sample_rate=16000,
260
+ output_for_gradio=True)
261
+ output_batches.append([(16000, midi_audio), midi_plot, tokens])
262
+
263
+ # Update generated_batches (for use by add/remove functions)
264
+ generated_batches = batched_gen_tokens
265
+
266
+ print("Request end time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S"))
267
+ print_sep()
268
+
269
+ # Flatten outputs: states then audio and plots for each batch.
270
+ outputs_flat = []
271
+ for batch in output_batches:
272
+ outputs_flat.extend([batch[0], batch[1]])
273
+ return [final_composition, generated_batches, block_lines] + outputs_flat
274
+
275
+ # -----------------------------
276
+ # BATCH HANDLING FUNCTIONS
277
+ # -----------------------------
278
+ def add_batch(batch_number, final_composition, generated_batches, block_lines):
279
+ """Add tokens from the specified batch to the final composition and update outputs."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  if generated_batches:
281
  final_composition.extend(generated_batches[batch_number])
282
+ midi_fname = save_midi(final_composition)
283
+ block_lines.append(final_composition[-1] if final_composition else 0)
284
+ midi_score = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
285
+ final_composition,
286
+ output_signature='Godzilla Piano Transformer',
287
+ output_file_name=midi_fname,
288
+ track_name='Project Los Angeles',
289
+ list_of_MIDI_patches=[0]*16,
290
+ verbose=False
291
+ )
292
+ midi_plot = TMIDIX.plot_ms_SONG(
293
+ midi_score,
294
+ plot_title='Godzilla Piano Transformer Composition',
295
+ block_lines_times_list=block_lines[:-1],
296
+ return_plt=True
297
+ )
298
+ midi_audio = midi_to_colab_audio(midi_fname + '.mid',
299
+ soundfont_path=SOUDFONT_PATH,
300
+ sample_rate=16000,
301
+ output_for_gradio=True)
302
+ print("Added batch #", batch_number)
303
+ print_sep()
304
+ return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines
 
 
 
 
305
  else:
306
  return None, None, None, [], [], []
307
 
308
+ def remove_batch(batch_number, num_tokens, final_composition, generated_batches, block_lines):
309
+ """Remove tokens from the final composition and update outputs."""
310
+ if final_composition and len(final_composition) > num_tokens:
311
+ final_composition = final_composition[:-num_tokens]
312
+ if block_lines:
 
 
 
313
  block_lines.pop()
314
+ midi_fname = save_midi(final_composition)
315
+ midi_score = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
316
+ final_composition,
317
+ output_signature='Godzilla Piano Transformer',
318
+ output_file_name=midi_fname,
319
+ track_name='Project Los Angeles',
320
+ list_of_MIDI_patches=[0]*16,
321
+ verbose=False
322
+ )
323
+ midi_plot = TMIDIX.plot_ms_SONG(
324
+ midi_score,
325
+ plot_title='Godzilla Piano Transformer Composition',
326
+ block_lines_times_list=block_lines[:-1],
327
+ return_plt=True
328
+ )
329
+ midi_audio = midi_to_colab_audio(midi_fname + '.mid',
330
+ soundfont_path=SOUDFONT_PATH,
331
+ sample_rate=16000,
332
+ output_for_gradio=True)
333
+ print("Removed batch #", batch_number)
334
+ print_sep()
335
+ return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines
 
 
 
336
  else:
337
  return None, None, None, [], [], []
 
 
 
 
 
338
 
339
+ def clear():
340
+ """Clear outputs and reset state."""
341
+ return None, None, None, [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
+ def reset(final_composition=[], generated_batches=[], block_lines=[]):
344
+ """Reset composition state."""
345
+ return [], [], []
346
 
347
+ def reset_demo(final_composition=[], generated_batches=[], block_lines=[]):
348
+ """Reset state for demo unload."""
349
+ return [], [], []
350
 
351
+ # -----------------------------
352
+ # GRADIO INTERFACE SETUP
353
+ # -----------------------------
354
  with gr.Blocks() as demo:
 
 
 
355
  demo.load(reset_demo)
356
 
 
 
357
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Godzilla Piano Transformer</h1>")
358
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Fast 807M 4k solo Piano music transformer trained on 1.14M+ MIDIs (2.7M+ samples)</h1>")
359
  gr.HTML("""
360
+ Check out <a href="https://huggingface.co/datasets/asigalov61/Godzilla-Piano">Godzilla Piano dataset</a> on Hugging Face
361
+ <p>
362
+ <a href="https://huggingface.co/spaces/asigalov61/Godzilla-Piano-Transformer?duplicate=true">
363
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face">
364
+ </a>
365
+ </p>
366
+ for faster execution and endless generation!
367
+ """)
368
+
369
+ # Global state variables for composition
 
 
 
370
  final_composition = gr.State([])
371
  generated_batches = gr.State([])
372
  block_lines = gr.State([])
373
+
374
+ gr.Markdown("## Upload seed MIDI or click 'Generate' for a random output")
 
 
 
 
375
  input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
376
+ input_midi.upload(reset, [final_composition, generated_batches, block_lines],
377
+ [final_composition, generated_batches, block_lines])
 
 
378
 
379
+ gr.Markdown("## Generate")
 
 
 
380
  num_prime_tokens = gr.Slider(15, 3072, value=3072, step=1, label="Number of prime tokens")
381
  num_gen_tokens = gr.Slider(15, 1024, value=512, step=1, label="Number of tokens to generate")
382
  num_mem_tokens = gr.Slider(15, 4096, value=4096, step=1, label="Number of memory tokens")
383
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
 
 
384
  generate_btn = gr.Button("Generate", variant="primary")
385
 
386
+ gr.Markdown("## Batch Previews")
 
387
  outputs = [final_composition, generated_batches, block_lines]
388
+ # Two outputs (audio and plot) for each batch
389
  for i in range(NUM_OUT_BATCHES):
390
+ with gr.Tab(f"Batch # {i}"):
391
+ audio_output = gr.Audio(label=f"Batch # {i} MIDI Audio", format="mp3")
 
392
  plot_output = gr.Plot(label=f"Batch # {i} MIDI Plot")
 
393
  outputs.extend([audio_output, plot_output])
394
+ generate_btn.click(
395
+ generate_music_and_state,
396
+ [input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens, model_temperature,
397
+ final_composition, generated_batches, block_lines],
398
+ outputs
399
+ )
400
+
401
+ gr.Markdown("## Add/Remove Batch")
402
+ batch_number = gr.Slider(0, NUM_OUT_BATCHES - 1, value=0, step=1, label="Batch number to add/remove")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  add_btn = gr.Button("Add batch", variant="primary")
404
  remove_btn = gr.Button("Remove batch", variant="stop")
405
  clear_btn = gr.ClearButton()
406
+
407
+ final_audio_output = gr.Audio(label="Final MIDI audio", format="mp3")
408
  final_plot_output = gr.Plot(label="Final MIDI plot")
409
  final_file_output = gr.File(label="Final MIDI file")
410
 
411
+ add_btn.click(
412
+ add_batch,
413
+ [batch_number, final_composition, generated_batches, block_lines],
414
+ [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines]
415
+ )
416
+ remove_btn.click(
417
+ remove_batch,
418
+ [batch_number, num_gen_tokens, final_composition, generated_batches, block_lines],
419
+ [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines]
420
+ )
421
+ clear_btn.click(clear, inputs=None,
422
+ outputs=[final_audio_output, final_plot_output, final_file_output, final_composition, block_lines])
 
 
 
 
 
 
423
 
424
  demo.unload(reset_demo)
425
 
426
+ demo.launch()