SteveZerb commited on
Commit
6ae945d
·
verified ·
1 Parent(s): 8bde534

Upload 3 files

Browse files
Files changed (3) hide show
  1. push_to_hub.py +59 -0
  2. requirements.txt +2 -3
  3. train.py +479 -0
push_to_hub.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from safetensors.torch import load_file as safe_load_file
6
+ from midi_model import config_name_list, MIDIModelConfig, MIDIModel
7
+
8
+ if __name__ == '__main__':
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument(
11
+ "--ckpt", type=str, default="", help="load ckpt"
12
+ )
13
+ parser.add_argument(
14
+ "--config", type=str, default="auto",
15
+ help="model config name, file or automatically find config.json"
16
+ )
17
+ parser.add_argument(
18
+ "--precision",
19
+ type=str,
20
+ default="bf16",
21
+ choices=["bf16", "fp16", "fp32"],
22
+ help="convert precision",
23
+ )
24
+ parser.add_argument(
25
+ "--repo-id", type=str, default="midi-model-test",
26
+ help="repo id"
27
+ )
28
+ parser.add_argument(
29
+ "--private", action="store_true", default=False, help="private repo"
30
+ )
31
+
32
+ opt = parser.parse_args()
33
+ print(opt)
34
+
35
+ if opt.config in config_name_list:
36
+ config = MIDIModelConfig.from_name(opt.config)
37
+ elif opt.config == "auto":
38
+ config_path = Path(opt.ckpt).parent / "config.json"
39
+ if config_path.exists():
40
+ config = MIDIModelConfig.from_json_file(config_path)
41
+ else:
42
+ raise ValueError("can not find config.json, please specify config")
43
+ else:
44
+ config = MIDIModelConfig.from_json_file(opt.config)
45
+
46
+ model = MIDIModel(config=config)
47
+ if opt.ckpt.endswith(".safetensors"):
48
+ state_dict = safe_load_file(opt.ckpt)
49
+ else:
50
+ ckpt = torch.load(opt.ckpt, map_location="cpu")
51
+ state_dict = ckpt.get("state_dict", ckpt)
52
+ model.load_state_dict(state_dict, strict=False)
53
+ precision_dict = {
54
+ "fp16": torch.float16,
55
+ "bf16": torch.bfloat16,
56
+ "fp32": torch.float32,
57
+ }
58
+ model.to(dtype=precision_dict[opt.precision]).eval()
59
+ model.push_to_hub(repo_id=opt.repo_id, private=opt.private)
requirements.txt CHANGED
@@ -1,11 +1,10 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu124
2
  Pillow
3
  numpy
4
  torch
5
- onnxruntime-gpu
6
  peft>=0.13.0
7
  transformers>=4.36
 
8
  gradio==5.3.0
9
  pyfluidsynth
10
  tqdm
11
- huggingface_hub
 
 
1
  Pillow
2
  numpy
3
  torch
4
+ safetensors
5
  peft>=0.13.0
6
  transformers>=4.36
7
+ lightning==2.4.0
8
  gradio==5.3.0
9
  pyfluidsynth
10
  tqdm
 
