Almusawee commited on
Commit
6bf3323
·
verified ·
1 Parent(s): 9641a3c

Upload Copy of COMPLETE_modular_brain_agent_with_spikes_and_plasticity.ipynb

Browse files
Copy of COMPLETE_modular_brain_agent_with_spikes_and_plasticity.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cells":[{"cell_type":"code","source":["# MIT License\n","#\n","# Copyright (c) 2025 ALMUSAWIY Halliru\n","#\n","# Permission is hereby granted, free of charge, to any person obtaining a copy\n","# of this software and associated documentation files (the \"Software\"), to deal\n","# in the Software without restriction, including without limitation the rights\n","# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n","# copies of the Software, and to permit persons to whom the Software is\n","# furnished to do so, subject to the following conditions:\n","#\n","# The above copyright notice and this permission notice shall be included in all\n","# copies or substantial portions of the Software.\n","#\n","# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n","# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n","# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n","# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n","# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n","# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n","# SOFTWARE.\n","\n","# === V3 Modular Brain Agent with Plasticity - Block 1 ===\n","\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import numpy as np\n","import random\n","from torch.utils.data import DataLoader, Dataset\n","from collections import deque\n","from torchvision import datasets, transforms\n","\n","# === Plastic Synapse Mechanisms ===\n","class PlasticLinear(nn.Module):\n"," def __init__(self, in_features, out_features, plasticity_type=\"hebbian\", learning_rate=0.01):\n"," super().__init__()\n"," self.in_features = in_features\n"," self.out_features = out_features\n"," self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)\n"," self.bias = nn.Parameter(torch.zeros(out_features))\n"," self.plasticity_type = plasticity_type\n"," self.eta = learning_rate\n"," self.trace = torch.zeros(out_features, in_features)\n"," self.register_buffer('prev_y', torch.zeros(out_features))\n","\n"," def forward(self, x):\n"," y = F.linear(x, self.weight, self.bias)\n"," if self.training:\n"," x_detached = x.detach()\n"," y_detached = y.detach()\n"," if self.plasticity_type == \"hebbian\":\n"," hebb = torch.einsum('bi,bj->ij', y_detached, x_detached) / x.size(0)\n"," self.trace = (1 - self.eta) * self.trace + self.eta * hebb\n"," with torch.no_grad():\n"," self.weight += self.trace\n"," elif self.plasticity_type == \"stdp\":\n"," dy = y_detached - self.prev_y\n"," stdp = torch.einsum('bi,bj->ij', dy, x_detached) / x.size(0)\n"," self.trace = (1 - self.eta) * self.trace + self.eta * stdp\n"," with torch.no_grad():\n"," self.weight += self.trace\n"," self.prev_y = y_detached.clone()\n"," return y\n","\n","# === Spiking Surrogate Functions and Base Neurons ===\n","class SpikeFunction(torch.autograd.Function):\n"," @staticmethod\n"," def forward(ctx, input):\n"," ctx.save_for_backward(input)\n"," return (input > 0).float()\n","\n"," @staticmethod\n"," def backward(ctx, grad_output):\n"," input, = ctx.saved_tensors\n"," return grad_output * (abs(input) < 1).float()\n","\n","spike_fn = SpikeFunction.apply\n","\n","class LIFNeuron(nn.Module):\n"," def __init__(self, tau=2.0):\n"," super().__init__()\n"," self.tau = tau\n"," self.mem = 0\n","\n"," def forward(self, x):\n"," decay = torch.exp(torch.tensor(-1.0 / self.tau))\n"," self.mem = self.mem * decay + x\n"," out = spike_fn(self.mem - 1.0)\n"," self.mem = self.mem * (1.0 - out.detach())\n"," return out\n","\n","# === Adaptive LIF Neuron ===\n","class AdaptiveLIF(nn.Module):\n"," def __init__(self, size, tau=2.0, beta=0.2):\n"," super().__init__()\n"," self.size = size\n"," self.tau = tau\n"," self.beta = beta\n"," self.mem = torch.zeros(size)\n"," self.thresh = torch.ones(size)\n","\n"," def forward(self, x):\n"," decay = torch.exp(torch.tensor(-1.0 / self.tau))\n"," self.mem = self.mem * decay + x\n"," out = spike_fn(self.mem - self.thresh)\n"," self.thresh = self.thresh + self.beta * out\n"," self.mem = self.mem * (1.0 - out.detach())\n"," return out\n","\n","# === Relay Layer with Attention ===\n","class RelayLayer(nn.Module):\n"," def __init__(self, dim, heads=4):\n"," super().__init__()\n"," self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, batch_first=True)\n"," self.lif = LIFNeuron()\n","\n"," def forward(self, x):\n"," attn_out, _ = self.attn(x, x, x)\n"," return self.lif(attn_out)\n","\n","# === Working Memory ===\n","class WorkingMemory(nn.Module):\n"," def __init__(self, input_dim, hidden_dim):\n"," super().__init__()\n"," self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)\n","\n"," def forward(self, x):\n"," out, _ = self.lstm(x)\n"," return out[:, -1]\n","\n","# === Place Cell Grid ===\n","class PlaceGrid(nn.Module):\n"," def __init__(self, grid_size=10, embedding_dim=64):\n"," super().__init__()\n"," self.embedding = nn.Embedding(grid_size**2, embedding_dim)\n","\n"," def forward(self, index):\n"," return self.embedding(index)\n","\n","# === Mirror Comparator ===\n","class MirrorComparator(nn.Module):\n"," def __init__(self, dim):\n"," super().__init__()\n"," self.cos = nn.CosineSimilarity(dim=1)\n","\n"," def forward(self, x1, x2):\n"," return self.cos(x1, x2).unsqueeze(1)\n","\n","# === Neuroendocrine Module ===\n","class NeuroendocrineModulator(nn.Module):\n"," def __init__(self, input_dim, hidden_dim):\n"," super().__init__()\n"," self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)\n","\n"," def forward(self, x):\n"," out, _ = self.lstm(x)\n"," return out[:, -1]\n","\n","# === Autonomic Feedback Module ===\n","class AutonomicFeedback(nn.Module):\n"," def __init__(self, input_dim):\n"," super().__init__()\n"," self.feedback = nn.Linear(input_dim, input_dim)\n","\n"," def forward(self, x):\n"," return torch.tanh(self.feedback(x))\n","\n","# === Replay Buffer ===\n","class ReplayBuffer:\n"," def __init__(self, capacity=1000):\n"," self.buffer = deque(maxlen=capacity)\n","\n"," def add(self, inputs, labels, task):\n"," self.buffer.append((inputs, labels, task))\n","\n"," def sample(self, batch_size):\n"," indices = random.sample(range(len(self.buffer)), batch_size)\n"," batch = [self.buffer[i] for i in indices]\n"," inputs, labels, tasks = zip(*batch)\n"," return inputs, labels, tasks\n","\n","# === Full Modular Brain Agent with Plasticity ===\n","class ModularBrainAgent(nn.Module):\n"," def __init__(self, input_dims, hidden_dim, output_dims):\n"," super().__init__()\n"," self.vision_encoder = nn.Linear(input_dims['vision'], hidden_dim)\n"," self.language_encoder = nn.Linear(input_dims['language'], hidden_dim)\n"," self.numeric_encoder = nn.Linear(input_dims['numeric'], hidden_dim)\n","\n"," # Plastic synapses (Hebbian and STDP)\n"," self.connect_sensory_to_relay = PlasticLinear(hidden_dim * 3, hidden_dim, plasticity_type='hebbian')\n"," self.relay_layer = RelayLayer(hidden_dim)\n"," self.connect_relay_to_inter = PlasticLinear(hidden_dim, hidden_dim, plasticity_type='stdp')\n","\n"," self.interneuron = AdaptiveLIF(hidden_dim)\n"," self.memory = WorkingMemory(hidden_dim, hidden_dim)\n"," self.place = PlaceGrid(grid_size=10, embedding_dim=hidden_dim)\n"," self.comparator = MirrorComparator(hidden_dim)\n"," self.emotion = NeuroendocrineModulator(hidden_dim, hidden_dim)\n"," self.feedback = AutonomicFeedback(hidden_dim)\n","\n"," self.task_heads = nn.ModuleDict({\n"," task: nn.Linear(hidden_dim, out_dim)\n"," for task, out_dim in output_dims.items()\n"," })\n","\n"," self.replay = ReplayBuffer()\n","\n"," def forward(self, inputs, task, position_idx=None):\n"," v = self.vision_encoder(inputs['vision'])\n"," l = self.language_encoder(inputs['language'])\n"," n = self.numeric_encoder(inputs['numeric'])\n","\n"," sensory_cat = torch.cat([v, l, n], dim=-1)\n"," z = self.connect_sensory_to_relay(sensory_cat)\n","\n"," z = self.relay_layer(z.unsqueeze(1)).squeeze(1)\n"," z = self.connect_relay_to_inter(z)\n"," z = self.interneuron(z)\n","\n"," m = self.memory(z.unsqueeze(1))\n"," p = self.place(position_idx if position_idx is not None else torch.tensor([0]))\n"," e = self.emotion(z.unsqueeze(1))\n"," f = self.feedback(z)\n","\n"," combined = z + m + p + e + f\n"," out = self.task_heads[task](combined)\n"," return out\n","\n"," def remember(self, inputs, labels, task):\n"," self.replay.add(inputs, labels, task)\n","\n","# === Main Test Block ===\n","if __name__ == \"__main__\":\n"," input_dims = {'vision': 32, 'language': 16, 'numeric': 8}\n"," output_dims = {'classification': 5, 'regression': 1, 'binary': 1}\n"," agent = ModularBrainAgent(input_dims, hidden_dim=64, output_dims=output_dims)\n","\n"," tasks = list(output_dims.keys())\n","\n"," for step in range(250):\n"," task = random.choice(tasks)\n"," inputs = {\n"," 'vision': torch.randn(1, 32),\n"," 'language': torch.randn(1, 16),\n"," 'numeric': torch.randn(1, 8)\n"," }\n"," labels = torch.randint(0, output_dims[task], (1,)) if task == 'classification' else torch.randn(1, output_dims[task])\n"," output = agent(inputs, task)\n"," loss = F.cross_entropy(output, labels) if task == 'classification' else F.mse_loss(output, labels)\n"," print(f\"Step {step:02d} | Task: {task:13s} | Loss: {loss.item():.4f}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ziYXMGMEBPzD","executionInfo":{"status":"ok","timestamp":1750795814513,"user_tz":-60,"elapsed":1203,"user":{"displayName":"Aliyu Lawan Halliru","userId":"08436070427613420807"}},"outputId":"55134f70-d1d5-41a1-9dea-4814a365b4ee","collapsed":true},"id":"ziYXMGMEBPzD","execution_count":57,"outputs":[{"output_type":"stream","name":"stdout","text":["Step 00 | Task: regression | Loss: 0.1261\n","Step 01 | Task: binary | Loss: 0.0077\n","Step 02 | Task: classification | Loss: 2.2710\n","Step 03 | Task: classification | Loss: 2.6086\n","Step 04 | Task: regression | Loss: 0.9435\n","Step 05 | Task: regression | Loss: 1.0335\n","Step 06 | Task: classification | Loss: 1.3449\n","Step 07 | Task: classification | Loss: 2.6086\n","Step 08 | Task: binary | Loss: 4.1092\n","Step 09 | Task: regression | Loss: 0.1559\n","Step 10 | Task: regression | Loss: 6.2626\n","Step 11 | Task: regression | Loss: 0.1366\n","Step 12 | Task: regression | Loss: 2.2743\n","Step 13 | Task: binary | Loss: 0.7352\n","Step 14 | Task: classification | Loss: 1.7174\n","Step 15 | Task: regression | Loss: 4.1072\n","Step 16 | Task: regression | Loss: 0.9142\n","Step 17 | Task: classification | Loss: 1.6857\n","Step 18 | Task: classification | Loss: 1.0320\n","Step 19 | Task: binary | Loss: 0.0361\n","Step 20 | Task: binary | Loss: 2.0402\n","Step 21 | Task: regression | Loss: 3.3533\n","Step 22 | Task: classification | Loss: 1.3963\n","Step 23 | Task: classification | Loss: 2.2710\n","Step 24 | Task: binary | Loss: 2.1556\n","Step 25 | Task: classification | Loss: 1.2474\n","Step 26 | Task: binary | Loss: 0.0622\n","Step 27 | Task: binary | Loss: 1.7286\n","Step 28 | Task: regression | Loss: 0.1215\n","Step 29 | Task: classification | Loss: 0.9833\n","Step 30 | Task: binary | Loss: 6.0931\n","Step 31 | Task: binary | Loss: 0.1991\n","Step 32 | Task: classification | Loss: 2.2448\n","Step 33 | Task: binary | Loss: 0.0533\n","Step 34 | Task: binary | Loss: 2.5611\n","Step 35 | Task: classification | Loss: 0.7496\n","Step 36 | Task: binary | Loss: 6.7756\n","Step 37 | Task: binary | Loss: 0.0002\n","Step 38 | Task: classification | Loss: 1.7373\n","Step 39 | Task: classification | Loss: 2.4029\n","Step 40 | Task: binary | Loss: 0.3080\n","Step 41 | Task: classification | Loss: 1.1328\n","Step 42 | Task: classification | Loss: 1.0772\n","Step 43 | Task: binary | Loss: 1.8347\n","Step 44 | Task: regression | Loss: 4.5785\n","Step 45 | Task: classification | Loss: 2.7342\n","Step 46 | Task: classification | Loss: 1.4711\n","Step 47 | Task: classification | Loss: 2.7029\n","Step 48 | Task: classification | Loss: 2.7029\n","Step 49 | Task: classification | Loss: 0.5279\n","Step 50 | Task: regression | Loss: 3.5593\n","Step 51 | Task: binary | Loss: 0.4833\n","Step 52 | Task: classification | Loss: 2.0750\n","Step 53 | Task: regression | Loss: 0.1442\n","Step 54 | Task: classification | Loss: 1.2820\n","Step 55 | Task: regression | Loss: 2.0941\n","Step 56 | Task: regression | Loss: 0.6135\n","Step 57 | Task: regression | Loss: 0.1448\n","Step 58 | Task: binary | Loss: 0.9394\n","Step 59 | Task: regression | Loss: 0.1174\n","Step 60 | Task: classification | Loss: 1.6808\n","Step 61 | Task: binary | Loss: 2.0150\n","Step 62 | Task: classification | Loss: 2.0750\n","Step 63 | Task: binary | Loss: 0.0019\n","Step 64 | Task: binary | Loss: 0.0001\n","Step 65 | Task: regression | Loss: 1.1479\n","Step 66 | Task: classification | Loss: 1.2820\n","Step 67 | Task: regression | Loss: 0.0808\n","Step 68 | Task: classification | Loss: 1.2820\n","Step 69 | Task: regression | Loss: 0.8548\n","Step 70 | Task: regression | Loss: 3.8744\n","Step 71 | Task: regression | Loss: 3.2210\n","Step 72 | Task: binary | Loss: 0.0207\n","Step 73 | Task: regression | Loss: 7.4246\n","Step 74 | Task: classification | Loss: 1.6808\n","Step 75 | Task: regression | Loss: 3.7108\n","Step 76 | Task: regression | Loss: 0.4187\n","Step 77 | Task: binary | Loss: 0.1630\n","Step 78 | Task: regression | Loss: 2.9547\n","Step 79 | Task: binary | Loss: 1.2919\n","Step 80 | Task: binary | Loss: 0.5321\n","Step 81 | Task: binary | Loss: 0.2894\n","Step 82 | Task: classification | Loss: 0.5721\n","Step 83 | Task: binary | Loss: 0.0276\n","Step 84 | Task: classification | Loss: 1.3449\n","Step 85 | Task: classification | Loss: 1.7001\n","Step 86 | Task: regression | Loss: 0.0304\n","Step 87 | Task: binary | Loss: 0.1510\n","Step 88 | Task: classification | Loss: 2.8726\n","Step 89 | Task: classification | Loss: 2.8726\n","Step 90 | Task: binary | Loss: 1.6419\n","Step 91 | Task: regression | Loss: 0.4132\n","Step 92 | Task: classification | Loss: 0.5721\n","Step 93 | Task: binary | Loss: 0.3058\n","Step 94 | Task: binary | Loss: 0.1348\n","Step 95 | Task: regression | Loss: 1.6620\n","Step 96 | Task: regression | Loss: 0.0004\n","Step 97 | Task: binary | Loss: 0.4877\n","Step 98 | Task: regression | Loss: 5.9194\n","Step 99 | Task: binary | Loss: 3.9332\n","Step 100 | Task: regression | Loss: 0.9100\n","Step 101 | Task: classification | Loss: 2.8726\n","Step 102 | Task: classification | Loss: 2.8726\n","Step 103 | Task: regression | Loss: 1.3869\n","Step 104 | Task: regression | Loss: 1.1899\n","Step 105 | Task: classification | Loss: 1.4711\n","Step 106 | Task: regression | Loss: 0.2585\n","Step 107 | Task: binary | Loss: 0.1365\n","Step 108 | Task: classification | Loss: 1.4711\n","Step 109 | Task: binary | Loss: 1.5944\n","Step 110 | Task: binary | Loss: 4.2947\n","Step 111 | Task: classification | Loss: 2.8726\n","Step 112 | Task: classification | Loss: 1.4711\n","Step 113 | Task: regression | Loss: 1.2115\n","Step 114 | Task: classification | Loss: 0.5721\n","Step 115 | Task: regression | Loss: 0.8073\n","Step 116 | Task: classification | Loss: 0.5721\n","Step 117 | Task: regression | Loss: 2.2396\n","Step 118 | Task: classification | Loss: 2.4615\n","Step 119 | Task: classification | Loss: 1.4711\n","Step 120 | Task: binary | Loss: 0.0218\n","Step 121 | Task: regression | Loss: 1.2469\n","Step 122 | Task: regression | Loss: 1.4394\n","Step 123 | Task: binary | Loss: 0.8655\n","Step 124 | Task: classification | Loss: 2.8726\n","Step 125 | Task: classification | Loss: 2.7469\n","Step 126 | Task: binary | Loss: 1.0146\n","Step 127 | Task: classification | Loss: 2.8726\n","Step 128 | Task: regression | Loss: 0.6096\n","Step 129 | Task: regression | Loss: 0.6803\n","Step 130 | Task: binary | Loss: 0.4864\n","Step 131 | Task: binary | Loss: 0.2572\n","Step 132 | Task: classification | Loss: 2.7469\n","Step 133 | Task: binary | Loss: 0.0000\n","Step 134 | Task: regression | Loss: 2.4716\n","Step 135 | Task: classification | Loss: 2.4615\n","Step 136 | Task: classification | Loss: 0.5721\n","Step 137 | Task: regression | Loss: 0.0945\n","Step 138 | Task: regression | Loss: 0.0004\n","Step 139 | Task: regression | Loss: 1.0718\n","Step 140 | Task: binary | Loss: 0.3439\n","Step 141 | Task: classification | Loss: 1.4711\n","Step 142 | Task: regression | Loss: 0.3230\n","Step 143 | Task: regression | Loss: 1.1078\n","Step 144 | Task: binary | Loss: 0.9522\n","Step 145 | Task: regression | Loss: 0.0215\n","Step 146 | Task: regression | Loss: 1.1291\n","Step 147 | Task: classification | Loss: 2.8726\n","Step 148 | Task: binary | Loss: 0.0601\n","Step 149 | Task: classification | Loss: 2.4615\n","Step 150 | Task: regression | Loss: 5.1299\n","Step 151 | Task: classification | Loss: 2.7469\n","Step 152 | Task: regression | Loss: 2.8519\n","Step 153 | Task: binary | Loss: 1.3090\n","Step 154 | Task: regression | Loss: 0.5354\n","Step 155 | Task: regression | Loss: 1.1876\n","Step 156 | Task: regression | Loss: 3.8182\n","Step 157 | Task: binary | Loss: 1.9869\n","Step 158 | Task: regression | Loss: 0.2135\n","Step 159 | Task: classification | Loss: 2.4615\n","Step 160 | Task: binary | Loss: 0.1252\n","Step 161 | Task: classification | Loss: 0.5721\n","Step 162 | Task: binary | Loss: 1.4540\n","Step 163 | Task: binary | Loss: 0.6229\n","Step 164 | Task: classification | Loss: 1.4711\n","Step 165 | Task: classification | Loss: 2.8726\n","Step 166 | Task: binary | Loss: 0.0676\n","Step 167 | Task: classification | Loss: 2.7469\n","Step 168 | Task: binary | Loss: 0.5462\n","Step 169 | Task: binary | Loss: 0.9286\n","Step 170 | Task: regression | Loss: 0.1001\n","Step 171 | Task: regression | Loss: 0.9330\n","Step 172 | Task: regression | Loss: 4.4834\n","Step 173 | Task: classification | Loss: 2.7469\n","Step 174 | Task: regression | Loss: 0.2288\n","Step 175 | Task: regression | Loss: 0.3698\n","Step 176 | Task: binary | Loss: 1.2039\n","Step 177 | Task: regression | Loss: 4.7919\n","Step 178 | Task: classification | Loss: 0.9678\n","Step 179 | Task: binary | Loss: 1.0903\n","Step 180 | Task: classification | Loss: 1.4711\n","Step 181 | Task: classification | Loss: 0.5721\n","Step 182 | Task: regression | Loss: 0.0162\n","Step 183 | Task: regression | Loss: 0.0749\n","Step 184 | Task: regression | Loss: 4.2979\n","Step 185 | Task: classification | Loss: 1.3449\n","Step 186 | Task: binary | Loss: 0.7585\n","Step 187 | Task: classification | Loss: 2.8726\n","Step 188 | Task: classification | Loss: 0.5721\n","Step 189 | Task: regression | Loss: 0.0401\n","Step 190 | Task: binary | Loss: 0.5200\n","Step 191 | Task: binary | Loss: 1.2920\n","Step 192 | Task: classification | Loss: 1.4711\n","Step 193 | Task: classification | Loss: 2.8726\n","Step 194 | Task: binary | Loss: 4.9661\n","Step 195 | Task: regression | Loss: 1.6592\n","Step 196 | Task: binary | Loss: 1.3990\n","Step 197 | Task: classification | Loss: 2.8726\n","Step 198 | Task: classification | Loss: 0.5721\n","Step 199 | Task: regression | Loss: 0.0019\n","Step 200 | Task: regression | Loss: 0.0853\n","Step 201 | Task: regression | Loss: 4.8802\n","Step 202 | Task: binary | Loss: 0.3916\n","Step 203 | Task: classification | Loss: 2.8726\n","Step 204 | Task: regression | Loss: 0.5499\n","Step 205 | Task: binary | Loss: 0.0333\n","Step 206 | Task: classification | Loss: 1.7001\n","Step 207 | Task: regression | Loss: 0.2609\n","Step 208 | Task: regression | Loss: 3.4729\n","Step 209 | Task: classification | Loss: 2.7469\n","Step 210 | Task: classification | Loss: 2.4615\n","Step 211 | Task: regression | Loss: 0.9893\n","Step 212 | Task: classification | Loss: 2.4615\n","Step 213 | Task: classification | Loss: 1.3449\n","Step 214 | Task: binary | Loss: 0.5581\n","Step 215 | Task: classification | Loss: 1.3449\n","Step 216 | Task: regression | Loss: 2.4565\n","Step 217 | Task: regression | Loss: 0.0659\n","Step 218 | Task: regression | Loss: 2.1580\n","Step 219 | Task: binary | Loss: 3.5851\n","Step 220 | Task: regression | Loss: 0.4198\n","Step 221 | Task: regression | Loss: 1.1572\n","Step 222 | Task: binary | Loss: 0.0318\n","Step 223 | Task: binary | Loss: 0.1382\n","Step 224 | Task: regression | Loss: 4.8530\n","Step 225 | Task: regression | Loss: 1.5095\n","Step 226 | Task: binary | Loss: 0.3016\n","Step 227 | Task: regression | Loss: 2.5745\n","Step 228 | Task: classification | Loss: 1.3449\n","Step 229 | Task: binary | Loss: 0.8454\n","Step 230 | Task: binary | Loss: 0.4800\n","Step 231 | Task: binary | Loss: 1.1820\n","Step 232 | Task: binary | Loss: 0.6707\n","Step 233 | Task: binary | Loss: 0.0364\n","Step 234 | Task: binary | Loss: 6.1202\n","Step 235 | Task: binary | Loss: 0.4238\n","Step 236 | Task: classification | Loss: 2.2710\n","Step 237 | Task: binary | Loss: 1.0683\n","Step 238 | Task: classification | Loss: 1.3449\n","Step 239 | Task: classification | Loss: 1.7001\n","Step 240 | Task: binary | Loss: 0.3303\n","Step 241 | Task: regression | Loss: 4.2595\n","Step 242 | Task: binary | Loss: 1.5226\n","Step 243 | Task: regression | Loss: 0.7986\n","Step 244 | Task: binary | Loss: 0.4532\n","Step 245 | Task: binary | Loss: 0.0087\n","Step 246 | Task: binary | Loss: 0.0018\n","Step 247 | Task: regression | Loss: 0.0814\n","Step 248 | Task: binary | Loss: 0.1868\n","Step 249 | Task: regression | Loss: 1.5815\n"]}]},{"cell_type":"code","source":["for step in range(250): # Increase to 100 or more for better plasticity\n"," task = random.choice(tasks)\n"," inputs = {\n"," 'vision': torch.randn(1, 32),\n"," 'language': torch.randn(1, 16),\n"," 'numeric': torch.randn(1, 8)\n"," }\n"," labels = (\n"," torch.randint(0, output_dims[task], (1,))\n"," if task == 'classification'\n"," else torch.randn(1, output_dims[task])\n"," )\n","\n"," output = agent(inputs, task)\n","\n"," if task == 'classification':\n"," loss = F.cross_entropy(output, labels)\n"," pred = output.argmax(dim=1)\n"," acc = (pred == labels).float().mean().item()\n"," metrics = f\"acc: {acc:.2f}\"\n"," elif task == 'binary':\n"," loss = F.binary_cross_entropy_with_logits(output, labels)\n"," pred = torch.sigmoid(output) > 0.5\n"," acc = (pred == labels.bool()).float().mean().item()\n"," metrics = f\"acc: {acc:.2f}\"\n"," else: # regression\n"," loss = F.mse_loss(output, labels)\n"," mae = F.l1_loss(output, labels).item()\n"," metrics = f\"mae: {mae:.2f}\"\n","\n"," print(f\"Step {step:02d} | Task: {task:13s} | Loss: {loss.item():.4f} | {metrics}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ZXsec0DkLsbI","executionInfo":{"status":"ok","timestamp":1750795922472,"user_tz":-60,"elapsed":605,"user":{"displayName":"Aliyu Lawan Halliru","userId":"08436070427613420807"}},"outputId":"62831b47-e179-411c-e55d-a3ddd0fdf29a"},"id":"ZXsec0DkLsbI","execution_count":59,"outputs":[{"output_type":"stream","name":"stdout","text":["Step 00 | Task: regression | Loss: 3.4720 | mae: 1.86\n","Step 01 | Task: binary | Loss: 0.6576 | acc: 0.00\n","Step 02 | Task: binary | Loss: 0.6059 | acc: 0.00\n","Step 03 | Task: binary | Loss: 0.6521 | acc: 0.00\n","Step 04 | Task: regression | Loss: 3.3991 | mae: 1.84\n","Step 05 | Task: regression | Loss: 0.1323 | mae: 0.36\n","Step 06 | Task: regression | Loss: 0.4002 | mae: 0.63\n","Step 07 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 08 | Task: binary | Loss: 0.4880 | acc: 0.00\n","Step 09 | Task: regression | Loss: 0.0004 | mae: 0.02\n","Step 10 | Task: regression | Loss: 0.2387 | mae: 0.49\n","Step 11 | Task: binary | Loss: 0.8405 | acc: 0.00\n","Step 12 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 13 | Task: binary | Loss: 0.6413 | acc: 0.00\n","Step 14 | Task: binary | Loss: 0.7443 | acc: 0.00\n","Step 15 | Task: binary | Loss: 0.7561 | acc: 0.00\n","Step 16 | Task: binary | Loss: 0.6510 | acc: 0.00\n","Step 17 | Task: binary | Loss: 0.7274 | acc: 0.00\n","Step 18 | Task: binary | Loss: 0.3592 | acc: 0.00\n","Step 19 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 20 | Task: binary | Loss: 0.6492 | acc: 0.00\n","Step 21 | Task: regression | Loss: 0.8288 | mae: 0.91\n","Step 22 | Task: regression | Loss: 1.1962 | mae: 1.09\n","Step 23 | Task: regression | Loss: 0.8713 | mae: 0.93\n","Step 24 | Task: binary | Loss: 0.6145 | acc: 0.00\n","Step 25 | Task: binary | Loss: 0.6797 | acc: 0.00\n","Step 26 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 27 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 28 | Task: binary | Loss: 0.4494 | acc: 0.00\n","Step 29 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 30 | Task: binary | Loss: 0.6192 | acc: 0.00\n","Step 31 | Task: binary | Loss: 0.7433 | acc: 0.00\n","Step 32 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 33 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 34 | Task: binary | Loss: 0.5510 | acc: 0.00\n","Step 35 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 36 | Task: binary | Loss: 0.7590 | acc: 0.00\n","Step 37 | Task: regression | Loss: 2.1471 | mae: 1.47\n","Step 38 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 39 | Task: binary | Loss: 0.7631 | acc: 0.00\n","Step 40 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 41 | Task: binary | Loss: 0.6592 | acc: 0.00\n","Step 42 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 43 | Task: binary | Loss: 0.5732 | acc: 0.00\n","Step 44 | Task: binary | Loss: 0.6986 | acc: 0.00\n","Step 45 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 46 | Task: regression | Loss: 0.0241 | mae: 0.16\n","Step 47 | Task: regression | Loss: 0.2965 | mae: 0.54\n","Step 48 | Task: binary | Loss: 0.4164 | acc: 0.00\n","Step 49 | Task: binary | Loss: 0.5815 | acc: 0.00\n","Step 50 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 51 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 52 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 53 | Task: binary | Loss: 0.5822 | acc: 0.00\n","Step 54 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 55 | Task: regression | Loss: 0.1614 | mae: 0.40\n","Step 56 | Task: regression | Loss: 1.6873 | mae: 1.30\n","Step 57 | Task: regression | Loss: 1.9702 | mae: 1.40\n","Step 58 | Task: binary | Loss: 0.6774 | acc: 0.00\n","Step 59 | Task: regression | Loss: 0.7601 | mae: 0.87\n","Step 60 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 61 | Task: binary | Loss: 0.4534 | acc: 0.00\n","Step 62 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 63 | Task: binary | Loss: 0.6341 | acc: 0.00\n","Step 64 | Task: binary | Loss: 0.5761 | acc: 0.00\n","Step 65 | Task: regression | Loss: 0.0252 | mae: 0.16\n","Step 66 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 67 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 68 | Task: regression | Loss: 0.1416 | mae: 0.38\n","Step 69 | Task: regression | Loss: 1.9614 | mae: 1.40\n","Step 70 | Task: binary | Loss: 0.7310 | acc: 0.00\n","Step 71 | Task: binary | Loss: 0.5481 | acc: 0.00\n","Step 72 | Task: regression | Loss: 1.8252 | mae: 1.35\n","Step 73 | Task: binary | Loss: 0.5886 | acc: 0.00\n","Step 74 | Task: binary | Loss: 0.5808 | acc: 0.00\n","Step 75 | Task: regression | Loss: 0.5157 | mae: 0.72\n","Step 76 | Task: regression | Loss: 1.2680 | mae: 1.13\n","Step 77 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 78 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 79 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 80 | Task: binary | Loss: 0.5873 | acc: 0.00\n","Step 81 | Task: regression | Loss: 0.1986 | mae: 0.45\n","Step 82 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 83 | Task: regression | Loss: 0.0000 | mae: 0.01\n","Step 84 | Task: binary | Loss: 0.8143 | acc: 0.00\n","Step 85 | Task: binary | Loss: 0.6048 | acc: 0.00\n","Step 86 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 87 | Task: binary | Loss: 0.7519 | acc: 0.00\n","Step 88 | Task: binary | Loss: 0.7430 | acc: 0.00\n","Step 89 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 90 | Task: regression | Loss: 0.2550 | mae: 0.50\n","Step 91 | Task: regression | Loss: 6.5722 | mae: 2.56\n","Step 92 | Task: binary | Loss: 0.5408 | acc: 0.00\n","Step 93 | Task: regression | Loss: 0.0581 | mae: 0.24\n","Step 94 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 95 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 96 | Task: regression | Loss: 8.9453 | mae: 2.99\n","Step 97 | Task: regression | Loss: 1.1854 | mae: 1.09\n","Step 98 | Task: regression | Loss: 5.0669 | mae: 2.25\n","Step 99 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 100 | Task: regression | Loss: 0.0519 | mae: 0.23\n","Step 101 | Task: regression | Loss: 6.2753 | mae: 2.51\n","Step 102 | Task: regression | Loss: 5.4473 | mae: 2.33\n","Step 103 | Task: regression | Loss: 0.0252 | mae: 0.16\n","Step 104 | Task: binary | Loss: 0.7823 | acc: 0.00\n","Step 105 | Task: binary | Loss: 0.7246 | acc: 0.00\n","Step 106 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 107 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 108 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 109 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 110 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 111 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 112 | Task: regression | Loss: 0.1359 | mae: 0.37\n","Step 113 | Task: binary | Loss: 0.6517 | acc: 0.00\n","Step 114 | Task: regression | Loss: 1.4003 | mae: 1.18\n","Step 115 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 116 | Task: binary | Loss: 0.7363 | acc: 0.00\n","Step 117 | Task: regression | Loss: 1.3225 | mae: 1.15\n","Step 118 | Task: regression | Loss: 0.4564 | mae: 0.68\n","Step 119 | Task: binary | Loss: 0.6187 | acc: 0.00\n","Step 120 | Task: binary | Loss: 0.5629 | acc: 0.00\n","Step 121 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 122 | Task: binary | Loss: 0.5499 | acc: 0.00\n","Step 123 | Task: regression | Loss: 0.0002 | mae: 0.01\n","Step 124 | Task: binary | Loss: 0.7817 | acc: 0.00\n","Step 125 | Task: regression | Loss: 6.6409 | mae: 2.58\n","Step 126 | Task: binary | Loss: 0.6378 | acc: 0.00\n","Step 127 | Task: binary | Loss: 0.7303 | acc: 0.00\n","Step 128 | Task: binary | Loss: 0.7928 | acc: 0.00\n","Step 129 | Task: regression | Loss: 1.8978 | mae: 1.38\n","Step 130 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 131 | Task: binary | Loss: 0.7529 | acc: 0.00\n","Step 132 | Task: binary | Loss: 0.6264 | acc: 0.00\n","Step 133 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 134 | Task: binary | Loss: 0.7641 | acc: 0.00\n","Step 135 | Task: regression | Loss: 0.0137 | mae: 0.12\n","Step 136 | Task: binary | Loss: 0.8169 | acc: 0.00\n","Step 137 | Task: regression | Loss: 1.4534 | mae: 1.21\n","Step 138 | Task: regression | Loss: 2.2067 | mae: 1.49\n","Step 139 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 140 | Task: regression | Loss: 2.0662 | mae: 1.44\n","Step 141 | Task: binary | Loss: 0.6768 | acc: 0.00\n","Step 142 | Task: binary | Loss: 0.6586 | acc: 0.00\n","Step 143 | Task: regression | Loss: 0.4004 | mae: 0.63\n","Step 144 | Task: regression | Loss: 0.3880 | mae: 0.62\n","Step 145 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 146 | Task: regression | Loss: 9.8429 | mae: 3.14\n","Step 147 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 148 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 149 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 150 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 151 | Task: binary | Loss: 0.6509 | acc: 0.00\n","Step 152 | Task: binary | Loss: 0.7141 | acc: 0.00\n","Step 153 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 154 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 155 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 156 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 157 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 158 | Task: binary | Loss: 0.7384 | acc: 0.00\n","Step 159 | Task: regression | Loss: 0.0036 | mae: 0.06\n","Step 160 | Task: binary | Loss: 0.5485 | acc: 0.00\n","Step 161 | Task: binary | Loss: 0.6691 | acc: 0.00\n","Step 162 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 163 | Task: binary | Loss: 0.5356 | acc: 0.00\n","Step 164 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 165 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 166 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 167 | Task: binary | Loss: 0.5236 | acc: 0.00\n","Step 168 | Task: regression | Loss: 0.8798 | mae: 0.94\n","Step 169 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 170 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 171 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 172 | Task: regression | Loss: 2.1451 | mae: 1.46\n","Step 173 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 174 | Task: regression | Loss: 0.0062 | mae: 0.08\n","Step 175 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 176 | Task: binary | Loss: 0.6449 | acc: 0.00\n","Step 177 | Task: regression | Loss: 0.0020 | mae: 0.04\n","Step 178 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 179 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 180 | Task: binary | Loss: 0.7070 | acc: 0.00\n","Step 181 | Task: binary | Loss: 0.6523 | acc: 0.00\n","Step 182 | Task: regression | Loss: 4.1291 | mae: 2.03\n","Step 183 | Task: binary | Loss: 0.6069 | acc: 0.00\n","Step 184 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 185 | Task: binary | Loss: 0.6218 | acc: 0.00\n","Step 186 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 187 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 188 | Task: regression | Loss: 0.3541 | mae: 0.60\n","Step 189 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 190 | Task: binary | Loss: 0.7896 | acc: 0.00\n","Step 191 | Task: binary | Loss: 0.5875 | acc: 0.00\n","Step 192 | Task: regression | Loss: 0.4571 | mae: 0.68\n","Step 193 | Task: regression | Loss: 0.0112 | mae: 0.11\n","Step 194 | Task: binary | Loss: 0.6780 | acc: 0.00\n","Step 195 | Task: binary | Loss: 0.6498 | acc: 0.00\n","Step 196 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 197 | Task: binary | Loss: 0.5580 | acc: 0.00\n","Step 198 | Task: regression | Loss: 0.0986 | mae: 0.31\n","Step 199 | Task: regression | Loss: 0.2997 | mae: 0.55\n","Step 200 | Task: binary | Loss: 0.5741 | acc: 0.00\n","Step 201 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 202 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 203 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 204 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 205 | Task: regression | Loss: 2.3312 | mae: 1.53\n","Step 206 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 207 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 208 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 209 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 210 | Task: binary | Loss: 0.5065 | acc: 0.00\n","Step 211 | Task: binary | Loss: 0.5311 | acc: 0.00\n","Step 212 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 213 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 214 | Task: binary | Loss: 0.6261 | acc: 0.00\n","Step 215 | Task: binary | Loss: 0.7516 | acc: 0.00\n","Step 216 | Task: binary | Loss: 0.6156 | acc: 0.00\n","Step 217 | Task: regression | Loss: 2.5262 | mae: 1.59\n","Step 218 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 219 | Task: classification | Loss: 0.9678 | acc: 1.00\n","Step 220 | Task: regression | Loss: 1.7295 | mae: 1.32\n","Step 221 | Task: binary | Loss: 0.7137 | acc: 0.00\n","Step 222 | Task: binary | Loss: 0.5417 | acc: 0.00\n","Step 223 | Task: binary | Loss: 0.8578 | acc: 0.00\n","Step 224 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 225 | Task: binary | Loss: 0.8036 | acc: 0.00\n","Step 226 | Task: regression | Loss: 0.4451 | mae: 0.67\n","Step 227 | Task: regression | Loss: 0.6359 | mae: 0.80\n","Step 228 | Task: classification | Loss: 2.6086 | acc: 0.00\n","Step 229 | Task: regression | Loss: 3.0294 | mae: 1.74\n","Step 230 | Task: binary | Loss: 0.5698 | acc: 0.00\n","Step 231 | Task: regression | Loss: 0.1400 | mae: 0.37\n","Step 232 | Task: binary | Loss: 0.6625 | acc: 0.00\n","Step 233 | Task: regression | Loss: 0.8981 | mae: 0.95\n","Step 234 | Task: regression | Loss: 1.6829 | mae: 1.30\n","Step 235 | Task: binary | Loss: 0.6379 | acc: 0.00\n","Step 236 | Task: regression | Loss: 0.8514 | mae: 0.92\n","Step 237 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 238 | Task: regression | Loss: 3.9088 | mae: 1.98\n","Step 239 | Task: regression | Loss: 1.0906 | mae: 1.04\n","Step 240 | Task: binary | Loss: 0.6551 | acc: 0.00\n","Step 241 | Task: regression | Loss: 0.9411 | mae: 0.97\n","Step 242 | Task: regression | Loss: 0.1226 | mae: 0.35\n","Step 243 | Task: classification | Loss: 1.7001 | acc: 0.00\n","Step 244 | Task: binary | Loss: 0.6715 | acc: 0.00\n","Step 245 | Task: regression | Loss: 0.1092 | mae: 0.33\n","Step 246 | Task: binary | Loss: 0.5053 | acc: 0.00\n","Step 247 | Task: classification | Loss: 1.3449 | acc: 0.00\n","Step 248 | Task: classification | Loss: 2.2710 | acc: 0.00\n","Step 249 | Task: classification | Loss: 0.9678 | acc: 1.00\n"]}]}],"metadata":{"colab":{"provenance":[{"file_id":"1ozSNP2Eodi2WhmyYnlQA0y-1Pfm51UkW","timestamp":1750788847560}]},"language_info":{"name":"python"},"kernelspec":{"name":"python3","display_name":"Python 3"}},"nbformat":4,"nbformat_minor":5}