{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "VGrGd6__l5ch" }, "source": [ "# Orpheus Music Transformer Maker (ver. 1.0)\n", "\n", "***\n", "\n", "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n", "\n", "***\n", "\n", "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n", "\n", "***\n", "\n", "#### Project Los Angeles\n", "\n", "#### Tegridy Code 2025\n", "\n", "***" ] }, { "cell_type": "markdown", "metadata": { "id": "shLrgoXdl5cj" }, "source": [ "# GPU check" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X3rABEpKCO02" }, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": { "id": "0RcVC4btl5ck" }, "source": [ "# Setup environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "viHgEaNACPTs" }, "outputs": [], "source": [ "!git clone --depth 1 https://github.com/asigalov61/tegridy-tools" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vK40g6V_BTNj" }, "outputs": [], "source": [ "!pip install huggingface_hub\n", "!pip install hf-transfer\n", "!pip install ipywidgets\n", "!pip install -U tqdm\n", "\n", "!pip install einx\n", "!pip install einops\n", "!pip install torch-summary" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DzCOZU_gBiQV" }, "outputs": [], "source": [ "# Load modules and make data dir\n", "\n", "print('Loading modules...')\n", "\n", "import os\n", "\n", "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n", "\n", "import pickle\n", "import random\n", "import secrets\n", "import tqdm\n", "import math\n", "\n", "import gc\n", "\n", "!set USE_FLASH_ATTENTION=1\n", "os.environ['USE_FLASH_ATTENTION'] = '1'\n", "\n", "import torch\n", "import torch.optim as optim\n", "\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "from torchsummary import summary\n", "from sklearn import metrics\n", "\n", "%cd /home/ubuntu/tegridy-tools/tegridy-tools/\n", "\n", "import TMIDIX\n", "\n", "%cd /home/ubuntu/tegridy-tools/tegridy-tools/X-Transformer\n", "\n", "from x_transformer_2_3_1 import *\n", "\n", "torch.set_float32_matmul_precision('medium')\n", "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n", "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n", "torch.backends.cuda.enable_flash_sdp(True)\n", "torch.backends.cuda.enable_cudnn_sdp(False)\n", "\n", "!set USE_FLASH_ATTENTION=1\n", "\n", "%cd /home/ubuntu/\n", "\n", "if not os.path.exists('/home/ubuntu/DATA'):\n", " os.makedirs('/home/ubuntu/DATA')\n", "\n", "import random\n", "\n", "print('Done')\n", "\n", "print('Torch version:', torch.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "Sbhzy8FGl5cm" }, "source": [ "# Prep training data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Put all training dataset pickle files into ./DATA folder" ] }, { "cell_type": "markdown", "metadata": { "id": "DdNpMqtEvs3G" }, "source": [ "## Data files List" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IdBpL-HUHLBW" }, "outputs": [], "source": [ "dataset_addr = \"/home/ubuntu/DATA\"\n", "\n", "#==========================================================================\n", "\n", "filez = list()\n", "for (dirpath, dirnames, filenames) in os.walk(dataset_addr):\n", " filez += [os.path.join(dirpath, file) for file in filenames if file.endswith('.pickle')]\n", "print('=' * 70)\n", "\n", "random.shuffle(filez)\n", "\n", "print('Loaded', len(filez), 'data files')\n", "print('=' * 70)" ] }, { "cell_type": "markdown", "metadata": { "id": "cd-51e9wooMs" }, "source": [ "## Load training data files" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SEQ_LEN = 8192\n", "PAD_IDX = 18819 # Model pad index\n", "\n", "#==========================================================================\n", "\n", "print('=' * 70)\n", "print('Loading data files...')\n", "print('Please wait...')\n", "print('=' * 70)\n", "\n", "train_data = set()\n", "\n", "chunks_counter = 0\n", "\n", "gc.disable()\n", "\n", "for lfa in tqdm.tqdm(filez):\n", "\n", " train_d = pickle.load(open(lfa, 'rb'))\n", "\n", " for t in train_d:\n", "\n", " if 0 <= max(t) < PAD_IDX: # final data integrity check\n", " train_data.add(tuple(t))\n", " chunks_counter += 1\n", " \n", " else:\n", " print('Bad data!!!')\n", "\n", "gc.enable()\n", "gc.collect()\n", "\n", "train_data = list(train_data)\n", "\n", "#==========================================================================\n", "\n", "print('Done!')\n", "print('=' * 70)\n", "print('Total number of main chunks:', chunks_counter)\n", "print('All data is good:', len(max(train_data, key=len)) == len(min(train_data, key=len)))\n", "print('=' * 70)\n", "print('Sorting by length...')\n", "print('Randomizing train data...')\n", "random.shuffle(train_data)\n", "print('Done!')\n", "print('=' * 70)\n", "print('Total length of train data:', len(train_data))\n", "print('=' * 70)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_data = list(train_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nAz4jyEaWslK" }, "outputs": [], "source": [ "len(train_data[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KrlnuPGyWslK" }, "outputs": [], "source": [ "train_data[0][:15]" ] }, { "cell_type": "markdown", "metadata": { "id": "VhZqBvqVl5cn" }, "source": [ "# Setup model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mfwp06xzzPZ5" }, "outputs": [], "source": [ "# Setup model\n", "\n", "# constants\n", "\n", "VALIDATE_EVERY = 100\n", "SAVE_EVERY = 500\n", "GENERATE_EVERY = 250\n", "GENERATE_LENGTH = 512\n", "PRINT_STATS_EVERY = 10\n", "\n", "NUM_EPOCHS = 5\n", "\n", "BATCH_SIZE = 9\n", "GRADIENT_ACCUMULATE_EVERY = 8\n", "\n", "LEARNING_RATE = 1e-4\n", "GRAD_CLIP = 1.0\n", "\n", "# instantiate the model\n", "\n", "model = TransformerWrapper(\n", " num_tokens = PAD_IDX+1,\n", " max_seq_len = SEQ_LEN,\n", " attn_layers = Decoder(dim = 2048,\n", " depth = 8,\n", " heads = 32,\n", " rotary_pos_emb = True,\n", " attn_flash = True,\n", " )\n", " )\n", "\n", "model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)\n", "\n", "model.cuda()\n", "\n", "print('Done!')\n", "\n", "summary(model)\n", "\n", "# Dataloader\n", "\n", "def get_train_data_batch(tdata, index, seq_len, batch_size, pad_idx):\n", "\n", " batch = tdata[(index*batch_size):(index*batch_size)+batch_size]\n", "\n", " padded_batch = []\n", "\n", " for ba in batch:\n", "\n", " ba = list(ba)\n", "\n", " if len(ba) > (seq_len+1):\n", " ba = ba[:(seq_len+1)]\n", "\n", " else:\n", " ba += [pad_idx] * ((seq_len+1) - len(ba[:(seq_len+1)]))\n", "\n", " padded_batch.append(ba)\n", "\n", " return torch.LongTensor(padded_batch).cuda()\n", "\n", "# precision/optimizer/scaler\n", "\n", "dtype = torch.bfloat16\n", "\n", "ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)\n", "\n", "optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n", "\n", "scaler = torch.amp.GradScaler('cuda')" ] }, { "cell_type": "markdown", "metadata": { "id": "xJPxxFiwl5cn" }, "source": [ "# Train" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HETGqz_6K1ml", "scrolled": true }, "outputs": [], "source": [ "# Train the model\n", "\n", "train_losses = []\n", "val_losses = []\n", "\n", "train_accs = []\n", "val_accs = []\n", "\n", "nsteps = 0\n", "\n", "for ep in range(NUM_EPOCHS):\n", "\n", " print('=' * 70)\n", " print('Randomizing train data...')\n", " random.shuffle(train_data)\n", " print('=' * 70)\n", "\n", " print('=' * 70)\n", " print('Epoch #', ep+1)\n", " print('=' * 70)\n", "\n", " NUM_BATCHES = len(train_data) // BATCH_SIZE // GRADIENT_ACCUMULATE_EVERY\n", "\n", " model.train()\n", "\n", " for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):\n", "\n", " optim.zero_grad()\n", "\n", " for j in range(GRADIENT_ACCUMULATE_EVERY):\n", " with ctx:\n", " loss, acc = model(get_train_data_batch(train_data, (i*GRADIENT_ACCUMULATE_EVERY)+j, SEQ_LEN, BATCH_SIZE, PAD_IDX))\n", " loss = loss / GRADIENT_ACCUMULATE_EVERY\n", " scaler.scale(loss).backward()\n", "\n", " if i % PRINT_STATS_EVERY == 0:\n", " print(f'Training loss: {loss.item() * GRADIENT_ACCUMULATE_EVERY}')\n", " print(f'Training acc: {acc.item()}')\n", "\n", " train_losses.append(loss.item() * GRADIENT_ACCUMULATE_EVERY)\n", " train_accs.append(acc.item())\n", "\n", " scaler.unscale_(optim)\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n", " scaler.step(optim)\n", " scaler.update()\n", "\n", " nsteps += 1\n", "\n", " if i % VALIDATE_EVERY == 0:\n", " model.eval()\n", " with torch.no_grad():\n", " with ctx:\n", " val_loss, val_acc = model(get_train_data_batch(train_data, i, SEQ_LEN, BATCH_SIZE, PAD_IDX))\n", "\n", " print(f'Validation loss: {val_loss.item()}')\n", " print(f'Validation acc: {val_acc.item()}')\n", "\n", " val_losses.append(val_loss.item())\n", " val_accs.append(val_acc.item())\n", "\n", " print('Plotting training loss graph...')\n", "\n", " tr_loss_list = train_losses\n", " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", " plt.show()\n", " plt.close()\n", " print('Done!')\n", "\n", " print('Plotting training acc graph...')\n", "\n", " tr_loss_list = train_accs\n", " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", " plt.show()\n", " plt.close()\n", " print('Done!')\n", "\n", " print('Plotting validation loss graph...')\n", " tr_loss_list = val_losses\n", " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", " plt.show()\n", " plt.close()\n", " print('Done!')\n", "\n", " print('Plotting validation acc graph...')\n", " tr_loss_list = val_accs\n", " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", " plt.show()\n", " plt.close()\n", " print('Done!')\n", "\n", " model.train()\n", "\n", " if i % GENERATE_EVERY == 0:\n", " model.eval()\n", "\n", " inp = random.choice(get_train_data_batch(train_data, i, SEQ_LEN, BATCH_SIZE, PAD_IDX))[:GENERATE_LENGTH]\n", "\n", " print(inp)\n", "\n", " with ctx:\n", " sample = model.generate(inp[None, ...], GENERATE_LENGTH)\n", "\n", " print(sample)\n", "\n", " data = sample.tolist()[0]\n", "\n", " print('Sample INTs', data[:15])\n", "\n", " if len(data) != 0:\n", "\n", " song = data\n", " song_f = []\n", "\n", " time = 0\n", " dur = 1\n", " vel = 90\n", " pitch = 60\n", " channel = 0\n", " patch = 0\n", " \n", " patches = [-1] * 16\n", " \n", " channels = [0] * 16\n", " channels[9] = 1\n", " \n", " for ss in song:\n", " \n", " if 0 <= ss < 256:\n", " \n", " time += ss * 16\n", " \n", " if 256 <= ss < 16768:\n", " \n", " patch = (ss-256) // 128\n", " \n", " if patch < 128:\n", " \n", " if patch not in patches:\n", " if 0 in channels:\n", " cha = channels.index(0)\n", " channels[cha] = 1\n", " else:\n", " cha = 15\n", " \n", " patches[cha] = patch\n", " channel = patches.index(patch)\n", " else:\n", " channel = patches.index(patch)\n", " \n", " if patch == 128:\n", " channel = 9\n", " \n", " pitch = (ss-256) % 128\n", " \n", " \n", " if 16768 <= ss < 18816:\n", " \n", " dur = ((ss-16768) // 8) * 16\n", " vel = (((ss-16768) % 8)+1) * 15\n", " \n", " song_f.append(['note', time, dur, channel, pitch, vel ])\n", "\n", " patches = [0 if x==-1 else x for x in patches]\n", "\n", " detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n", " output_signature = 'Orpheus Music Transformer',\n", " output_file_name = '/home/ubuntu/Orpheus-Music-Transformer-Composition',\n", " track_name='Project Los Angeles',\n", " list_of_MIDI_patches=patches\n", " )\n", "\n", " print('Done!')\n", "\n", " model.train()\n", "\n", " if i % SAVE_EVERY == 0:\n", "\n", " print('Saving model progress. Please wait...')\n", " print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')\n", "\n", " fname = '/home/ubuntu/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'\n", "\n", " torch.save(model.state_dict(), fname)\n", "\n", " torch.save(optim.state_dict(), '/home/ubuntu/optimizer.pth')\n", "\n", " torch.save(scaler.state_dict(), '/home/ubuntu/scaler.pth')\n", "\n", " data = [train_losses, train_accs, val_losses, val_accs]\n", "\n", " TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/home/ubuntu/losses_accs')\n", "\n", " print('Done!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Resume training from checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_path = 'checkpoint.pth'\n", "optim_path = 'optimizer.pth'\n", "scaler_path = 'scaler.pth'\n", "\n", "print('Restoring optimizer...')\n", "optim.load_state_dict(torch.load(optim_path))\n", "\n", "print('Restoring scaler...')\n", "scaler.load_state_dict(torch.load(scaler_path))\n", "\n", "print('Restoring model...')\n", "model.load_state_dict(torch.load(model_path))\n", "\n", "print('Done!')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "START_EPOCHS = 0\n", "\n", "train_losses, train_accs, val_losses, val_accs = TMIDIX.Tegridy_Any_Pickle_File_Reader('losses_accs.pickle')\n", "\n", "nsteps = len(train_losses)\n", "\n", "print(nsteps)\n", "print(train_losses[-1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Train the model\n", "\n", "for ep in range(START_EPOCHS, NUM_EPOCHS):\n", "\n", " print('=' * 70)\n", " print('Randomizing train data...')\n", " random.shuffle(train_data)\n", " print('=' * 70)\n", "\n", " print('=' * 70)\n", " print('Epoch #', ep+1)\n", " print('=' * 70)\n", "\n", " NUM_BATCHES = len(train_data) // BATCH_SIZE // GRADIENT_ACCUMULATE_EVERY\n", "\n", " model.train()\n", "\n", " for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):\n", "\n", " optim.zero_grad()\n", "\n", " for j in range(GRADIENT_ACCUMULATE_EVERY):\n", " with ctx:\n", " loss, acc = model(get_train_data_batch(train_data, (i*GRADIENT_ACCUMULATE_EVERY)+j, SEQ_LEN, BATCH_SIZE, PAD_IDX))\n", " loss = loss / GRADIENT_ACCUMULATE_EVERY\n", " scaler.scale(loss).backward()\n", "\n", " if i % PRINT_STATS_EVERY == 0:\n", " print(f'Training loss: {loss.item() * GRADIENT_ACCUMULATE_EVERY}')\n", " print(f'Training acc: {acc.item()}')\n", "\n", " train_losses.append(loss.item() * GRADIENT_ACCUMULATE_EVERY)\n", " train_accs.append(acc.item())\n", "\n", " scaler.unscale_(optim)\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n", " scaler.step(optim)\n", " scaler.update()\n", "\n", " nsteps += 1\n", "\n", " if i % VALIDATE_EVERY == 0:\n", " model.eval()\n", " with torch.no_grad():\n", " with ctx:\n", " val_loss, val_acc = model(get_train_data_batch(train_data, i, SEQ_LEN, BATCH_SIZE, PAD_IDX))\n", "\n", " print(f'Validation loss: {val_loss.item()}')\n", " print(f'Validation acc: {val_acc.item()}')\n", "\n", " val_losses.append(val_loss.item())\n", " val_accs.append(val_acc.item())\n", "\n", " print('Plotting training loss graph...')\n", "\n", " tr_loss_list = train_losses\n", " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", " plt.show()\n", " plt.close()\n", " print('Done!')\n", "\n", " print('Plotting training acc graph...')\n", "\n", " tr_loss_list = train_accs\n", " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", " plt.show()\n", " plt.close()\n", " print('Done!')\n", "\n", " print('Plotting validation loss graph...')\n", " tr_loss_list = val_losses\n", " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", " plt.show()\n", " plt.close()\n", " print('Done!')\n", "\n", " print('Plotting validation acc graph...')\n", " tr_loss_list = val_accs\n", " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", " plt.show()\n", " plt.close()\n", " print('Done!')\n", "\n", " model.train()\n", "\n", " if i % GENERATE_EVERY == 0:\n", " model.eval()\n", "\n", " inp = random.choice(get_train_data_batch(train_data, i, SEQ_LEN, BATCH_SIZE, PAD_IDX))[:GENERATE_LENGTH]\n", "\n", " print(inp)\n", "\n", " with ctx:\n", " sample = model.generate(inp[None, ...], GENERATE_LENGTH)\n", "\n", " print(sample)\n", "\n", " data = sample.tolist()[0]\n", "\n", " print('Sample INTs', data[:15])\n", "\n", " if len(data) != 0:\n", "\n", " song = data\n", " song_f = []\n", "\n", " time = 0\n", " dur = 1\n", " vel = 90\n", " pitch = 60\n", " channel = 0\n", " patch = 0\n", " \n", " patches = [-1] * 16\n", " \n", " channels = [0] * 16\n", " channels[9] = 1\n", " \n", " for ss in song:\n", " \n", " if 0 <= ss < 256:\n", " \n", " time += ss * 16\n", " \n", " if 256 <= ss < 16768:\n", " \n", " patch = (ss-256) // 128\n", " \n", " if patch < 128:\n", " \n", " if patch not in patches:\n", " if 0 in channels:\n", " cha = channels.index(0)\n", " channels[cha] = 1\n", " else:\n", " cha = 15\n", " \n", " patches[cha] = patch\n", " channel = patches.index(patch)\n", " else:\n", " channel = patches.index(patch)\n", " \n", " if patch == 128:\n", " channel = 9\n", " \n", " pitch = (ss-256) % 128\n", " \n", " \n", " if 16768 <= ss < 18816:\n", " \n", " dur = ((ss-16768) // 8) * 16\n", " vel = (((ss-16768) % 8)+1) * 15\n", " \n", " song_f.append(['note', time, dur, channel, pitch, vel ])\n", "\n", " patches = [0 if x==-1 else x for x in patches]\n", "\n", " detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n", " output_signature = 'Orpheus Music Transformer',\n", " output_file_name = '/home/ubuntu/Orpheus-Music-Transformer-Composition',\n", " track_name='Project Los Angeles',\n", " list_of_MIDI_patches=patches\n", " )\n", "\n", " print('Done!')\n", "\n", " model.train()\n", "\n", " if i % SAVE_EVERY == 0:\n", "\n", " print('Saving model progress. Please wait...')\n", " print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')\n", "\n", " fname = '/home/ubuntu/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'\n", "\n", " torch.save(model.state_dict(), fname)\n", "\n", " torch.save(optim.state_dict(), '/home/ubuntu/optimizer.pth')\n", "\n", " torch.save(scaler.state_dict(), '/home/ubuntu/scaler.pth')\n", "\n", " data = [train_losses, train_accs, val_losses, val_accs]\n", "\n", " TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/home/ubuntu/losses_accs')\n", "\n", " print('Done!')" ] }, { "cell_type": "markdown", "metadata": { "id": "wBkMH2gWl5co" }, "source": [ "# Final Save" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gjBJnzZxWslL" }, "outputs": [], "source": [ "print('Saving model progress. Please wait...')\n", "print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')\n", "\n", "fname = '/home/ubuntu/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'\n", "\n", "torch.save(model.state_dict(), fname)\n", "\n", "print('Done!')\n", "\n", "data = [train_losses, train_accs, val_losses, val_accs]\n", "\n", "TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/home/ubuntu/losses_accuracies')\n", "\n", "# Save training loss graph\n", "\n", "plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')\n", "plt.savefig('/home/ubuntu/training_loss_graph.png')\n", "plt.close()\n", "print('Done!')\n", "\n", "# Save training acc graph\n", "\n", "plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')\n", "plt.savefig('/home/ubuntu/training_acc_graph.png')\n", "plt.close()\n", "print('Done!')\n", "\n", "# Save validation loss graph\n", "\n", "plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')\n", "plt.savefig('/home/ubuntu/validation_loss_graph.png')\n", "plt.close()\n", "print('Done!')\n", "\n", "# Save validation acc graph\n", "\n", "plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')\n", "plt.savefig('/home/ubuntu/validation_acc_graph.png')\n", "plt.close()\n", "print('Done!')" ] }, { "cell_type": "markdown", "metadata": { "id": "feXay_Ed7mG5" }, "source": [ "# Eval" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oRgbMijDWslM" }, "outputs": [], "source": [ "!sudo pip install huggingface_hub" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SA8qQSzbWslM" }, "outputs": [], "source": [ "from huggingface_hub import hf_hub_download\n", "\n", "hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer',\n", " filename='Orpheus_Music_Transformer_Trained_Model_96332_steps_0.82_loss_0.748_acc.pth',\n", " local_dir='/home/ubuntu/Models/',\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gSvqSRLaWslM" }, "outputs": [], "source": [ "SEQ_LEN = 8192\n", "PAD_IDX = 18819\n", "\n", "model = TransformerWrapper(\n", " num_tokens = PAD_IDX+1,\n", " max_seq_len = SEQ_LEN,\n", " attn_layers = Decoder(dim = 2048,\n", " depth = 8,\n", " heads = 32,\n", " rotary_pos_emb = True,\n", " attn_flash = True\n", " )\n", " )\n", "\n", "model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)\n", "\n", "print('=' * 70)\n", "print('Loading model checkpoint...')\n", "\n", "model_path = 'Models/Orpheus_Music_Transformer_Trained_Model_96332_steps_0.82_loss_0.748_acc.pth'\n", "\n", "model.load_state_dict(torch.load(model_path))\n", "\n", "print('=' * 70)\n", "\n", "model.cuda()\n", "model.eval()\n", "\n", "print('Done!')\n", "\n", "summary(model)\n", "\n", "dtype = torch.bfloat16\n", "\n", "ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "enHpaHxaWslM" }, "outputs": [], "source": [ "midi_file = 'Orpheus-Music-Transformer-Piano-Seed-1.mid'\n", "\n", "print('=' * 70)\n", "print('MIDI File:', midi_file)\n", "print('=' * 70)\n", "\n", "raw_score = TMIDIX.midi2single_track_ms_score(midi_file)\n", "\n", "escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True)\n", "\n", "escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], sort_drums_last=True)\n", "\n", "dscore = TMIDIX.delta_score_notes(escore_notes)\n", "\n", "dcscore = TMIDIX.chordify_score([d[1:] for d in dscore])\n", "\n", "bad_chords_counts = TMIDIX.count_bad_chords_in_chordified_score(dcscore, pitches_index=3, patches_index=5)\n", "\n", "melody_chords = [18816]\n", "\n", "#=======================================================\n", "# MAIN PROCESSING CYCLE\n", "#=======================================================\n", "\n", "for i, c in enumerate(dcscore):\n", "\n", " # Outro seq\n", " # if len(dcscore)-i == 64 and len(dcscore) > 191:\n", " # melody_chords.extend([18817])\n", " \n", " # Delta start-times\n", "\n", " delta_time = c[0][0]\n", "\n", " melody_chords.append(delta_time)\n", "\n", " for e in c:\n", " \n", " #=======================================================\n", " \n", " # Durations\n", " dur = max(1, min(255, e[1]))\n", "\n", " # Patches\n", " pat = max(0, min(128, e[5]))\n", " \n", " # Pitches\n", " ptc = max(1, min(127, e[3]))\n", " \n", " # Velocities\n", " # Calculating octo-velocity\n", " \n", " vel = max(8, min(127, e[4]))\n", " velocity = round(vel / 15)-1\n", " \n", " #=======================================================\n", " # FINAL NOTE SEQ\n", " #=======================================================\n", " \n", " # Writing final note\n", " pat_ptc = (128 * pat) + ptc \n", " dur_vel = (8 * dur) + velocity\n", "\n", " melody_chords.extend([pat_ptc+256, dur_vel+16768]) # 18816\n", "\n", "print('Done!')\n", "print('=' * 70)\n", "print(len(melody_chords))\n", "print('=' * 70)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "naf65RxUXwDg" }, "outputs": [], "source": [ "model.eval()\n", "\n", "x = torch.LongTensor([0]).cuda()\n", "# x = torch.LongTensor(melody_chords).cuda()\n", "\n", "with ctx:\n", " out = model.generate(x,\n", " 700,\n", " temperature=0.9,\n", " #filter_logits_fn=top_k,\n", " #filter_kwargs={'k': 15},\n", " return_prime=True,\n", " verbose=True)\n", "\n", "y = out.tolist()\n", "\n", "print('---------------')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tlBzqWpAnZna" }, "outputs": [], "source": [ "# Save to MIDI\n", "\n", "data = y[0]\n", "\n", "print('Sample INTs', data[:15])\n", "\n", "if len(data) != 0:\n", "\n", " song = data\n", " song_f = []\n", "\n", " time = 0\n", " dur = 1\n", " vel = 90\n", " pitch = 60\n", " channel = 0\n", " patch = 0\n", "\n", " patches = [-1] * 16\n", "\n", " channels = [0] * 16\n", " channels[9] = 1\n", "\n", " for ss in song:\n", "\n", " if 0 <= ss < 256:\n", "\n", " time += ss * 16\n", "\n", " if 256 <= ss < 16768:\n", "\n", " patch = (ss-256) // 128\n", "\n", " if patch < 128:\n", "\n", " if patch not in patches:\n", " if 0 in channels:\n", " cha = channels.index(0)\n", " channels[cha] = 1\n", " else:\n", " cha = 15\n", "\n", " patches[cha] = patch\n", " channel = patches.index(patch)\n", " else:\n", " channel = patches.index(patch)\n", "\n", " if patch == 128:\n", " channel = 9\n", "\n", " pitch = (ss-256) % 128\n", "\n", "\n", " if 16768 <= ss < 18816:\n", "\n", " dur = ((ss-16768) // 8) * 16\n", " vel = (((ss-16768) % 8)+1) * 15\n", "\n", " song_f.append(['note', time, dur, channel, pitch, vel ])\n", "\n", "patches = [0 if x==-1 else x for x in patches]\n", "\n", "detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n", " output_signature = 'Orpheus Music Transformer',\n", " output_file_name = '/home/ubuntu/Orpheus-Music-Transformer-Composition',\n", " track_name='Project Los Angeles',\n", " list_of_MIDI_patches=patches\n", " )\n", "\n", "print('Done!')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "al3TDlH7T8m7" }, "outputs": [], "source": [ "tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()\n", "\n", "cos_sim = metrics.pairwise_distances(\n", " tok_emb, metric='cosine'\n", ")\n", "plt.figure(figsize=(7, 7))\n", "plt.imshow(cos_sim, cmap=\"inferno\", interpolation=\"nearest\")\n", "im_ratio = cos_sim.shape[0] / cos_sim.shape[1]\n", "plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)\n", "plt.xlabel(\"Position\")\n", "plt.ylabel(\"Position\")\n", "plt.tight_layout()\n", "plt.plot()\n", "plt.savefig(\"/home/ubuntu/Orpheus-Music-Transformer-Tokens-Embeddings-Plot.png\", bbox_inches=\"tight\")" ] }, { "cell_type": "markdown", "metadata": { "id": "z87TlDTVl5cp" }, "source": [ "# Congrats! You did it! :)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuClass": "premium", "gpuType": "T4", "private_outputs": true, "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }