SteveZerb commited on
Commit
180d53d
·
verified ·
1 Parent(s): bf288c7

Upload app_onnx.py

Browse files
Files changed (1) hide show
  1. app_onnx.py +625 -0
app_onnx.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import random
3
+ import argparse
4
+ import glob
5
+ import json
6
+ import os
7
+ import time
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import onnxruntime as rt
13
+ import tqdm
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ import MIDI
17
+ from midi_synthesizer import MidiSynthesizer
18
+ from midi_tokenizer import MIDITokenizer
19
+
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ in_space = os.getenv("SYSTEM") == "spaces"
22
+
23
+
24
+ def softmax(x, axis):
25
+ x_max = np.amax(x, axis=axis, keepdims=True)
26
+ exp_x_shifted = np.exp(x - x_max)
27
+ return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
28
+
29
+
30
+ def sample_top_p_k(probs, p, k, generator=None):
31
+ if generator is None:
32
+ generator = np.random
33
+ probs_idx = np.argsort(-probs, axis=-1)
34
+ probs_sort = np.take_along_axis(probs, probs_idx, -1)
35
+ probs_sum = np.cumsum(probs_sort, axis=-1)
36
+ mask = probs_sum - probs_sort > p
37
+ probs_sort[mask] = 0.0
38
+ mask = np.zeros(probs_sort.shape[-1])
39
+ mask[:k] = 1
40
+ probs_sort = probs_sort * mask
41
+ probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
42
+ shape = probs_sort.shape
43
+ probs_sort_flat = probs_sort.reshape(-1, shape[-1])
44
+ probs_idx_flat = probs_idx.reshape(-1, shape[-1])
45
+ next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
46
+ next_token = next_token.reshape(*shape[:-1])
47
+ return next_token
48
+
49
+
50
+ def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
51
+ io_binding = model.io_binding()
52
+ for input_ in model.get_inputs():
53
+ name = input_.name
54
+ if name.startswith("past_key_values"):
55
+ present_name = name.replace("past_key_values", "present")
56
+ if present_name in outputs:
57
+ v = outputs[present_name]
58
+ else:
59
+ v = rt.OrtValue.ortvalue_from_shape_and_type(
60
+ (batch_size, input_.shape[1], past_len, input_.shape[3]),
61
+ element_type=np.float32,
62
+ device_type=device)
63
+ inputs[name] = v
64
+ else:
65
+ v = inputs[name]
66
+ io_binding.bind_ortvalue_input(name, v)
67
+
68
+ for output in model.get_outputs():
69
+ name = output.name
70
+ if name.startswith("present"):
71
+ v = rt.OrtValue.ortvalue_from_shape_and_type(
72
+ (batch_size, output.shape[1], cur_len, output.shape[3]),
73
+ element_type=np.float32,
74
+ device_type=device)
75
+ outputs[name] = v
76
+ else:
77
+ v = outputs[name]
78
+ io_binding.bind_ortvalue_output(name, v)
79
+ return io_binding
80
+
81
+ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
82
+ disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
83
+ tokenizer = model[2]
84
+ if disable_channels is not None:
85
+ disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
86
+ else:
87
+ disable_channels = []
88
+ if generator is None:
89
+ generator = np.random
90
+ max_token_seq = tokenizer.max_token_seq
91
+ if prompt is None:
92
+ input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
93
+ input_tensor[0, 0] = tokenizer.bos_id # bos
94
+ input_tensor = input_tensor[None, :, :]
95
+ input_tensor = np.repeat(input_tensor, repeats=batch_size, axis=0)
96
+ else:
97
+ if len(prompt.shape) == 2:
98
+ prompt = prompt[None, :]
99
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
100
+ elif prompt.shape[0] == 1:
101
+ prompt = np.repeat(prompt, repeats=batch_size, axis=0)
102
+ elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
103
+ raise ValueError(f"invalid shape for prompt, {prompt.shape}")
104
+ prompt = prompt[..., :max_token_seq]
105
+ if prompt.shape[-1] < max_token_seq:
106
+ prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
107
+ mode="constant", constant_values=tokenizer.pad_id)
108
+ input_tensor = prompt
109
+ cur_len = input_tensor.shape[1]
110
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
111
+ model0_inputs = {}
112
+ model0_outputs = {}
113
+ emb_size = 1024
114
+ for output in model[0].get_outputs():
115
+ if output.name == "hidden":
116
+ emb_size = output.shape[2]
117
+ past_len = 0
118
+ with bar:
119
+ while cur_len < max_len:
120
+ end = [False] * batch_size
121
+ model0_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(input_tensor[:, past_len:], device_type=device)
122
+ model0_outputs["hidden"] = rt.OrtValue.ortvalue_from_shape_and_type(
123
+ (batch_size, cur_len - past_len, emb_size),
124
+ element_type=np.float32,
125
+ device_type=device)
126
+ io_binding = apply_io_binding(model[0], model0_inputs, model0_outputs, batch_size, past_len, cur_len)
127
+ io_binding.synchronize_inputs()
128
+ model[0].run_with_iobinding(io_binding)
129
+ io_binding.synchronize_outputs()
130
+
131
+ hidden = model0_outputs["hidden"].numpy()[:, -1:]
132
+ next_token_seq = np.zeros((batch_size, 0), dtype=np.int64)
133
+ event_names = [""] * batch_size
134
+ model1_inputs = {"hidden": rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)}
135
+ model1_outputs = {}
136
+ for i in range(max_token_seq):
137
+ mask = np.zeros((batch_size, tokenizer.vocab_size), dtype=np.int64)
138
+ for b in range(batch_size):
139
+ if end[b]:
140
+ mask[b, tokenizer.pad_id] = 1
141
+ continue
142
+ if i == 0:
143
+ mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
144
+ if disable_patch_change:
145
+ mask_ids.remove(tokenizer.event_ids["patch_change"])
146
+ if disable_control_change:
147
+ mask_ids.remove(tokenizer.event_ids["control_change"])
148
+ mask[b, mask_ids] = 1
149
+ else:
150
+ param_names = tokenizer.events[event_names[b]]
151
+ if i > len(param_names):
152
+ mask[b, tokenizer.pad_id] = 1
153
+ continue
154
+ param_name = param_names[i - 1]
155
+ mask_ids = tokenizer.parameter_ids[param_name]
156
+ if param_name == "channel":
157
+ mask_ids = [i for i in mask_ids if i not in disable_channels]
158
+ mask[b, mask_ids] = 1
159
+ mask = mask[:, None, :]
160
+ x = next_token_seq
161
+ if i != 0:
162
+ # cached
163
+ if i == 1:
164
+ hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
165
+ model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
166
+ x = x[:, -1:]
167
+ model1_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(x, device_type=device)
168
+ model1_outputs["y"] = rt.OrtValue.ortvalue_from_shape_and_type(
169
+ (batch_size, 1, tokenizer.vocab_size),
170
+ element_type=np.float32,
171
+ device_type=device
172
+ )
173
+ io_binding = apply_io_binding(model[1], model1_inputs, model1_outputs, batch_size, i, i+1)
174
+ io_binding.synchronize_inputs()
175
+ model[1].run_with_iobinding(io_binding)
176
+ io_binding.synchronize_outputs()
177
+ logits = model1_outputs["y"].numpy()
178
+ scores = softmax(logits / temp, -1) * mask
179
+ samples = sample_top_p_k(scores, top_p, top_k, generator)
180
+ if i == 0:
181
+ next_token_seq = samples
182
+ for b in range(batch_size):
183
+ if end[b]:
184
+ continue
185
+ eid = samples[b].item()
186
+ if eid == tokenizer.eos_id:
187
+ end[b] = True
188
+ else:
189
+ event_names[b] = tokenizer.id_events[eid]
190
+ else:
191
+ next_token_seq = np.concatenate([next_token_seq, samples], axis=1)
192
+ if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
193
+ break
194
+ if next_token_seq.shape[1] < max_token_seq:
195
+ next_token_seq = np.pad(next_token_seq,
196
+ ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
197
+ mode="constant", constant_values=tokenizer.pad_id)
198
+ next_token_seq = next_token_seq[:, None, :]
199
+ input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
200
+ past_len = cur_len
201
+ cur_len += 1
202
+ bar.update(1)
203
+ yield next_token_seq[:, 0]
204
+ if all(end):
205
+ break
206
+
207
+
208
+ def create_msg(name, data):
209
+ return {"name": name, "data": data}
210
+
211
+
212
+ def send_msgs(msgs):
213
+ return json.dumps(msgs)
214
+
215
+
216
+ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
217
+ time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
218
+ remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
219
+ t = gen_events // 28
220
+ if "large" in model_name:
221
+ t = gen_events // 20
222
+ return t + 10
223
+
224
+
225
+ @spaces.GPU(duration=get_duration)
226
+ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
227
+ key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
228
+ seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
229
+ model = models[model_name]
230
+ model_base = rt.InferenceSession(model[0], providers=providers)
231
+ model_token = rt.InferenceSession(model[1], providers=providers)
232
+ tokenizer = model[2]
233
+ model = [model_base, model_token, tokenizer]
234
+ bpm = int(bpm)
235
+ if time_sig == "auto":
236
+ time_sig = None
237
+ time_sig_nn = 4
238
+ time_sig_dd = 2
239
+ else:
240
+ time_sig_nn, time_sig_dd = time_sig.split('/')
241
+ time_sig_nn = int(time_sig_nn)
242
+ time_sig_dd = {2: 1, 4: 2, 8: 3}[int(time_sig_dd)]
243
+ if key_sig == 0:
244
+ key_sig = None
245
+ key_sig_sf = 0
246
+ key_sig_mi = 0
247
+ else:
248
+ key_sig = (key_sig - 1)
249
+ key_sig_sf = key_sig // 2 - 7
250
+ key_sig_mi = key_sig % 2
251
+ gen_events = int(gen_events)
252
+ max_len = gen_events
253
+ if seed_rand:
254
+ seed = random.randint(0, MAX_SEED)
255
+ generator = np.random.RandomState(seed)
256
+ disable_patch_change = False
257
+ disable_channels = None
258
+ if tab == 0:
259
+ i = 0
260
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
261
+ if tokenizer.version == "v2":
262
+ if time_sig is not None:
263
+ mid.append(tokenizer.event2tokens(["time_signature", 0, 0, 0, time_sig_nn - 1, time_sig_dd - 1]))
264
+ if key_sig is not None:
265
+ mid.append(tokenizer.event2tokens(["key_signature", 0, 0, 0, key_sig_sf + 7, key_sig_mi]))
266
+ if bpm != 0:
267
+ mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
268
+ patches = {}
269
+ if instruments is None:
270
+ instruments = []
271
+ for instr in instruments:
272
+ patches[i] = patch2number[instr]
273
+ i = (i + 1) if i != 8 else 10
274
+ if drum_kit != "None":
275
+ patches[9] = drum_kits2number[drum_kit]
276
+ for i, (c, p) in enumerate(patches.items()):
277
+ mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
278
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
279
+ mid_seq = mid.tolist()
280
+ if len(instruments) > 0:
281
+ disable_patch_change = True
282
+ disable_channels = [i for i in range(16) if i not in patches]
283
+ elif tab == 1 and mid is not None:
284
+ eps = 4 if reduce_cc_st else 0
285
+ mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
286
+ remap_track_channel=remap_track_channel,
287
+ add_default_instr=add_default_instr,
288
+ remove_empty_channels=remove_empty_channels)
289
+ mid = mid[:int(midi_events)]
290
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
291
+ mid_seq = mid.tolist()
292
+ elif tab == 2 and mid_seq is not None:
293
+ mid = np.asarray(mid_seq, dtype=np.int64)
294
+ if continuation_select > 0:
295
+ continuation_state.append(mid_seq)
296
+ mid = np.repeat(mid[continuation_select - 1:continuation_select], repeats=OUTPUT_BATCH_SIZE, axis=0)
297
+ mid_seq = mid.tolist()
298
+ else:
299
+ continuation_state.append(mid.shape[1])
300
+ else:
301
+ continuation_state = [0]
302
+ mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
303
+ mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
304
+ mid_seq = mid.tolist()
305
+
306
+ if mid is not None:
307
+ max_len += mid.shape[1]
308
+
309
+ init_msgs = [create_msg("progress", [0, gen_events])]
310
+ if not (tab == 2 and continuation_select == 0):
311
+ for i in range(OUTPUT_BATCH_SIZE):
312
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
313
+ init_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
314
+ create_msg("visualizer_append", [i, events])]
315
+ yield mid_seq, continuation_state, seed, send_msgs(init_msgs)
316
+ midi_generator = generate(model, mid, batch_size=OUTPUT_BATCH_SIZE, max_len=max_len, temp=temp,
317
+ top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change,
318
+ disable_control_change=not allow_cc, disable_channels=disable_channels,
319
+ generator=generator)
320
+ events = [list() for i in range(OUTPUT_BATCH_SIZE)]
321
+ t = time.time() + 1
322
+ for i, token_seqs in enumerate(midi_generator):
323
+ token_seqs = token_seqs.tolist()
324
+ for j in range(OUTPUT_BATCH_SIZE):
325
+ token_seq = token_seqs[j]
326
+ mid_seq[j].append(token_seq)
327
+ events[j].append(tokenizer.tokens2event(token_seq))
328
+ if time.time() - t > 0.5:
329
+ msgs = [create_msg("progress", [i + 1, gen_events])]
330
+ for j in range(OUTPUT_BATCH_SIZE):
331
+ msgs += [create_msg("visualizer_append", [j, events[j]])]
332
+ events[j] = list()
333
+ yield mid_seq, continuation_state, seed, send_msgs(msgs)
334
+ t = time.time()
335
+ yield mid_seq, continuation_state, seed, send_msgs([])
336
+
337
+
338
+ def finish_run(model_name, mid_seq):
339
+ if mid_seq is None:
340
+ outputs = [None] * OUTPUT_BATCH_SIZE
341
+ return *outputs, []
342
+ tokenizer = models[model_name][2]
343
+ outputs = []
344
+ end_msgs = [create_msg("progress", [0, 0])]
345
+ if not os.path.exists("outputs"):
346
+ os.mkdir("outputs")
347
+ for i in range(OUTPUT_BATCH_SIZE):
348
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
349
+ mid = tokenizer.detokenize(mid_seq[i])
350
+ with open(f"outputs/output{i + 1}.mid", 'wb') as f:
351
+ f.write(MIDI.score2midi(mid))
352
+ outputs.append(f"outputs/output{i + 1}.mid")
353
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
354
+ create_msg("visualizer_append", [i, events]),
355
+ create_msg("visualizer_end", i)]
356
+ return *outputs, send_msgs(end_msgs)
357
+
358
+
359
+ def synthesis_task(mid):
360
+ return synthesizer.synthesis(MIDI.score2opus(mid))
361
+
362
+ def render_audio(model_name, mid_seq, should_render_audio):
363
+ if (not should_render_audio) or mid_seq is None:
364
+ outputs = [None] * OUTPUT_BATCH_SIZE
365
+ return tuple(outputs)
366
+ tokenizer = models[model_name][2]
367
+ outputs = []
368
+ if not os.path.exists("outputs"):
369
+ os.mkdir("outputs")
370
+ audio_futures = []
371
+ for i in range(OUTPUT_BATCH_SIZE):
372
+ mid = tokenizer.detokenize(mid_seq[i])
373
+ audio_future = thread_pool.submit(synthesis_task, mid)
374
+ audio_futures.append(audio_future)
375
+ for future in audio_futures:
376
+ outputs.append((44100, future.result()))
377
+ if OUTPUT_BATCH_SIZE == 1:
378
+ return outputs[0]
379
+ return tuple(outputs)
380
+
381
+
382
+ def undo_continuation(model_name, mid_seq, continuation_state):
383
+ if mid_seq is None or len(continuation_state) < 2:
384
+ return mid_seq, continuation_state, send_msgs([])
385
+ tokenizer = models[model_name][2]
386
+ if isinstance(continuation_state[-1], list):
387
+ mid_seq = continuation_state[-1]
388
+ else:
389
+ mid_seq = [ms[:continuation_state[-1]] for ms in mid_seq]
390
+ continuation_state = continuation_state[:-1]
391
+ end_msgs = [create_msg("progress", [0, 0])]
392
+ for i in range(OUTPUT_BATCH_SIZE):
393
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
394
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
395
+ create_msg("visualizer_append", [i, events]),
396
+ create_msg("visualizer_end", i)]
397
+ return mid_seq, continuation_state, send_msgs(end_msgs)
398
+
399
+
400
+ def load_javascript(dir="javascript"):
401
+ scripts_list = glob.glob(f"{dir}/*.js")
402
+ javascript = ""
403
+ for path in scripts_list:
404
+ with open(path, "r", encoding="utf8") as jsfile:
405
+ js_content = jsfile.read()
406
+ js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
407
+ f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
408
+ javascript += f"\n<!-- {path} --><script>{js_content}</script>"
409
+ template_response_ori = gr.routes.templates.TemplateResponse
410
+
411
+ def template_response(*args, **kwargs):
412
+ res = template_response_ori(*args, **kwargs)
413
+ res.body = res.body.replace(
414
+ b'</head>', f'{javascript}</head>'.encode("utf8"))
415
+ res.init_headers()
416
+ return res
417
+
418
+ gr.routes.templates.TemplateResponse = template_response
419
+
420
+
421
+ def hf_hub_download_retry(repo_id, filename):
422
+ print(f"downloading {repo_id} {filename}")
423
+ retry = 0
424
+ err = None
425
+ while retry < 30:
426
+ try:
427
+ return hf_hub_download(repo_id=repo_id, filename=filename)
428
+ except Exception as e:
429
+ err = e
430
+ retry += 1
431
+ if err:
432
+ raise err
433
+
434
+
435
+ def get_tokenizer(repo_id):
436
+ config_path = hf_hub_download_retry(repo_id=repo_id, filename=f"config.json")
437
+ with open(config_path, "r") as f:
438
+ config = json.load(f)
439
+ tokenizer = MIDITokenizer(config["tokenizer"]["version"])
440
+ tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
441
+ return tokenizer
442
+
443
+
444
+ number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
445
+ 40: "Blush", 48: "Orchestra"}
446
+ patch2number = {v: k for k, v in MIDI.Number2patch.items()}
447
+ drum_kits2number = {v: k for k, v in number2drum_kits.items()}
448
+ key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
449
+ 'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
450
+
451
+ if __name__ == "__main__":
452
+ parser = argparse.ArgumentParser()
453
+ parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
454
+ parser.add_argument("--port", type=int, default=7860, help="gradio server port")
455
+ parser.add_argument("--device", type=str, default="cuda", help="device to run model")
456
+ parser.add_argument("--batch", type=int, default=8, help="batch size")
457
+ parser.add_argument("--max-gen", type=int, default=1024, help="max")
458
+ opt = parser.parse_args()
459
+ OUTPUT_BATCH_SIZE = opt.batch
460
+ soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
461
+ thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
462
+ synthesizer = MidiSynthesizer(soundfont_path)
463
+ models_info = {
464
+ "generic pretrain model (tv2o-medium) by skytnt": [
465
+ "skytnt/midi-model-tv2o-medium", "", {
466
+ "jpop": "skytnt/midi-model-tv2om-jpop-lora",
467
+ "touhou": "skytnt/midi-model-tv2om-touhou-lora"
468
+ }
469
+ ],
470
+ "generic pretrain model (tv2o-large) by asigalov61": [
471
+ "asigalov61/Music-Llama", "", {}
472
+ ],
473
+ "generic pretrain model (tv2o-medium) by asigalov61": [
474
+ "asigalov61/Music-Llama-Medium", "", {}
475
+ ],
476
+ "generic pretrain model (tv1-medium) by skytnt": [
477
+ "skytnt/midi-model", "", {}
478
+ ]
479
+ }
480
+ models = {}
481
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
482
+ device = "cuda"
483
+
484
+ for name, (repo_id, path, loras) in models_info.items():
485
+ model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
486
+ model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
487
+ tokenizer = get_tokenizer(repo_id)
488
+ models[name] = [model_base_path, model_token_path, tokenizer]
489
+ for lora_name, lora_repo in loras.items():
490
+ model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
491
+ model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
492
+ models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
493
+
494
+ load_javascript()
495
+ app = gr.Blocks(theme=gr.themes.Soft())
496
+ with app:
497
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
498
+ gr.Markdown("\n\n"
499
+ "A modified version of the Midi-Generator for the IAT-360 Course by Ethan Lum\n\n"
500
+ "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
501
+ "[Open In Colab]"
502
+ "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
503
+ " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
504
+ " for unlimited generation\n\n"
505
+ "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
506
+ "The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
507
+ )
508
+ js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
509
+ js_msg.change(None, [js_msg], [], js="""
510
+ (msg_json) =>{
511
+ let msgs = JSON.parse(msg_json);
512
+ executeCallbacks(msgReceiveCallbacks, msgs);
513
+ return [];
514
+ }
515
+ """)
516
+ input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
517
+ type="value", value=list(models.keys())[0])
518
+ tab_select = gr.State(value=0)
519
+ with gr.Tabs():
520
+ with gr.TabItem("custom prompt") as tab1:
521
+ input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
522
+ multiselect=True, max_choices=15, type="value")
523
+ input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
524
+ value="None")
525
+ input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
526
+ step=1,
527
+ value=0)
528
+ input_time_sig = gr.Radio(label="time signature (only for tv2 models)",
529
+ value="auto",
530
+ choices=["auto", "4/4", "2/4", "3/4", "6/4", "7/4",
531
+ "2/2", "3/2", "4/2", "3/8", "5/8", "6/8", "7/8", "9/8", "12/8"]
532
+ )
533
+ input_key_sig = gr.Radio(label="key signature (only for tv2 models)",
534
+ value="auto",
535
+ choices=["auto"] + key_signatures,
536
+ type="index"
537
+ )
538
+ example1 = gr.Examples([
539
+ [[], "None"],
540
+ [["Acoustic Grand"], "None"],
541
+ [['Acoustic Grand', 'SynthStrings 2', 'SynthStrings 1', 'Pizzicato Strings',
542
+ 'Pad 2 (warm)', 'Tremolo Strings', 'String Ensemble 1'], "Orchestra"],
543
+ [['Trumpet', 'Oboe', 'Trombone', 'String Ensemble 1', 'Clarinet',
544
+ 'French Horn', 'Pad 4 (choir)', 'Bassoon', 'Flute'], "None"],
545
+ [['Flute', 'French Horn', 'Clarinet', 'String Ensemble 2', 'English Horn', 'Bassoon',
546
+ 'Oboe', 'Pizzicato Strings'], "Orchestra"],
547
+ [['Electric Piano 2', 'Lead 5 (charang)', 'Electric Bass(pick)', 'Lead 2 (sawtooth)',
548
+ 'Pad 1 (new age)', 'Orchestra Hit', 'Cello', 'Electric Guitar(clean)'], "Standard"],
549
+ [["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
550
+ "Electric Bass(finger)"], "Standard"]
551
+ ], [input_instruments, input_drum_kit])
552
+ with gr.TabItem("midi prompt") as tab2:
553
+ input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
554
+ input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
555
+ step=1,
556
+ value=128)
557
+ input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
558
+ input_remap_track_channel = gr.Checkbox(
559
+ label="remap tracks and channels so each track has only one channel and in order", value=True)
560
+ input_add_default_instr = gr.Checkbox(
561
+ label="add a default instrument to channels that don't have an instrument", value=True)
562
+ input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
563
+ example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
564
+ [input_midi, input_midi_events])
565
+ with gr.TabItem("last output prompt") as tab3:
566
+ gr.Markdown("Continue generating on the last output.")
567
+ input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
568
+ choices=["all"] + [f"output{i + 1}" for i in
569
+ range(OUTPUT_BATCH_SIZE)],
570
+ type="index"
571
+ )
572
+ undo_btn = gr.Button("undo the last continuation")
573
+
574
+ tab1.select(lambda: 0, None, tab_select, queue=False)
575
+ tab2.select(lambda: 1, None, tab_select, queue=False)
576
+ tab3.select(lambda: 2, None, tab_select, queue=False)
577
+ input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
578
+ step=1, value=0)
579
+ input_seed_rand = gr.Checkbox(label="random seed", value=True)
580
+ input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
581
+ step=1, value=opt.max_gen // 2)
582
+ with gr.Accordion("options", open=False):
583
+ input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
584
+ input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
585
+ input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
586
+ input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
587
+ input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
588
+ example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
589
+ [input_temp, input_top_p, input_top_k])
590
+ run_btn = gr.Button("generate", variant="primary")
591
+ # stop_btn = gr.Button("stop and output")
592
+ output_midi_seq = gr.State()
593
+ output_continuation_state = gr.State([0])
594
+ midi_outputs = []
595
+ audio_outputs = []
596
+ with gr.Tabs(elem_id="output_tabs"):
597
+ for i in range(OUTPUT_BATCH_SIZE):
598
+ with gr.TabItem(f"output {i + 1}") as tab1:
599
+ output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
600
+ output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
601
+ output_midi = gr.File(label="output midi", file_types=[".mid"])
602
+ midi_outputs.append(output_midi)
603
+ audio_outputs.append(output_audio)
604
+ run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
605
+ input_continuation_select, input_instruments, input_drum_kit, input_bpm,
606
+ input_time_sig, input_key_sig, input_midi, input_midi_events,
607
+ input_reduce_cc_st, input_remap_track_channel,
608
+ input_add_default_instr, input_remove_empty_channels,
609
+ input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
610
+ input_top_k, input_allow_cc],
611
+ [output_midi_seq, output_continuation_state, input_seed, js_msg], queue=True)
612
+ finish_run_event = run_event.then(fn=finish_run,
613
+ inputs=[input_model, output_midi_seq],
614
+ outputs=midi_outputs + [js_msg],
615
+ queue=False)
616
+ finish_run_event.then(fn=render_audio,
617
+ inputs=[input_model, output_midi_seq, input_render_audio],
618
+ outputs=audio_outputs,
619
+ queue=False)
620
+ # stop_btn.click(None, [], [], cancels=run_event,
621
+ # queue=False)
622
+ undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
623
+ [output_midi_seq, output_continuation_state, js_msg], queue=False)
624
+ app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
625
+ thread_pool.shutdown()