Upload 3 files
Browse files- push_to_hub.py +59 -0
- requirements.txt +2 -3
- train.py +479 -0
push_to_hub.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from safetensors.torch import load_file as safe_load_file
|
6 |
+
from midi_model import config_name_list, MIDIModelConfig, MIDIModel
|
7 |
+
|
8 |
+
if __name__ == '__main__':
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument(
|
11 |
+
"--ckpt", type=str, default="", help="load ckpt"
|
12 |
+
)
|
13 |
+
parser.add_argument(
|
14 |
+
"--config", type=str, default="auto",
|
15 |
+
help="model config name, file or automatically find config.json"
|
16 |
+
)
|
17 |
+
parser.add_argument(
|
18 |
+
"--precision",
|
19 |
+
type=str,
|
20 |
+
default="bf16",
|
21 |
+
choices=["bf16", "fp16", "fp32"],
|
22 |
+
help="convert precision",
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--repo-id", type=str, default="midi-model-test",
|
26 |
+
help="repo id"
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--private", action="store_true", default=False, help="private repo"
|
30 |
+
)
|
31 |
+
|
32 |
+
opt = parser.parse_args()
|
33 |
+
print(opt)
|
34 |
+
|
35 |
+
if opt.config in config_name_list:
|
36 |
+
config = MIDIModelConfig.from_name(opt.config)
|
37 |
+
elif opt.config == "auto":
|
38 |
+
config_path = Path(opt.ckpt).parent / "config.json"
|
39 |
+
if config_path.exists():
|
40 |
+
config = MIDIModelConfig.from_json_file(config_path)
|
41 |
+
else:
|
42 |
+
raise ValueError("can not find config.json, please specify config")
|
43 |
+
else:
|
44 |
+
config = MIDIModelConfig.from_json_file(opt.config)
|
45 |
+
|
46 |
+
model = MIDIModel(config=config)
|
47 |
+
if opt.ckpt.endswith(".safetensors"):
|
48 |
+
state_dict = safe_load_file(opt.ckpt)
|
49 |
+
else:
|
50 |
+
ckpt = torch.load(opt.ckpt, map_location="cpu")
|
51 |
+
state_dict = ckpt.get("state_dict", ckpt)
|
52 |
+
model.load_state_dict(state_dict, strict=False)
|
53 |
+
precision_dict = {
|
54 |
+
"fp16": torch.float16,
|
55 |
+
"bf16": torch.bfloat16,
|
56 |
+
"fp32": torch.float32,
|
57 |
+
}
|
58 |
+
model.to(dtype=precision_dict[opt.precision]).eval()
|
59 |
+
model.push_to_hub(repo_id=opt.repo_id, private=opt.private)
|
requirements.txt
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
-
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
Pillow
|
3 |
numpy
|
4 |
torch
|
5 |
-
|
6 |
peft>=0.13.0
|
7 |
transformers>=4.36
|
|
|
8 |
gradio==5.3.0
|
9 |
pyfluidsynth
|
10 |
tqdm
|
11 |
-
huggingface_hub
|
|
|
|
|
1 |
Pillow
|
2 |
numpy
|
3 |
torch
|
4 |
+
safetensors
|
5 |
peft>=0.13.0
|
6 |
transformers>=4.36
|
7 |
+
lightning==2.4.0
|
8 |
gradio==5.3.0
|
9 |
pyfluidsynth
|
10 |
tqdm
|
|
train.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import lightning as pl
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from lightning import Trainer
|
12 |
+
from lightning.fabric.utilities import rank_zero_only
|
13 |
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
14 |
+
from peft import LoraConfig, TaskType
|
15 |
+
from safetensors.torch import save_file as safe_save_file
|
16 |
+
from torch import optim
|
17 |
+
from torch.optim.lr_scheduler import LambdaLR
|
18 |
+
from torch.utils.data import Dataset, DataLoader
|
19 |
+
|
20 |
+
import MIDI
|
21 |
+
from midi_model import MIDIModel, MIDIModelConfig, config_name_list
|
22 |
+
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2
|
23 |
+
|
24 |
+
EXTENSION = [".mid", ".midi"]
|
25 |
+
|
26 |
+
|
27 |
+
def file_ext(fname):
|
28 |
+
return os.path.splitext(fname)[1].lower()
|
29 |
+
|
30 |
+
|
31 |
+
class MidiDataset(Dataset):
|
32 |
+
def __init__(self, midi_list, tokenizer: Union[MIDITokenizerV1, MIDITokenizerV2], max_len=2048, min_file_size=3000,
|
33 |
+
max_file_size=384000,
|
34 |
+
aug=True, check_quality=False, rand_start=True):
|
35 |
+
|
36 |
+
self.tokenizer = tokenizer
|
37 |
+
self.midi_list = midi_list
|
38 |
+
self.max_len = max_len
|
39 |
+
self.min_file_size = min_file_size
|
40 |
+
self.max_file_size = max_file_size
|
41 |
+
self.aug = aug
|
42 |
+
self.check_quality = check_quality
|
43 |
+
self.rand_start = rand_start
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return len(self.midi_list)
|
47 |
+
|
48 |
+
def load_midi(self, index):
|
49 |
+
path = self.midi_list[index]
|
50 |
+
try:
|
51 |
+
with open(path, 'rb') as f:
|
52 |
+
datas = f.read()
|
53 |
+
if len(datas) > self.max_file_size: # large midi file will spend too much time to load
|
54 |
+
raise ValueError("file too large")
|
55 |
+
elif len(datas) < self.min_file_size:
|
56 |
+
raise ValueError("file too small")
|
57 |
+
mid = MIDI.midi2score(datas)
|
58 |
+
if max([0] + [len(track) for track in mid[1:]]) == 0:
|
59 |
+
raise ValueError("empty track")
|
60 |
+
mid = self.tokenizer.tokenize(mid)
|
61 |
+
if self.check_quality and not self.tokenizer.check_quality(mid)[0]:
|
62 |
+
raise ValueError("bad quality")
|
63 |
+
if self.aug:
|
64 |
+
mid = self.tokenizer.augment(mid)
|
65 |
+
except Exception:
|
66 |
+
mid = self.load_midi(random.randint(0, self.__len__() - 1))
|
67 |
+
return mid
|
68 |
+
|
69 |
+
def __getitem__(self, index):
|
70 |
+
mid = self.load_midi(index)
|
71 |
+
mid = np.asarray(mid, dtype=np.int16)
|
72 |
+
# if mid.shape[0] < self.max_len:
|
73 |
+
# mid = np.pad(mid, ((0, self.max_len - mid.shape[0]), (0, 0)),
|
74 |
+
# mode="constant", constant_values=self.tokenizer.pad_id)
|
75 |
+
if self.rand_start:
|
76 |
+
start_idx = random.randrange(0, max(1, mid.shape[0] - self.max_len))
|
77 |
+
start_idx = random.choice([0, start_idx])
|
78 |
+
else:
|
79 |
+
max_start = max(1, mid.shape[0] - self.max_len)
|
80 |
+
start_idx = (index * (max_start // 8)) % max_start
|
81 |
+
mid = mid[start_idx: start_idx + self.max_len]
|
82 |
+
mid = mid.astype(np.int64)
|
83 |
+
mid = torch.from_numpy(mid)
|
84 |
+
return mid
|
85 |
+
|
86 |
+
def collate_fn(self, batch):
|
87 |
+
max_len = max([len(mid) for mid in batch])
|
88 |
+
batch = [F.pad(mid, (0, 0, 0, max_len - mid.shape[0]), mode="constant", value=self.tokenizer.pad_id) for mid in batch]
|
89 |
+
batch = torch.stack(batch)
|
90 |
+
return batch
|
91 |
+
|
92 |
+
|
93 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
94 |
+
""" Create a schedule with a learning rate that decreases linearly after
|
95 |
+
linearly increasing during a warmup period.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def lr_lambda(current_step):
|
99 |
+
if current_step < num_warmup_steps:
|
100 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
101 |
+
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
|
102 |
+
|
103 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
104 |
+
|
105 |
+
|
106 |
+
class TrainMIDIModel(MIDIModel, pl.LightningModule):
|
107 |
+
def __init__(self, config: MIDIModelConfig,
|
108 |
+
lr=2e-4, weight_decay=0.01, warmup=1e3, max_step=1e6, sample_seq=False,
|
109 |
+
gen_example_interval=1, example_batch=8):
|
110 |
+
super(TrainMIDIModel, self).__init__(config)
|
111 |
+
self.lr = lr
|
112 |
+
self.weight_decay = weight_decay
|
113 |
+
self.warmup = warmup
|
114 |
+
self.max_step = max_step
|
115 |
+
self.sample_seq = sample_seq
|
116 |
+
self.gen_example_interval = gen_example_interval
|
117 |
+
self.example_batch = example_batch
|
118 |
+
self.last_save_step = 0
|
119 |
+
self.gen_example_count = 0
|
120 |
+
|
121 |
+
def configure_optimizers(self):
|
122 |
+
param_optimizer = list(self.named_parameters())
|
123 |
+
no_decay = ['bias', 'norm'] # no decay for bias and Norm
|
124 |
+
optimizer_grouped_parameters = [
|
125 |
+
{
|
126 |
+
'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
127 |
+
'weight_decay': self.weight_decay},
|
128 |
+
{
|
129 |
+
'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
|
130 |
+
'weight_decay': 0.0
|
131 |
+
}
|
132 |
+
]
|
133 |
+
optimizer = optim.AdamW(
|
134 |
+
optimizer_grouped_parameters,
|
135 |
+
lr=self.lr,
|
136 |
+
betas=(0.9, 0.99),
|
137 |
+
eps=1e-08,
|
138 |
+
)
|
139 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
140 |
+
optimizer=optimizer,
|
141 |
+
num_warmup_steps=self.warmup,
|
142 |
+
num_training_steps=self.max_step,
|
143 |
+
)
|
144 |
+
return {
|
145 |
+
"optimizer": optimizer,
|
146 |
+
"lr_scheduler": {
|
147 |
+
"scheduler": lr_scheduler,
|
148 |
+
"interval": "step",
|
149 |
+
"frequency": 1
|
150 |
+
}
|
151 |
+
}
|
152 |
+
|
153 |
+
def compute_accuracy(self, logits, labels):
|
154 |
+
out = torch.argmax(logits, dim=-1)
|
155 |
+
out = out.flatten()
|
156 |
+
labels = labels.flatten()
|
157 |
+
|
158 |
+
mask = (labels != self.tokenizer.pad_id)
|
159 |
+
out = out[mask]
|
160 |
+
labels = labels[mask]
|
161 |
+
|
162 |
+
num_right = (out == labels)
|
163 |
+
num_right = torch.sum(num_right).type(torch.float32)
|
164 |
+
acc = num_right / len(labels)
|
165 |
+
|
166 |
+
return acc
|
167 |
+
|
168 |
+
def training_step(self, batch, batch_idx):
|
169 |
+
x = batch[:, :-1].contiguous() # (batch_size, midi_sequence_length, token_sequence_length)
|
170 |
+
y = batch[:, 1:].contiguous()
|
171 |
+
hidden = self.forward(x)
|
172 |
+
if self.sample_seq: # to reduce vram
|
173 |
+
rand_idx = [-1] + random.sample(list(range(y.shape[1] - 2)), min(127, (y.shape[1] - 2) // 2))
|
174 |
+
hidden = hidden[:, rand_idx]
|
175 |
+
y = y[:, rand_idx]
|
176 |
+
hidden = hidden.reshape(-1, hidden.shape[-1])
|
177 |
+
y = y.reshape(-1, y.shape[-1]) # (batch_size*midi_sequence_length, token_sequence_length)
|
178 |
+
x = y[:, :-1]
|
179 |
+
logits = self.forward_token(hidden, x)
|
180 |
+
loss = F.cross_entropy(
|
181 |
+
logits.view(-1, self.tokenizer.vocab_size),
|
182 |
+
y.view(-1),
|
183 |
+
reduction="mean",
|
184 |
+
ignore_index=self.tokenizer.pad_id
|
185 |
+
)
|
186 |
+
self.log("train/loss", loss)
|
187 |
+
self.log("train/lr", self.lr_schedulers().get_last_lr()[0])
|
188 |
+
return loss
|
189 |
+
|
190 |
+
def validation_step(self, batch, batch_idx):
|
191 |
+
x = batch[:, :-1].contiguous() # (batch_size, midi_sequence_length, token_sequence_length)
|
192 |
+
y = batch[:, 1:].contiguous()
|
193 |
+
hidden = self.forward(x)
|
194 |
+
hidden = hidden.reshape(-1, hidden.shape[-1])
|
195 |
+
y = y.reshape(-1, y.shape[-1]) # (batch_size*midi_sequence_length, token_sequence_length)
|
196 |
+
x = y[:, :-1]
|
197 |
+
logits = self.forward_token(hidden, x)
|
198 |
+
loss = F.cross_entropy(
|
199 |
+
logits.view(-1, self.tokenizer.vocab_size),
|
200 |
+
y.view(-1),
|
201 |
+
reduction="mean",
|
202 |
+
ignore_index=self.tokenizer.pad_id
|
203 |
+
)
|
204 |
+
acc = self.compute_accuracy(logits, y)
|
205 |
+
self.log_dict({"val/loss": loss, "val/acc": acc}, sync_dist=True)
|
206 |
+
return loss
|
207 |
+
|
208 |
+
@rank_zero_only
|
209 |
+
def gen_example(self, save_dir):
|
210 |
+
base_dir = f"{save_dir}/sample/{self.global_step}"
|
211 |
+
if not os.path.exists(base_dir):
|
212 |
+
Path(base_dir).mkdir(parents=True)
|
213 |
+
midis = self.generate(batch_size=self.example_batch)
|
214 |
+
midis = [self.tokenizer.detokenize(midi) for midi in midis]
|
215 |
+
imgs = [self.tokenizer.midi2img(midi) for midi in midis]
|
216 |
+
for i, (img, midi) in enumerate(zip(imgs, midis)):
|
217 |
+
img.save(f"{base_dir}/0_{i}.png")
|
218 |
+
with open(f"{base_dir}/0_{i}.mid", 'wb') as f:
|
219 |
+
f.write(MIDI.score2midi(midi))
|
220 |
+
prompt = val_dataset.load_midi(random.randint(0, len(val_dataset) - 1))
|
221 |
+
prompt = np.asarray(prompt, dtype=np.int16)
|
222 |
+
ori = prompt[:512]
|
223 |
+
img = self.tokenizer.midi2img(self.tokenizer.detokenize(ori))
|
224 |
+
img.save(f"{base_dir}/1_ori.png")
|
225 |
+
prompt = prompt[:256].astype(np.int64)
|
226 |
+
midis = self.generate(prompt, batch_size=self.example_batch)
|
227 |
+
midis = [self.tokenizer.detokenize(midi) for midi in midis]
|
228 |
+
imgs = [self.tokenizer.midi2img(midi) for midi in midis]
|
229 |
+
for i, (img, midi) in enumerate(zip(imgs, midis)):
|
230 |
+
img.save(f"{base_dir}/1_{i}.png")
|
231 |
+
with open(f"{base_dir}/1_{i}.mid", 'wb') as f:
|
232 |
+
f.write(MIDI.score2midi(midi))
|
233 |
+
|
234 |
+
@rank_zero_only
|
235 |
+
def save_peft(self, save_dir):
|
236 |
+
adapter_name = self.active_adapters()[0]
|
237 |
+
adapter_config = self.peft_config[adapter_name]
|
238 |
+
if not os.path.exists(save_dir):
|
239 |
+
os.makedirs(save_dir, exist_ok=True)
|
240 |
+
adapter_config.save_pretrained(save_dir)
|
241 |
+
adapter_state_dict = self.get_adapter_state_dict(adapter_name)
|
242 |
+
safe_save_file(adapter_state_dict,
|
243 |
+
os.path.join(save_dir, "adapter_model.safetensors"),
|
244 |
+
metadata={"format": "pt"})
|
245 |
+
|
246 |
+
def on_save_checkpoint(self, checkpoint):
|
247 |
+
if self.global_step == self.last_save_step:
|
248 |
+
return
|
249 |
+
self.last_save_step = self.global_step
|
250 |
+
trainer = self.trainer
|
251 |
+
if len(trainer.loggers) > 0:
|
252 |
+
if trainer.loggers[0].save_dir is not None:
|
253 |
+
save_dir = trainer.loggers[0].save_dir
|
254 |
+
else:
|
255 |
+
save_dir = trainer.default_root_dir
|
256 |
+
name = trainer.loggers[0].name
|
257 |
+
version = trainer.loggers[0].version
|
258 |
+
version = version if isinstance(version, str) else f"version_{version}"
|
259 |
+
save_dir = os.path.join(save_dir, str(name), version)
|
260 |
+
else:
|
261 |
+
save_dir = trainer.default_root_dir
|
262 |
+
self.config.save_pretrained(os.path.join(save_dir, "checkpoints"))
|
263 |
+
if self._hf_peft_config_loaded:
|
264 |
+
self.save_peft(os.path.join(save_dir, "lora"))
|
265 |
+
self.gen_example_count += 1
|
266 |
+
if self.gen_example_interval>0 and self.gen_example_count % self.gen_example_interval == 0:
|
267 |
+
try:
|
268 |
+
self.gen_example(save_dir)
|
269 |
+
except Exception as e:
|
270 |
+
print(e)
|
271 |
+
|
272 |
+
|
273 |
+
def get_midi_list(path):
|
274 |
+
all_files = {
|
275 |
+
os.path.join(root, fname)
|
276 |
+
for root, _dirs, files in os.walk(path)
|
277 |
+
for fname in files
|
278 |
+
}
|
279 |
+
all_midis = sorted(
|
280 |
+
fname for fname in all_files if file_ext(fname) in EXTENSION
|
281 |
+
)
|
282 |
+
return all_midis
|
283 |
+
|
284 |
+
|
285 |
+
if __name__ == '__main__':
|
286 |
+
parser = argparse.ArgumentParser()
|
287 |
+
# model args
|
288 |
+
parser.add_argument(
|
289 |
+
"--resume", type=str, default="", help="resume training from ckpt"
|
290 |
+
)
|
291 |
+
parser.add_argument(
|
292 |
+
"--ckpt", type=str, default="", help="load ckpt"
|
293 |
+
)
|
294 |
+
parser.add_argument(
|
295 |
+
"--config", type=str, default="tv2o-medium", help="model config name or file"
|
296 |
+
)
|
297 |
+
parser.add_argument(
|
298 |
+
"--task", type=str, default="train", choices=["train", "lora"], help="Full train or lora"
|
299 |
+
)
|
300 |
+
|
301 |
+
# dataset args
|
302 |
+
parser.add_argument(
|
303 |
+
"--data", type=str, default="data", help="dataset path"
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--data-val-split",
|
307 |
+
type=int,
|
308 |
+
default=128,
|
309 |
+
help="the number of midi files divided into the validation set",
|
310 |
+
)
|
311 |
+
parser.add_argument(
|
312 |
+
"--max-len",
|
313 |
+
type=int,
|
314 |
+
default=2048,
|
315 |
+
help="max seq length for training",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--quality", action="store_true", default=False, help="check dataset quality"
|
319 |
+
)
|
320 |
+
|
321 |
+
# training args
|
322 |
+
parser.add_argument("--seed", type=int, default=0, help="seed")
|
323 |
+
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
|
324 |
+
parser.add_argument("--weight-decay", type=float, default=0.01, help="weight decay")
|
325 |
+
parser.add_argument("--warmup-step", type=int, default=1e2, help="warmup step")
|
326 |
+
parser.add_argument("--max-step", type=int, default=1e6, help="max training step")
|
327 |
+
parser.add_argument("--grad-clip", type=float, default=1.0, help="gradient clip val")
|
328 |
+
parser.add_argument(
|
329 |
+
"--sample-seq", action="store_true", default=False, help="sample midi seq to reduce vram"
|
330 |
+
)
|
331 |
+
parser.add_argument(
|
332 |
+
"--gen-example-interval", type=int, default=1, help="generate example interval. set 0 to disable"
|
333 |
+
)
|
334 |
+
parser.add_argument(
|
335 |
+
"--batch-size-train", type=int, default=2, help="batch size for training"
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--batch-size-val", type=int, default=2, help="batch size for val"
|
339 |
+
)
|
340 |
+
parser.add_argument(
|
341 |
+
"--batch-size-gen-example", type=int, default=8, help="batch size for generate example"
|
342 |
+
)
|
343 |
+
parser.add_argument(
|
344 |
+
"--workers-train",
|
345 |
+
type=int,
|
346 |
+
default=4,
|
347 |
+
help="workers num for training dataloader",
|
348 |
+
)
|
349 |
+
parser.add_argument(
|
350 |
+
"--workers-val",
|
351 |
+
type=int,
|
352 |
+
default=4,
|
353 |
+
help="workers num for validation dataloader",
|
354 |
+
)
|
355 |
+
parser.add_argument(
|
356 |
+
"--acc-grad", type=int, default=2, help="gradient accumulation"
|
357 |
+
)
|
358 |
+
parser.add_argument(
|
359 |
+
"--accelerator",
|
360 |
+
type=str,
|
361 |
+
default="gpu",
|
362 |
+
choices=["cpu", "gpu", "tpu", "ipu", "hpu", "auto"],
|
363 |
+
help="accelerator",
|
364 |
+
)
|
365 |
+
parser.add_argument(
|
366 |
+
"--precision",
|
367 |
+
type=str,
|
368 |
+
default="bf16-true",
|
369 |
+
choices=["16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true", "64", "32", "16", "bf16"],
|
370 |
+
help="precision",
|
371 |
+
)
|
372 |
+
parser.add_argument("--devices", type=int, default=-1, help="devices num")
|
373 |
+
parser.add_argument("--nodes", type=int, default=1, help="nodes num")
|
374 |
+
parser.add_argument(
|
375 |
+
"--disable-benchmark", action="store_true", default=False, help="disable cudnn benchmark"
|
376 |
+
)
|
377 |
+
parser.add_argument(
|
378 |
+
"--log-step", type=int, default=1, help="log training loss every n steps"
|
379 |
+
)
|
380 |
+
parser.add_argument(
|
381 |
+
"--val-step", type=int, default=1600, help="valid and save every n steps, set 0 to valid and save every epoch"
|
382 |
+
)
|
383 |
+
|
384 |
+
opt = parser.parse_args()
|
385 |
+
print(opt)
|
386 |
+
|
387 |
+
if not os.path.exists("lightning_logs"):
|
388 |
+
os.mkdir("lightning_logs")
|
389 |
+
if not os.path.exists("sample"):
|
390 |
+
os.mkdir("sample")
|
391 |
+
pl.seed_everything(opt.seed)
|
392 |
+
print("---load dataset---")
|
393 |
+
if opt.config in config_name_list:
|
394 |
+
config = MIDIModelConfig.from_name(opt.config)
|
395 |
+
else:
|
396 |
+
config = MIDIModelConfig.from_json_file(opt.config)
|
397 |
+
tokenizer = config.tokenizer
|
398 |
+
midi_list = get_midi_list(opt.data)
|
399 |
+
random.shuffle(midi_list)
|
400 |
+
full_dataset_len = len(midi_list)
|
401 |
+
train_dataset_len = full_dataset_len - opt.data_val_split
|
402 |
+
train_midi_list = midi_list[:train_dataset_len]
|
403 |
+
val_midi_list = midi_list[train_dataset_len:]
|
404 |
+
train_dataset = MidiDataset(train_midi_list, tokenizer, max_len=opt.max_len, aug=True, check_quality=opt.quality,
|
405 |
+
rand_start=True)
|
406 |
+
val_dataset = MidiDataset(val_midi_list, tokenizer, max_len=opt.max_len, aug=False, check_quality=opt.quality,
|
407 |
+
rand_start=False)
|
408 |
+
train_dataloader = DataLoader(
|
409 |
+
train_dataset,
|
410 |
+
batch_size=opt.batch_size_train,
|
411 |
+
shuffle=True,
|
412 |
+
persistent_workers=True,
|
413 |
+
num_workers=opt.workers_train,
|
414 |
+
pin_memory=True,
|
415 |
+
collate_fn=train_dataset.collate_fn
|
416 |
+
)
|
417 |
+
val_dataloader = DataLoader(
|
418 |
+
val_dataset,
|
419 |
+
batch_size=opt.batch_size_val,
|
420 |
+
shuffle=False,
|
421 |
+
persistent_workers=True,
|
422 |
+
num_workers=opt.workers_val,
|
423 |
+
pin_memory=True,
|
424 |
+
collate_fn=val_dataset.collate_fn
|
425 |
+
)
|
426 |
+
print(f"train: {len(train_dataset)} val: {len(val_dataset)}")
|
427 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
428 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
429 |
+
model = TrainMIDIModel(config, lr=opt.lr, weight_decay=opt.weight_decay,
|
430 |
+
warmup=opt.warmup_step, max_step=opt.max_step,
|
431 |
+
sample_seq=opt.sample_seq, gen_example_interval=opt.gen_example_interval,
|
432 |
+
example_batch=opt.batch_size_gen_example)
|
433 |
+
if opt.ckpt:
|
434 |
+
ckpt = torch.load(opt.ckpt, map_location="cpu")
|
435 |
+
state_dict = ckpt.get("state_dict", ckpt)
|
436 |
+
model.load_state_dict(state_dict, strict=False)
|
437 |
+
elif opt.task == "lora":
|
438 |
+
raise ValueError("--ckpt must be set to train lora")
|
439 |
+
if opt.task == "lora":
|
440 |
+
model.requires_grad_(False)
|
441 |
+
lora_config = LoraConfig(
|
442 |
+
r=64,
|
443 |
+
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
|
444 |
+
task_type=TaskType.CAUSAL_LM,
|
445 |
+
bias="none",
|
446 |
+
lora_alpha=128,
|
447 |
+
lora_dropout=0
|
448 |
+
)
|
449 |
+
model.add_adapter(lora_config)
|
450 |
+
print("---start train---")
|
451 |
+
checkpoint_callback = ModelCheckpoint(
|
452 |
+
monitor="val/loss",
|
453 |
+
mode="min",
|
454 |
+
save_top_k=1,
|
455 |
+
save_last=True,
|
456 |
+
auto_insert_metric_name=False,
|
457 |
+
filename="epoch={epoch},loss={val/loss:.4f}",
|
458 |
+
)
|
459 |
+
callbacks = [checkpoint_callback]
|
460 |
+
|
461 |
+
trainer = Trainer(
|
462 |
+
precision=opt.precision,
|
463 |
+
accumulate_grad_batches=opt.acc_grad,
|
464 |
+
gradient_clip_val=opt.grad_clip,
|
465 |
+
accelerator=opt.accelerator,
|
466 |
+
devices=opt.devices,
|
467 |
+
num_nodes=opt.nodes,
|
468 |
+
max_steps=opt.max_step,
|
469 |
+
benchmark=not opt.disable_benchmark,
|
470 |
+
val_check_interval=opt.val_step or None,
|
471 |
+
log_every_n_steps=1,
|
472 |
+
strategy="auto",
|
473 |
+
callbacks=callbacks,
|
474 |
+
)
|
475 |
+
ckpt_path = opt.resume
|
476 |
+
if ckpt_path == "":
|
477 |
+
ckpt_path = None
|
478 |
+
print("---start train---")
|
479 |
+
trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)
|