train.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import lightning as pl
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from lightning import Trainer
12
+ from lightning.fabric.utilities import rank_zero_only
13
+ from lightning.pytorch.callbacks import ModelCheckpoint
14
+ from peft import LoraConfig, TaskType
15
+ from safetensors.torch import save_file as safe_save_file
16
+ from torch import optim
17
+ from torch.optim.lr_scheduler import LambdaLR
18
+ from torch.utils.data import Dataset, DataLoader
19
+
20
+ import MIDI
21
+ from midi_model import MIDIModel, MIDIModelConfig, config_name_list
22
+ from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2
23
+
24
+ EXTENSION = [".mid", ".midi"]
25
+
26
+
27
+ def file_ext(fname):
28
+ return os.path.splitext(fname)[1].lower()
29
+
30
+
31
+ class MidiDataset(Dataset):
32
+ def __init__(self, midi_list, tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2], max_len=2048, min_file_size=3000,
33
+ max_file_size=384000,
34
+ aug=True, check_quality=False, rand_start=True):
35
+
36
+ self.tokenizer = tokenizer
37
+ self.midi_list = midi_list
38
+ self.max_len = max_len
39
+ self.min_file_size = min_file_size
40
+ self.max_file_size = max_file_size
41
+ self.aug = aug
42
+ self.check_quality = check_quality
43
+ self.rand_start = rand_start
44
+
45
+ def __len__(self):
46
+ return len(self.midi_list)
47
+
48
+ def load_midi(self, index):
49
+ path = self.midi_list[index]
50
+ try:
51
+ with open(path, 'rb') as f:
52
+ datas = f.read()
53
+ if len(datas) > self.max_file_size: # large midi file will spend too much time to load
54
+ raise ValueError("file too large")
55
+ elif len(datas) < self.min_file_size:
56
+ raise ValueError("file too small")
57
+ mid = MIDI.midi2score(datas)
58
+ if max([0] + [len(track) for track in mid[1:]]) == 0:
59
+ raise ValueError("empty track")
60
+ mid = self.tokenizer.tokenize(mid)
61
+ if self.check_quality and not self.tokenizer.check_quality(mid)[0]:
62
+ raise ValueError("bad quality")
63
+ if self.aug:
64
+ mid = self.tokenizer.augment(mid)
65
+ except Exception:
66
+ mid = self.load_midi(random.randint(0, self.__len__() - 1))
67
+ return mid
68
+
69
+ def __getitem__(self, index):
70
+ mid = self.load_midi(index)
71
+ mid = np.asarray(mid, dtype=np.int16)
72
+ # if mid.shape[0] < self.max_len:
73
+ # mid = np.pad(mid, ((0, self.max_len - mid.shape[0]), (0, 0)),
74
+ # mode="constant", constant_values=self.tokenizer.pad_id)
75
+ if self.rand_start:
76
+ start_idx = random.randrange(0, max(1, mid.shape[0] - self.max_len))
77
+ start_idx = random.choice([0, start_idx])
78
+ else:
79
+ max_start = max(1, mid.shape[0] - self.max_len)
80
+ start_idx = (index * (max_start // 8)) % max_start
81
+ mid = mid[start_idx: start_idx + self.max_len]
82
+ mid = mid.astype(np.int64)
83
+ mid = torch.from_numpy(mid)
84
+ return mid
85
+
86
+ def collate_fn(self, batch):
87
+ max_len = max([len(mid) for mid in batch])
88
+ batch = [F.pad(mid, (0, 0, 0, max_len - mid.shape[0]), mode="constant", value=self.tokenizer.pad_id) for mid in batch]
89
+ batch = torch.stack(batch)
90
+ return batch
91
+
92
+
93
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
94
+ """ Create a schedule with a learning rate that decreases linearly after
95
+ linearly increasing during a warmup period.
96
+ """
97
+
98
+ def lr_lambda(current_step):
99
+ if current_step < num_warmup_steps:
100
+ return float(current_step) / float(max(1, num_warmup_steps))
101
+ return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
102
+
103
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
104
+
105
+
106
+ class TrainMIDIModel(MIDIModel, pl.LightningModule):
107
+ def __init__(self, config: MIDIModelConfig,
108
+ lr=2e-4, weight_decay=0.01, warmup=1e3, max_step=1e6, sample_seq=False,
109
+ gen_example_interval=1, example_batch=8):
110
+ super(TrainMIDIModel, self).__init__(config)
111
+ self.lr = lr
112
+ self.weight_decay = weight_decay
113
+ self.warmup = warmup
114
+ self.max_step = max_step
115
+ self.sample_seq = sample_seq
116
+ self.gen_example_interval = gen_example_interval
117
+ self.example_batch = example_batch
118
+ self.last_save_step = 0
119
+ self.gen_example_count = 0
120
+
121
+ def configure_optimizers(self):
122
+ param_optimizer = list(self.named_parameters())
123
+ no_decay = ['bias', 'norm'] # no decay for bias and Norm
124
+ optimizer_grouped_parameters = [
125
+ {
126
+ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
127
+ 'weight_decay': self.weight_decay},
128
+ {
129
+ 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
130
+ 'weight_decay': 0.0
131
+ }
132
+ ]
133
+ optimizer = optim.AdamW(
134
+ optimizer_grouped_parameters,
135
+ lr=self.lr,
136
+ betas=(0.9, 0.99),
137
+ eps=1e-08,
138
+ )
139
+ lr_scheduler = get_linear_schedule_with_warmup(
140
+ optimizer=optimizer,
141
+ num_warmup_steps=self.warmup,
142
+ num_training_steps=self.max_step,
143
+ )
144
+ return {
145
+ "optimizer": optimizer,
146
+ "lr_scheduler": {
147
+ "scheduler": lr_scheduler,
148
+ "interval": "step",
149
+ "frequency": 1
150
+ }
151
+ }
152
+
153
+ def compute_accuracy(self, logits, labels):
154
+ out = torch.argmax(logits, dim=-1)
155
+ out = out.flatten()
156
+ labels = labels.flatten()
157
+
158
+ mask = (labels != self.tokenizer.pad_id)
159
+ out = out[mask]
160
+ labels = labels[mask]
161
+
162
+ num_right = (out == labels)
163
+ num_right = torch.sum(num_right).type(torch.float32)
164
+ acc = num_right / len(labels)
165
+
166
+ return acc
167
+
168
+ def training_step(self, batch, batch_idx):
169
+ x = batch[:, :-1].contiguous() # (batch_size, midi_sequence_length, token_sequence_length)
170
+ y = batch[:, 1:].contiguous()
171
+ hidden = self.forward(x)
172
+ if self.sample_seq: # to reduce vram
173
+ rand_idx = [-1] + random.sample(list(range(y.shape[1] - 2)), min(127, (y.shape[1] - 2) // 2))
174
+ hidden = hidden[:, rand_idx]
175
+ y = y[:, rand_idx]
176
+ hidden = hidden.reshape(-1, hidden.shape[-1])
177
+ y = y.reshape(-1, y.shape[-1]) # (batch_size*midi_sequence_length, token_sequence_length)
178
+ x = y[:, :-1]
179
+ logits = self.forward_token(hidden, x)
180
+ loss = F.cross_entropy(
181
+ logits.view(-1, self.tokenizer.vocab_size),
182
+ y.view(-1),
183
+ reduction="mean",
184
+ ignore_index=self.tokenizer.pad_id
185
+ )
186
+ self.log("train/loss", loss)
187
+ self.log("train/lr", self.lr_schedulers().get_last_lr()[0])
188
+ return loss
189
+
190
+ def validation_step(self, batch, batch_idx):
191
+ x = batch[:, :-1].contiguous() # (batch_size, midi_sequence_length, token_sequence_length)
192
+ y = batch[:, 1:].contiguous()
193
+ hidden = self.forward(x)
194
+ hidden = hidden.reshape(-1, hidden.shape[-1])
195
+ y = y.reshape(-1, y.shape[-1]) # (batch_size*midi_sequence_length, token_sequence_length)
196
+ x = y[:, :-1]
197
+ logits = self.forward_token(hidden, x)
198
+ loss = F.cross_entropy(
199
+ logits.view(-1, self.tokenizer.vocab_size),
200
+ y.view(-1),
201
+ reduction="mean",
202
+ ignore_index=self.tokenizer.pad_id
203
+ )
204
+ acc = self.compute_accuracy(logits, y)
205
+ self.log_dict({"val/loss": loss, "val/acc": acc}, sync_dist=True)
206
+ return loss
207
+
208
+ @rank_zero_only
209
+ def gen_example(self, save_dir):
210
+ base_dir = f"{save_dir}/sample/{self.global_step}"
211
+ if not os.path.exists(base_dir):
212
+ Path(base_dir).mkdir(parents=True)
213
+ midis = self.generate(batch_size=self.example_batch)
214
+ midis = [self.tokenizer.detokenize(midi) for midi in midis]
215
+ imgs = [self.tokenizer.midi2img(midi) for midi in midis]
216
+ for i, (img, midi) in enumerate(zip(imgs, midis)):
217
+ img.save(f"{base_dir}/0_{i}.png")
218
+ with open(f"{base_dir}/0_{i}.mid", 'wb') as f:
219
+ f.write(MIDI.score2midi(midi))
220
+ prompt = val_dataset.load_midi(random.randint(0, len(val_dataset) - 1))
221
+ prompt = np.asarray(prompt, dtype=np.int16)
222
+ ori = prompt[:512]
223
+ img = self.tokenizer.midi2img(self.tokenizer.detokenize(ori))
224
+ img.save(f"{base_dir}/1_ori.png")
225
+ prompt = prompt[:256].astype(np.int64)
226
+ midis = self.generate(prompt, batch_size=self.example_batch)
227
+ midis = [self.tokenizer.detokenize(midi) for midi in midis]
228
+ imgs = [self.tokenizer.midi2img(midi) for midi in midis]
229
+ for i, (img, midi) in enumerate(zip(imgs, midis)):
230
+ img.save(f"{base_dir}/1_{i}.png")
231
+ with open(f"{base_dir}/1_{i}.mid", 'wb') as f:
232
+ f.write(MIDI.score2midi(midi))
233
+
234
+ @rank_zero_only
235
+ def save_peft(self, save_dir):
236
+ adapter_name = self.active_adapters()[0]
237
+ adapter_config = self.peft_config[adapter_name]
238
+ if not os.path.exists(save_dir):
239
+ os.makedirs(save_dir, exist_ok=True)
240
+ adapter_config.save_pretrained(save_dir)
241
+ adapter_state_dict = self.get_adapter_state_dict(adapter_name)
242
+ safe_save_file(adapter_state_dict,
243
+ os.path.join(save_dir, "adapter_model.safetensors"),
244
+ metadata={"format": "pt"})
245
+
246
+ def on_save_checkpoint(self, checkpoint):
247
+ if self.global_step == self.last_save_step:
248
+ return
249
+ self.last_save_step = self.global_step
250
+ trainer = self.trainer
251
+ if len(trainer.loggers) > 0:
252
+ if trainer.loggers[0].save_dir is not None:
253
+ save_dir = trainer.loggers[0].save_dir
254
+ else:
255
+ save_dir = trainer.default_root_dir
256
+ name = trainer.loggers[0].name
257
+ version = trainer.loggers[0].version
258
+ version = version if isinstance(version, str) else f"version_{version}"
259
+ save_dir = os.path.join(save_dir, str(name), version)
260
+ else:
261
+ save_dir = trainer.default_root_dir
262
+ self.config.save_pretrained(os.path.join(save_dir, "checkpoints"))
263
+ if self._hf_peft_config_loaded:
264
+ self.save_peft(os.path.join(save_dir, "lora"))
265
+ self.gen_example_count += 1
266
+ if self.gen_example_interval>0 and self.gen_example_count % self.gen_example_interval == 0:
267
+ try:
268
+ self.gen_example(save_dir)
269
+ except Exception as e:
270
+ print(e)
271
+
272
+
273
+ def get_midi_list(path):
274
+ all_files = {
275
+ os.path.join(root, fname)
276
+ for root, _dirs, files in os.walk(path)
277
+ for fname in files
278
+ }
279
+ all_midis = sorted(
280
+ fname for fname in all_files if file_ext(fname) in EXTENSION
281
+ )
282
+ return all_midis
283
+
284
+
285
+ if __name__ == '__main__':
286
+ parser = argparse.ArgumentParser()
287
+ # model args
288
+ parser.add_argument(
289
+ "--resume", type=str, default="", help="resume training from ckpt"
290
+ )
291
+ parser.add_argument(
292
+ "--ckpt", type=str, default="", help="load ckpt"
293
+ )
294
+ parser.add_argument(
295
+ "--config", type=str, default="tv2o-medium", help="model config name or file"
296
+ )
297
+ parser.add_argument(
298
+ "--task", type=str, default="train", choices=["train", "lora"], help="Full train or lora"
299
+ )
300
+
301
+ # dataset args
302
+ parser.add_argument(
303
+ "--data", type=str, default="data", help="dataset path"
304
+ )
305
+ parser.add_argument(
306
+ "--data-val-split",
307
+ type=int,
308
+ default=128,
309
+ help="the number of midi files divided into the validation set",
310
+ )
311
+ parser.add_argument(
312
+ "--max-len",
313
+ type=int,
314
+ default=2048,
315
+ help="max seq length for training",
316
+ )
317
+ parser.add_argument(
318
+ "--quality", action="store_true", default=False, help="check dataset quality"
319
+ )
320
+
321
+ # training args
322
+ parser.add_argument("--seed", type=int, default=0, help="seed")
323
+ parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
324
+ parser.add_argument("--weight-decay", type=float, default=0.01, help="weight decay")
325
+ parser.add_argument("--warmup-step", type=int, default=1e2, help="warmup step")
326
+ parser.add_argument("--max-step", type=int, default=1e6, help="max training step")
327
+ parser.add_argument("--grad-clip", type=float, default=1.0, help="gradient clip val")
328
+ parser.add_argument(
329
+ "--sample-seq", action="store_true", default=False, help="sample midi seq to reduce vram"
330
+ )
331
+ parser.add_argument(
332
+ "--gen-example-interval", type=int, default=1, help="generate example interval. set 0 to disable"
333
+ )
334
+ parser.add_argument(
335
+ "--batch-size-train", type=int, default=2, help="batch size for training"
336
+ )
337
+ parser.add_argument(
338
+ "--batch-size-val", type=int, default=2, help="batch size for val"
339
+ )
340
+ parser.add_argument(
341
+ "--batch-size-gen-example", type=int, default=8, help="batch size for generate example"
342
+ )
343
+ parser.add_argument(
344
+ "--workers-train",
345
+ type=int,
346
+ default=4,
347
+ help="workers num for training dataloader",
348
+ )
349
+ parser.add_argument(
350
+ "--workers-val",
351
+ type=int,
352
+ default=4,
353
+ help="workers num for validation dataloader",
354
+ )
355
+ parser.add_argument(
356
+ "--acc-grad", type=int, default=2, help="gradient accumulation"
357
+ )
358
+ parser.add_argument(
359
+ "--accelerator",
360
+ type=str,
361
+ default="gpu",
362
+ choices=["cpu", "gpu", "tpu", "ipu", "hpu", "auto"],
363
+ help="accelerator",
364
+ )
365
+ parser.add_argument(
366
+ "--precision",
367
+ type=str,
368
+ default="bf16-true",
369
+ choices=["16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true", "64", "32", "16", "bf16"],
370
+ help="precision",
371
+ )
372
+ parser.add_argument("--devices", type=int, default=-1, help="devices num")
373
+ parser.add_argument("--nodes", type=int, default=1, help="nodes num")
374
+ parser.add_argument(
375
+ "--disable-benchmark", action="store_true", default=False, help="disable cudnn benchmark"
376
+ )
377
+ parser.add_argument(
378
+ "--log-step", type=int, default=1, help="log training loss every n steps"
379
+ )
380
+ parser.add_argument(
381
+ "--val-step", type=int, default=1600, help="valid and save every n steps, set 0 to valid and save every epoch"
382
+ )
383
+
384
+ opt = parser.parse_args()
385
+ print(opt)
386
+
387
+ if not os.path.exists("lightning_logs"):
388
+ os.mkdir("lightning_logs")
389
+ if not os.path.exists("sample"):
390
+ os.mkdir("sample")
391
+ pl.seed_everything(opt.seed)
392
+ print("---load dataset---")
393
+ if opt.config in config_name_list:
394
+ config = MIDIModelConfig.from_name(opt.config)
395
+ else:
396
+ config = MIDIModelConfig.from_json_file(opt.config)
397
+ tokenizer = config.tokenizer
398
+ midi_list = get_midi_list(opt.data)
399
+ random.shuffle(midi_list)
400
+ full_dataset_len = len(midi_list)
401
+ train_dataset_len = full_dataset_len - opt.data_val_split
402
+ train_midi_list = midi_list[:train_dataset_len]
403
+ val_midi_list = midi_list[train_dataset_len:]
404
+ train_dataset = MidiDataset(train_midi_list, tokenizer, max_len=opt.max_len, aug=True, check_quality=opt.quality,
405
+ rand_start=True)
406
+ val_dataset = MidiDataset(val_midi_list, tokenizer, max_len=opt.max_len, aug=False, check_quality=opt.quality,
407
+ rand_start=False)
408
+ train_dataloader = DataLoader(
409
+ train_dataset,
410
+ batch_size=opt.batch_size_train,
411
+ shuffle=True,
412
+ persistent_workers=True,
413
+ num_workers=opt.workers_train,
414
+ pin_memory=True,
415
+ collate_fn=train_dataset.collate_fn
416
+ )
417
+ val_dataloader = DataLoader(
418
+ val_dataset,
419
+ batch_size=opt.batch_size_val,
420
+ shuffle=False,
421
+ persistent_workers=True,
422
+ num_workers=opt.workers_val,
423
+ pin_memory=True,
424
+ collate_fn=val_dataset.collate_fn
425
+ )
426
+ print(f"train: {len(train_dataset)} val: {len(val_dataset)}")
427
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
428
+ torch.backends.cuda.enable_flash_sdp(True)
429
+ model = TrainMIDIModel(config, lr=opt.lr, weight_decay=opt.weight_decay,
430
+ warmup=opt.warmup_step, max_step=opt.max_step,
431
+ sample_seq=opt.sample_seq, gen_example_interval=opt.gen_example_interval,
432
+ example_batch=opt.batch_size_gen_example)
433
+ if opt.ckpt:
434
+ ckpt = torch.load(opt.ckpt, map_location="cpu")
435
+ state_dict = ckpt.get("state_dict", ckpt)
436
+ model.load_state_dict(state_dict, strict=False)
437
+ elif opt.task == "lora":
438
+ raise ValueError("--ckpt must be set to train lora")
439
+ if opt.task == "lora":
440
+ model.requires_grad_(False)
441
+ lora_config = LoraConfig(
442
+ r=64,
443
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
444
+ task_type=TaskType.CAUSAL_LM,
445
+ bias="none",
446
+ lora_alpha=128,
447
+ lora_dropout=0
448
+ )
449
+ model.add_adapter(lora_config)
450
+ print("---start train---")
451
+ checkpoint_callback = ModelCheckpoint(
452
+ monitor="val/loss",
453
+ mode="min",
454
+ save_top_k=1,
455
+ save_last=True,
456
+ auto_insert_metric_name=False,
457
+ filename="epoch={epoch},loss={val/loss:.4f}",
458
+ )
459
+ callbacks = [checkpoint_callback]
460
+
461
+ trainer = Trainer(
462
+ precision=opt.precision,
463
+ accumulate_grad_batches=opt.acc_grad,
464
+ gradient_clip_val=opt.grad_clip,
465
+ accelerator=opt.accelerator,
466
+ devices=opt.devices,
467
+ num_nodes=opt.nodes,
468
+ max_steps=opt.max_step,
469
+ benchmark=not opt.disable_benchmark,
470
+ val_check_interval=opt.val_step or None,
471
+ log_every_n_steps=1,
472
+ strategy="auto",
473
+ callbacks=callbacks,
474
+ )
475
+ ckpt_path = opt.resume
476
+ if ckpt_path == "":
477
+ ckpt_path = None
478
+ print("---start train---")
479
+ trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)