SteveZerb commited on
Commit
f11edaf
·
verified ·
1 Parent(s): 6644593

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -479
train.py DELETED
@@ -1,479 +0,0 @@
1
- import argparse
2
- import os
3
- import random
4
- from pathlib import Path
5
- from typing import Union
6
-
7
- import lightning as pl
8
- import numpy as np
9
- import torch
10
- import torch.nn.functional as F
11
- from lightning import Trainer
12
- from lightning.fabric.utilities import rank_zero_only
13
- from lightning.pytorch.callbacks import ModelCheckpoint
14
- from peft import LoraConfig, TaskType
15
- from safetensors.torch import save_file as safe_save_file
16
- from torch import optim
17
- from torch.optim.lr_scheduler import LambdaLR
18
- from torch.utils.data import Dataset, DataLoader
19
-
20
- import MIDI
21
- from midi_model import MIDIModel, MIDIModelConfig, config_name_list
22
- from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2
23
-
24
- EXTENSION = [".mid", ".midi"]
25
-
26
-
27
- def file_ext(fname):
28
- return os.path.splitext(fname)[1].lower()
29
-
30
-
31
- class MidiDataset(Dataset):
32
- def __init__(self, midi_list, tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2], max_len=2048, min_file_size=3000,
33
- max_file_size=384000,
34
- aug=True, check_quality=False, rand_start=True):
35
-
36
- self.tokenizer = tokenizer
37
- self.midi_list = midi_list
38
- self.max_len = max_len
39
- self.min_file_size = min_file_size
40
- self.max_file_size = max_file_size
41
- self.aug = aug
42
- self.check_quality = check_quality
43
- self.rand_start = rand_start
44
-
45
- def __len__(self):
46
- return len(self.midi_list)
47
-
48
- def load_midi(self, index):
49
- path = self.midi_list[index]
50
- try:
51
- with open(path, 'rb') as f:
52
- datas = f.read()
53
- if len(datas) > self.max_file_size: # large midi file will spend too much time to load
54
- raise ValueError("file too large")
55
- elif len(datas) < self.min_file_size:
56
- raise ValueError("file too small")
57
- mid = MIDI.midi2score(datas)
58
- if max([0] + [len(track) for track in mid[1:]]) == 0:
59
- raise ValueError("empty track")
60
- mid = self.tokenizer.tokenize(mid)
61
- if self.check_quality and not self.tokenizer.check_quality(mid)[0]:
62
- raise ValueError("bad quality")
63
- if self.aug:
64
- mid = self.tokenizer.augment(mid)
65
- except Exception:
66
- mid = self.load_midi(random.randint(0, self.__len__() - 1))
67
- return mid
68
-
69
- def __getitem__(self, index):
70
- mid = self.load_midi(index)
71
- mid = np.asarray(mid, dtype=np.int16)
72
- # if mid.shape[0] < self.max_len:
73
- # mid = np.pad(mid, ((0, self.max_len - mid.shape[0]), (0, 0)),
74
- # mode="constant", constant_values=self.tokenizer.pad_id)
75
- if self.rand_start:
76
- start_idx = random.randrange(0, max(1, mid.shape[0] - self.max_len))
77
- start_idx = random.choice([0, start_idx])
78
- else:
79
- max_start = max(1, mid.shape[0] - self.max_len)
80
- start_idx = (index * (max_start // 8)) % max_start
81
- mid = mid[start_idx: start_idx + self.max_len]
82
- mid = mid.astype(np.int64)
83
- mid = torch.from_numpy(mid)
84
- return mid
85
-
86
- def collate_fn(self, batch):
87
- max_len = max([len(mid) for mid in batch])
88
- batch = [F.pad(mid, (0, 0, 0, max_len - mid.shape[0]), mode="constant", value=self.tokenizer.pad_id) for mid in batch]
89
- batch = torch.stack(batch)
90
- return batch
91
-
92
-
93
- def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
94
- """ Create a schedule with a learning rate that decreases linearly after
95
- linearly increasing during a warmup period.
96
- """
97
-
98
- def lr_lambda(current_step):
99
- if current_step < num_warmup_steps:
100
- return float(current_step) / float(max(1, num_warmup_steps))
101
- return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
102
-
103
- return LambdaLR(optimizer, lr_lambda, last_epoch)
104
-
105
-
106
- class TrainMIDIModel(MIDIModel, pl.LightningModule):
107
- def __init__(self, config: MIDIModelConfig,
108
- lr=2e-4, weight_decay=0.01, warmup=1e3, max_step=1e6, sample_seq=False,
109
- gen_example_interval=1, example_batch=8):
110
- super(TrainMIDIModel, self).__init__(config)
111
- self.lr = lr
112
- self.weight_decay = weight_decay
113
- self.warmup = warmup
114
- self.max_step = max_step
115
- self.sample_seq = sample_seq
116
- self.gen_example_interval = gen_example_interval
117
- self.example_batch = example_batch
118
- self.last_save_step = 0
119
- self.gen_example_count = 0
120
-
121
- def configure_optimizers(self):
122
- param_optimizer = list(self.named_parameters())
123
- no_decay = ['bias', 'norm'] # no decay for bias and Norm
124
- optimizer_grouped_parameters = [
125
- {
126
- 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
127
- 'weight_decay': self.weight_decay},
128
- {
129
- 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
130
- 'weight_decay': 0.0
131
- }
132
- ]
133
- optimizer = optim.AdamW(
134
- optimizer_grouped_parameters,
135
- lr=self.lr,
136
- betas=(0.9, 0.99),
137
- eps=1e-08,
138
- )
139
- lr_scheduler = get_linear_schedule_with_warmup(
140
- optimizer=optimizer,
141
- num_warmup_steps=self.warmup,
142
- num_training_steps=self.max_step,
143
- )
144
- return {
145
- "optimizer": optimizer,
146
- "lr_scheduler": {
147
- "scheduler": lr_scheduler,
148
- "interval": "step",
149
- "frequency": 1
150
- }
151
- }
152
-
153
- def compute_accuracy(self, logits, labels):
154
- out = torch.argmax(logits, dim=-1)
155
- out = out.flatten()
156
- labels = labels.flatten()
157
-
158
- mask = (labels != self.tokenizer.pad_id)
159
- out = out[mask]
160
- labels = labels[mask]
161
-
162
- num_right = (out == labels)
163
- num_right = torch.sum(num_right).type(torch.float32)
164
- acc = num_right / len(labels)
165
-
166
- return acc
167
-
168
- def training_step(self, batch, batch_idx):
169
- x = batch[:, :-1].contiguous() # (batch_size, midi_sequence_length, token_sequence_length)
170
- y = batch[:, 1:].contiguous()
171
- hidden = self.forward(x)
172
- if self.sample_seq: # to reduce vram
173
- rand_idx = [-1] + random.sample(list(range(y.shape[1] - 2)), min(127, (y.shape[1] - 2) // 2))
174
- hidden = hidden[:, rand_idx]
175
- y = y[:, rand_idx]
176
- hidden = hidden.reshape(-1, hidden.shape[-1])
177
- y = y.reshape(-1, y.shape[-1]) # (batch_size*midi_sequence_length, token_sequence_length)
178
- x = y[:, :-1]
179
- logits = self.forward_token(hidden, x)
180
- loss = F.cross_entropy(
181
- logits.view(-1, self.tokenizer.vocab_size),
182
- y.view(-1),
183
- reduction="mean",
184
- ignore_index=self.tokenizer.pad_id
185
- )
186
- self.log("train/loss", loss)
187
- self.log("train/lr", self.lr_schedulers().get_last_lr()[0])
188
- return loss
189
-
190
- def validation_step(self, batch, batch_idx):
191
- x = batch[:, :-1].contiguous() # (batch_size, midi_sequence_length, token_sequence_length)
192
- y = batch[:, 1:].contiguous()
193
- hidden = self.forward(x)
194
- hidden = hidden.reshape(-1, hidden.shape[-1])
195
- y = y.reshape(-1, y.shape[-1]) # (batch_size*midi_sequence_length, token_sequence_length)
196
- x = y[:, :-1]
197
- logits = self.forward_token(hidden, x)
198
- loss = F.cross_entropy(
199
- logits.view(-1, self.tokenizer.vocab_size),
200
- y.view(-1),
201
- reduction="mean",
202
- ignore_index=self.tokenizer.pad_id
203
- )
204
- acc = self.compute_accuracy(logits, y)
205
- self.log_dict({"val/loss": loss, "val/acc": acc}, sync_dist=True)
206
- return loss
207
-
208
- @rank_zero_only
209
- def gen_example(self, save_dir):
210
- base_dir = f"{save_dir}/sample/{self.global_step}"
211
- if not os.path.exists(base_dir):
212
- Path(base_dir).mkdir(parents=True)
213
- midis = self.generate(batch_size=self.example_batch)
214
- midis = [self.tokenizer.detokenize(midi) for midi in midis]
215
- imgs = [self.tokenizer.midi2img(midi) for midi in midis]
216
- for i, (img, midi) in enumerate(zip(imgs, midis)):
217
- img.save(f"{base_dir}/0_{i}.png")
218
- with open(f"{base_dir}/0_{i}.mid", 'wb') as f:
219
- f.write(MIDI.score2midi(midi))
220
- prompt = val_dataset.load_midi(random.randint(0, len(val_dataset) - 1))
221
- prompt = np.asarray(prompt, dtype=np.int16)
222
- ori = prompt[:512]
223
- img = self.tokenizer.midi2img(self.tokenizer.detokenize(ori))
224
- img.save(f"{base_dir}/1_ori.png")
225
- prompt = prompt[:256].astype(np.int64)
226
- midis = self.generate(prompt, batch_size=self.example_batch)
227
- midis = [self.tokenizer.detokenize(midi) for midi in midis]
228
- imgs = [self.tokenizer.midi2img(midi) for midi in midis]
229
- for i, (img, midi) in enumerate(zip(imgs, midis)):
230
- img.save(f"{base_dir}/1_{i}.png")
231
- with open(f"{base_dir}/1_{i}.mid", 'wb') as f:
232
- f.write(MIDI.score2midi(midi))
233
-
234
- @rank_zero_only
235
- def save_peft(self, save_dir):
236
- adapter_name = self.active_adapters()[0]
237
- adapter_config = self.peft_config[adapter_name]
238
- if not os.path.exists(save_dir):
239
- os.makedirs(save_dir, exist_ok=True)
240
- adapter_config.save_pretrained(save_dir)
241
- adapter_state_dict = self.get_adapter_state_dict(adapter_name)
242
- safe_save_file(adapter_state_dict,
243
- os.path.join(save_dir, "adapter_model.safetensors"),
244
- metadata={"format": "pt"})
245
-
246
- def on_save_checkpoint(self, checkpoint):
247
- if self.global_step == self.last_save_step:
248
- return
249
- self.last_save_step = self.global_step
250
- trainer = self.trainer
251
- if len(trainer.loggers) > 0:
252
- if trainer.loggers[0].save_dir is not None:
253
- save_dir = trainer.loggers[0].save_dir
254
- else:
255
- save_dir = trainer.default_root_dir
256
- name = trainer.loggers[0].name
257
- version = trainer.loggers[0].version
258
- version = version if isinstance(version, str) else f"version_{version}"
259
- save_dir = os.path.join(save_dir, str(name), version)
260
- else:
261
- save_dir = trainer.default_root_dir
262
- self.config.save_pretrained(os.path.join(save_dir, "checkpoints"))
263
- if self._hf_peft_config_loaded:
264
- self.save_peft(os.path.join(save_dir, "lora"))
265
- self.gen_example_count += 1
266
- if self.gen_example_interval>0 and self.gen_example_count % self.gen_example_interval == 0:
267
- try:
268
- self.gen_example(save_dir)
269
- except Exception as e:
270
- print(e)
271
-
272
-
273
- def get_midi_list(path):
274
- all_files = {
275
- os.path.join(root, fname)
276
- for root, _dirs, files in os.walk(path)
277
- for fname in files
278
- }
279
- all_midis = sorted(
280
- fname for fname in all_files if file_ext(fname) in EXTENSION
281
- )
282
- return all_midis
283
-
284
-
285
- if __name__ == '__main__':
286
- parser = argparse.ArgumentParser()
287
- # model args
288
- parser.add_argument(
289
- "--resume", type=str, default="", help="resume training from ckpt"
290
- )
291
- parser.add_argument(
292
- "--ckpt", type=str, default="", help="load ckpt"
293
- )
294
- parser.add_argument(
295
- "--config", type=str, default="tv2o-medium", help="model config name or file"
296
- )
297
- parser.add_argument(
298
- "--task", type=str, default="train", choices=["train", "lora"], help="Full train or lora"
299
- )
300
-
301
- # dataset args
302
- parser.add_argument(
303
- "--data", type=str, default="data", help="dataset path"
304
- )
305
- parser.add_argument(
306
- "--data-val-split",
307
- type=int,
308
- default=128,
309
- help="the number of midi files divided into the validation set",
310
- )
311
- parser.add_argument(
312
- "--max-len",
313
- type=int,
314
- default=2048,
315
- help="max seq length for training",
316
- )
317
- parser.add_argument(
318
- "--quality", action="store_true", default=False, help="check dataset quality"
319
- )
320
-
321
- # training args
322
- parser.add_argument("--seed", type=int, default=0, help="seed")
323
- parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
324
- parser.add_argument("--weight-decay", type=float, default=0.01, help="weight decay")
325
- parser.add_argument("--warmup-step", type=int, default=1e2, help="warmup step")
326
- parser.add_argument("--max-step", type=int, default=1e6, help="max training step")
327
- parser.add_argument("--grad-clip", type=float, default=1.0, help="gradient clip val")
328
- parser.add_argument(
329
- "--sample-seq", action="store_true", default=False, help="sample midi seq to reduce vram"
330
- )
331
- parser.add_argument(
332
- "--gen-example-interval", type=int, default=1, help="generate example interval. set 0 to disable"
333
- )
334
- parser.add_argument(
335
- "--batch-size-train", type=int, default=2, help="batch size for training"
336
- )
337
- parser.add_argument(
338
- "--batch-size-val", type=int, default=2, help="batch size for val"
339
- )
340
- parser.add_argument(
341
- "--batch-size-gen-example", type=int, default=8, help="batch size for generate example"
342
- )
343
- parser.add_argument(
344
- "--workers-train",
345
- type=int,
346
- default=4,
347
- help="workers num for training dataloader",
348
- )
349
- parser.add_argument(
350
- "--workers-val",
351
- type=int,
352
- default=4,
353
- help="workers num for validation dataloader",
354
- )
355
- parser.add_argument(
356
- "--acc-grad", type=int, default=2, help="gradient accumulation"
357
- )
358
- parser.add_argument(
359
- "--accelerator",
360
- type=str,
361
- default="gpu",
362
- choices=["cpu", "gpu", "tpu", "ipu", "hpu", "auto"],
363
- help="accelerator",
364
- )
365
- parser.add_argument(
366
- "--precision",
367
- type=str,
368
- default="bf16-true",
369
- choices=["16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true", "64", "32", "16", "bf16"],
370
- help="precision",
371
- )
372
- parser.add_argument("--devices", type=int, default=-1, help="devices num")
373
- parser.add_argument("--nodes", type=int, default=1, help="nodes num")
374
- parser.add_argument(
375
- "--disable-benchmark", action="store_true", default=False, help="disable cudnn benchmark"
376
- )
377
- parser.add_argument(
378
- "--log-step", type=int, default=1, help="log training loss every n steps"
379
- )
380
- parser.add_argument(
381
- "--val-step", type=int, default=1600, help="valid and save every n steps, set 0 to valid and save every epoch"
382
- )
383
-
384
- opt = parser.parse_args()
385
- print(opt)
386
-
387
- if not os.path.exists("lightning_logs"):
388
- os.mkdir("lightning_logs")
389
- if not os.path.exists("sample"):
390
- os.mkdir("sample")
391
- pl.seed_everything(opt.seed)
392
- print("---load dataset---")
393
- if opt.config in config_name_list:
394
- config = MIDIModelConfig.from_name(opt.config)
395
- else:
396
- config = MIDIModelConfig.from_json_file(opt.config)
397
- tokenizer = config.tokenizer
398
- midi_list = get_midi_list(opt.data)
399
- random.shuffle(midi_list)
400
- full_dataset_len = len(midi_list)
401
- train_dataset_len = full_dataset_len - opt.data_val_split
402
- train_midi_list = midi_list[:train_dataset_len]
403
- val_midi_list = midi_list[train_dataset_len:]
404
- train_dataset = MidiDataset(train_midi_list, tokenizer, max_len=opt.max_len, aug=True, check_quality=opt.quality,
405
- rand_start=True)
406
- val_dataset = MidiDataset(val_midi_list, tokenizer, max_len=opt.max_len, aug=False, check_quality=opt.quality,
407
- rand_start=False)
408
- train_dataloader = DataLoader(
409
- train_dataset,
410
- batch_size=opt.batch_size_train,
411
- shuffle=True,
412
- persistent_workers=True,
413
- num_workers=opt.workers_train,
414
- pin_memory=True,
415
- collate_fn=train_dataset.collate_fn
416
- )
417
- val_dataloader = DataLoader(
418
- val_dataset,
419
- batch_size=opt.batch_size_val,
420
- shuffle=False,
421
- persistent_workers=True,
422
- num_workers=opt.workers_val,
423
- pin_memory=True,
424
- collate_fn=val_dataset.collate_fn
425
- )
426
- print(f"train: {len(train_dataset)} val: {len(val_dataset)}")
427
- torch.backends.cuda.enable_mem_efficient_sdp(True)
428
- torch.backends.cuda.enable_flash_sdp(True)
429
- model = TrainMIDIModel(config, lr=opt.lr, weight_decay=opt.weight_decay,
430
- warmup=opt.warmup_step, max_step=opt.max_step,
431
- sample_seq=opt.sample_seq, gen_example_interval=opt.gen_example_interval,
432
- example_batch=opt.batch_size_gen_example)
433
- if opt.ckpt:
434
- ckpt = torch.load(opt.ckpt, map_location="cpu")
435
- state_dict = ckpt.get("state_dict", ckpt)
436
- model.load_state_dict(state_dict, strict=False)
437
- elif opt.task == "lora":
438
- raise ValueError("--ckpt must be set to train lora")
439
- if opt.task == "lora":
440
- model.requires_grad_(False)
441
- lora_config = LoraConfig(
442
- r=64,
443
- target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
444
- task_type=TaskType.CAUSAL_LM,
445
- bias="none",
446
- lora_alpha=128,
447
- lora_dropout=0
448
- )
449
- model.add_adapter(lora_config)
450
- print("---start train---")
451
- checkpoint_callback = ModelCheckpoint(
452
- monitor="val/loss",
453
- mode="min",
454
- save_top_k=1,
455
- save_last=True,
456
- auto_insert_metric_name=False,
457
- filename="epoch={epoch},loss={val/loss:.4f}",
458
- )
459
- callbacks = [checkpoint_callback]
460
-
461
- trainer = Trainer(
462
- precision=opt.precision,
463
- accumulate_grad_batches=opt.acc_grad,
464
- gradient_clip_val=opt.grad_clip,
465
- accelerator=opt.accelerator,
466
- devices=opt.devices,
467
- num_nodes=opt.nodes,
468
- max_steps=opt.max_step,
469
- benchmark=not opt.disable_benchmark,
470
- val_check_interval=opt.val_step or None,
471
- log_every_n_steps=1,
472
- strategy="auto",
473
- callbacks=callbacks,
474
- )
475
- ckpt_path = opt.resume
476
- if ckpt_path == "":
477
- ckpt_path = None
478
- print("---start train---")
479
- trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)