Delete COMPLETE_modular_brain_agent_with_spikes_and_plasticity.ipynb (1)
Browse files
COMPLETE_modular_brain_agent_with_spikes_and_plasticity.ipynb (1)
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"cells":[{"cell_type":"code","source":["# MIT License\n","#\n","# Copyright (c) 2025 ALMUSAWIY Halliru\n","#\n","# Permission is hereby granted, free of charge, to any person obtaining a copy\n","# of this software and associated documentation files (the \"Software\"), to deal\n","# in the Software without restriction, including without limitation the rights\n","# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n","# copies of the Software, and to permit persons to whom the Software is\n","# furnished to do so, subject to the following conditions:\n","#\n","# The above copyright notice and this permission notice shall be included in all\n","# copies or substantial portions of the Software.\n","#\n","# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n","# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n","# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n","# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n","# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n","# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n","# SOFTWARE.\n","\n","# === V3 Modular Brain Agent with Plasticity - Block 1 ===\n","\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import numpy as np\n","import random\n","from torch.utils.data import DataLoader, Dataset\n","from collections import deque\n","from torchvision import datasets, transforms\n","\n","# === Plastic Synapse Mechanisms ===\n","class PlasticLinear(nn.Module):\n"," def __init__(self, in_features, out_features, plasticity_type=\"hebbian\", learning_rate=0.01):\n"," super().__init__()\n"," self.in_features = in_features\n"," self.out_features = out_features\n"," self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)\n"," self.bias = nn.Parameter(torch.zeros(out_features))\n"," self.plasticity_type = plasticity_type\n"," self.eta = learning_rate\n"," self.trace = torch.zeros(out_features, in_features)\n"," self.register_buffer('prev_y', torch.zeros(out_features))\n","\n"," def forward(self, x):\n"," y = F.linear(x, self.weight, self.bias)\n"," if self.training:\n"," x_detached = x.detach()\n"," y_detached = y.detach()\n"," if self.plasticity_type == \"hebbian\":\n"," hebb = torch.einsum('bi,bj->ij', y_detached, x_detached) / x.size(0)\n"," self.trace = (1 - self.eta) * self.trace + self.eta * hebb\n"," with torch.no_grad():\n"," self.weight += self.trace\n"," elif self.plasticity_type == \"stdp\":\n"," dy = y_detached - self.prev_y\n"," stdp = torch.einsum('bi,bj->ij', dy, x_detached) / x.size(0)\n"," self.trace = (1 - self.eta) * self.trace + self.eta * stdp\n"," with torch.no_grad():\n"," self.weight += self.trace\n"," self.prev_y = y_detached.clone()\n"," return y\n","\n","# === Spiking Surrogate Functions and Base Neurons ===\n","class SpikeFunction(torch.autograd.Function):\n"," @staticmethod\n"," def forward(ctx, input):\n"," ctx.save_for_backward(input)\n"," return (input > 0).float()\n","\n"," @staticmethod\n"," def backward(ctx, grad_output):\n"," input, = ctx.saved_tensors\n"," return grad_output * (abs(input) < 1).float()\n","\n","spike_fn = SpikeFunction.apply\n","\n","class LIFNeuron(nn.Module):\n"," def __init__(self, tau=2.0):\n"," super().__init__()\n"," self.tau = tau\n"," self.mem = 0\n","\n"," def forward(self, x):\n"," decay = torch.exp(torch.tensor(-1.0 / self.tau))\n"," self.mem = self.mem * decay + x\n"," out = spike_fn(self.mem - 1.0)\n"," self.mem = self.mem * (1.0 - out.detach())\n"," return out\n","\n","# === Adaptive LIF Neuron ===\n","class AdaptiveLIF(nn.Module):\n"," def __init__(self, size, tau=2.0, beta=0.2):\n"," super().__init__()\n"," self.size = size\n"," self.tau = tau\n"," self.beta = beta\n"," self.mem = torch.zeros(size)\n"," self.thresh = torch.ones(size)\n","\n"," def forward(self, x):\n"," decay = torch.exp(torch.tensor(-1.0 / self.tau))\n"," self.mem = self.mem * decay + x\n"," out = spike_fn(self.mem - self.thresh)\n"," self.thresh = self.thresh + self.beta * out\n"," self.mem = self.mem * (1.0 - out.detach())\n"," return out\n","\n","# === Relay Layer with Attention ===\n","class RelayLayer(nn.Module):\n"," def __init__(self, dim, heads=4):\n"," super().__init__()\n"," self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, batch_first=True)\n"," self.lif = LIFNeuron()\n","\n"," def forward(self, x):\n"," attn_out, _ = self.attn(x, x, x)\n"," return self.lif(attn_out)\n","\n","# === Working Memory ===\n","class WorkingMemory(nn.Module):\n"," def __init__(self, input_dim, hidden_dim):\n"," super().__init__()\n"," self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)\n","\n"," def forward(self, x):\n"," out, _ = self.lstm(x)\n"," return out[:, -1]\n","\n","# === Place Cell Grid ===\n","class PlaceGrid(nn.Module):\n"," def __init__(self, grid_size=10, embedding_dim=64):\n"," super().__init__()\n"," self.embedding = nn.Embedding(grid_size**2, embedding_dim)\n","\n"," def forward(self, index):\n"," return self.embedding(index)\n","\n","# === Mirror Comparator ===\n","class MirrorComparator(nn.Module):\n"," def __init__(self, dim):\n"," super().__init__()\n"," self.cos = nn.CosineSimilarity(dim=1)\n","\n"," def forward(self, x1, x2):\n"," return self.cos(x1, x2).unsqueeze(1)\n","\n","# === Neuroendocrine Module ===\n","class NeuroendocrineModulator(nn.Module):\n"," def __init__(self, input_dim, hidden_dim):\n"," super().__init__()\n"," self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)\n","\n"," def forward(self, x):\n"," out, _ = self.lstm(x)\n"," return out[:, -1]\n","\n","# === Autonomic Feedback Module ===\n","class AutonomicFeedback(nn.Module):\n"," def __init__(self, input_dim):\n"," super().__init__()\n"," self.feedback = nn.Linear(input_dim, input_dim)\n","\n"," def forward(self, x):\n"," return torch.tanh(self.feedback(x))\n","\n","# === Replay Buffer ===\n","class ReplayBuffer:\n"," def __init__(self, capacity=1000):\n"," self.buffer = deque(maxlen=capacity)\n","\n"," def add(self, inputs, labels, task):\n"," self.buffer.append((inputs, labels, task))\n","\n"," def sample(self, batch_size):\n"," indices = random.sample(range(len(self.buffer)), batch_size)\n"," batch = [self.buffer[i] for i in indices]\n"," inputs, labels, tasks = zip(*batch)\n"," return inputs, labels, tasks\n","\n","# === Full Modular Brain Agent with Plasticity ===\n","class ModularBrainAgent(nn.Module):\n"," def __init__(self, input_dims, hidden_dim, output_dims):\n"," super().__init__()\n"," self.vision_encoder = nn.Linear(input_dims['vision'], hidden_dim)\n"," self.language_encoder = nn.Linear(input_dims['language'], hidden_dim)\n"," self.numeric_encoder = nn.Linear(input_dims['numeric'], hidden_dim)\n","\n"," # Plastic synapses (Hebbian and STDP)\n"," self.connect_sensory_to_relay = PlasticLinear(hidden_dim * 3, hidden_dim, plasticity_type='hebbian')\n"," self.relay_layer = RelayLayer(hidden_dim)\n"," self.connect_relay_to_inter = PlasticLinear(hidden_dim, hidden_dim, plasticity_type='stdp')\n","\n"," self.interneuron = AdaptiveLIF(hidden_dim)\n"," self.memory = WorkingMemory(hidden_dim, hidden_dim)\n"," self.place = PlaceGrid(grid_size=10, embedding_dim=hidden_dim)\n"," self.comparator = MirrorComparator(hidden_dim)\n"," self.emotion = NeuroendocrineModulator(hidden_dim, hidden_dim)\n"," self.feedback = AutonomicFeedback(hidden_dim)\n","\n"," self.task_heads = nn.ModuleDict({\n"," task: nn.Linear(hidden_dim, out_dim)\n"," for task, out_dim in output_dims.items()\n"," })\n","\n"," self.replay = ReplayBuffer()\n","\n"," def forward(self, inputs, task, position_idx=None):\n"," v = self.vision_encoder(inputs['vision'])\n"," l = self.language_encoder(inputs['language'])\n"," n = self.numeric_encoder(inputs['numeric'])\n","\n"," sensory_cat = torch.cat([v, l, n], dim=-1)\n"," z = self.connect_sensory_to_relay(sensory_cat)\n","\n"," z = self.relay_layer(z.unsqueeze(1)).squeeze(1)\n"," z = self.connect_relay_to_inter(z)\n"," z = self.interneuron(z)\n","\n"," m = self.memory(z.unsqueeze(1))\n"," p = self.place(position_idx if position_idx is not None else torch.tensor([0]))\n"," e = self.emotion(z.unsqueeze(1))\n"," f = self.feedback(z)\n","\n"," combined = z + m + p + e + f\n"," out = self.task_heads[task](combined)\n"," return out\n","\n"," def remember(self, inputs, labels, task):\n"," self.replay.add(inputs, labels, task)\n","\n","# === Main Test Block ===\n","if __name__ == \"__main__\":\n"," input_dims = {'vision': 32, 'language': 16, 'numeric': 8}\n"," output_dims = {'classification': 5, 'regression': 1, 'binary': 1}\n"," agent = ModularBrainAgent(input_dims, hidden_dim=64, output_dims=output_dims)\n","\n"," tasks = list(output_dims.keys())\n","\n"," for step in range(250):\n"," task = random.choice(tasks)\n"," inputs = {\n"," 'vision': torch.randn(1, 32),\n"," 'language': torch.randn(1, 16),\n"," 'numeric': torch.randn(1, 8)\n"," }\n"," labels = torch.randint(0, output_dims[task], (1,)) if task == 'classification' else torch.randn(1, output_dims[task])\n"," output = agent(inputs, task)\n"," loss = F.cross_entropy(output, labels) if task == 'classification' else F.mse_loss(output, labels)\n"," print(f\"Step {step:02d} | Task: {task:13s} | Loss: {loss.item():.4f}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ziYXMGMEBPzD","executionInfo":{"status":"ok","timestamp":1750792622086,"user_tz":-60,"elapsed":671,"user":{"displayName":"Aliyu Lawan Halliru","userId":"08436070427613420807"}},"outputId":"be6bde14-ea3e-4845-a6fb-975894a1086a","collapsed":true},"id":"ziYXMGMEBPzD","execution_count":56,"outputs":[{"output_type":"stream","name":"stdout","text":["Step 00 | Task: regression | Loss: 0.6114\n","Step 01 | Task: classification | Loss: 1.9526\n","Step 02 | Task: binary | Loss: 0.0619\n","Step 03 | Task: regression | Loss: 0.4528\n","Step 04 | Task: classification | Loss: 1.5287\n","Step 05 | Task: regression | Loss: 0.1885\n","Step 06 | Task: binary | Loss: 0.8592\n","Step 07 | Task: regression | Loss: 0.0030\n","Step 08 | Task: regression | Loss: 2.2366\n","Step 09 | Task: regression | Loss: 0.1509\n","Step 10 | Task: classification | Loss: 2.3092\n","Step 11 | Task: binary | Loss: 0.9322\n","Step 12 | Task: binary | Loss: 1.2702\n","Step 13 | Task: binary | Loss: 1.4081\n","Step 14 | Task: regression | Loss: 0.0778\n","Step 15 | Task: regression | Loss: 0.1083\n","Step 16 | Task: regression | Loss: 0.4360\n","Step 17 | Task: classification | Loss: 1.5915\n","Step 18 | Task: regression | Loss: 3.2734\n","Step 19 | Task: regression | Loss: 0.0337\n","Step 20 | Task: binary | Loss: 0.0791\n","Step 21 | Task: binary | Loss: 0.2117\n","Step 22 | Task: binary | Loss: 5.6710\n","Step 23 | Task: binary | Loss: 0.0450\n","Step 24 | Task: classification | Loss: 0.9686\n","Step 25 | Task: regression | Loss: 0.1319\n","Step 26 | Task: regression | Loss: 1.3254\n","Step 27 | Task: classification | Loss: 1.4636\n","Step 28 | Task: regression | Loss: 1.0090\n","Step 29 | Task: classification | Loss: 2.1761\n","Step 30 | Task: regression | Loss: 1.5670\n","Step 31 | Task: classification | Loss: 2.2772\n","Step 32 | Task: binary | Loss: 1.6179\n","Step 33 | Task: binary | Loss: 5.2599\n","Step 34 | Task: classification | Loss: 0.6671\n","Step 35 | Task: binary | Loss: 0.3045\n","Step 36 | Task: classification | Loss: 0.9672\n","Step 37 | Task: regression | Loss: 1.4661\n","Step 38 | Task: binary | Loss: 1.9450\n","Step 39 | Task: regression | Loss: 1.5602\n","Step 40 | Task: classification | Loss: 0.9481\n","Step 41 | Task: classification | Loss: 2.6518\n","Step 42 | Task: classification | Loss: 0.9192\n","Step 43 | Task: binary | Loss: 0.7554\n","Step 44 | Task: classification | Loss: 2.6885\n","Step 45 | Task: regression | Loss: 0.9646\n","Step 46 | Task: regression | Loss: 2.5860\n","Step 47 | Task: classification | Loss: 1.5601\n","Step 48 | Task: regression | Loss: 0.1791\n","Step 49 | Task: binary | Loss: 0.8173\n","Step 50 | Task: classification | Loss: 1.6362\n","Step 51 | Task: binary | Loss: 0.0094\n","Step 52 | Task: classification | Loss: 2.6885\n","Step 53 | Task: binary | Loss: 0.2068\n","Step 54 | Task: binary | Loss: 2.4436\n","Step 55 | Task: binary | Loss: 0.0122\n","Step 56 | Task: classification | Loss: 0.9483\n","Step 57 | Task: binary | Loss: 1.6579\n","Step 58 | Task: classification | Loss: 1.4136\n","Step 59 | Task: regression | Loss: 0.0385\n","Step 60 | Task: regression | Loss: 0.0363\n","Step 61 | Task: classification | Loss: 1.6362\n","Step 62 | Task: regression | Loss: 0.3630\n","Step 63 | Task: regression | Loss: 2.8187\n","Step 64 | Task: binary | Loss: 0.8714\n","Step 65 | Task: regression | Loss: 2.8570\n","Step 66 | Task: binary | Loss: 0.2519\n","Step 67 | Task: binary | Loss: 0.0046\n","Step 68 | Task: regression | Loss: 0.2388\n","Step 69 | Task: binary | Loss: 0.0953\n","Step 70 | Task: binary | Loss: 0.0001\n","Step 71 | Task: binary | Loss: 0.5775\n","Step 72 | Task: classification | Loss: 2.2497\n","Step 73 | Task: classification | Loss: 2.2497\n","Step 74 | Task: classification | Loss: 0.9483\n","Step 75 | Task: regression | Loss: 0.0641\n","Step 76 | Task: regression | Loss: 0.5261\n","Step 77 | Task: regression | Loss: 2.8726\n","Step 78 | Task: regression | Loss: 0.2244\n","Step 79 | Task: regression | Loss: 0.0225\n","Step 80 | Task: classification | Loss: 1.4136\n","Step 81 | Task: classification | Loss: 1.6362\n","Step 82 | Task: binary | Loss: 0.0065\n","Step 83 | Task: binary | Loss: 0.0118\n","Step 84 | Task: regression | Loss: 2.3820\n","Step 85 | Task: binary | Loss: 4.2463\n","Step 86 | Task: binary | Loss: 0.2670\n","Step 87 | Task: binary | Loss: 0.3418\n","Step 88 | Task: binary | Loss: 3.4286\n","Step 89 | Task: classification | Loss: 2.6711\n","Step 90 | Task: classification | Loss: 0.9483\n","Step 91 | Task: binary | Loss: 0.0477\n","Step 92 | Task: regression | Loss: 2.6107\n","Step 93 | Task: classification | Loss: 1.6362\n","Step 94 | Task: regression | Loss: 0.1060\n","Step 95 | Task: binary | Loss: 0.2471\n","Step 96 | Task: regression | Loss: 0.0640\n","Step 97 | Task: regression | Loss: 0.0057\n","Step 98 | Task: classification | Loss: 1.4136\n","Step 99 | Task: classification | Loss: 2.6711\n","Step 100 | Task: classification | Loss: 1.9526\n","Step 101 | Task: binary | Loss: 1.3364\n","Step 102 | Task: regression | Loss: 0.2104\n","Step 103 | Task: binary | Loss: 0.4321\n","Step 104 | Task: classification | Loss: 1.5287\n","Step 105 | Task: regression | Loss: 0.8773\n","Step 106 | Task: regression | Loss: 1.6732\n","Step 107 | Task: binary | Loss: 0.0427\n","Step 108 | Task: regression | Loss: 2.7213\n","Step 109 | Task: classification | Loss: 2.2497\n","Step 110 | Task: binary | Loss: 0.2391\n","Step 111 | Task: regression | Loss: 0.5446\n","Step 112 | Task: classification | Loss: 2.6711\n","Step 113 | Task: classification | Loss: 1.4136\n","Step 114 | Task: regression | Loss: 0.3452\n","Step 115 | Task: binary | Loss: 0.1353\n","Step 116 | Task: regression | Loss: 0.1175\n","Step 117 | Task: regression | Loss: 0.4142\n","Step 118 | Task: binary | Loss: 0.0000\n","Step 119 | Task: classification | Loss: 1.0401\n","Step 120 | Task: binary | Loss: 0.1260\n","Step 121 | Task: classification | Loss: 1.6362\n","Step 122 | Task: regression | Loss: 0.0367\n","Step 123 | Task: regression | Loss: 0.0916\n","Step 124 | Task: regression | Loss: 1.0657\n","Step 125 | Task: regression | Loss: 1.0912\n","Step 126 | Task: regression | Loss: 0.0192\n","Step 127 | Task: classification | Loss: 2.9280\n","Step 128 | Task: classification | Loss: 2.2497\n","Step 129 | Task: binary | Loss: 0.3649\n","Step 130 | Task: regression | Loss: 0.2958\n","Step 131 | Task: regression | Loss: 0.0648\n","Step 132 | Task: binary | Loss: 0.0208\n","Step 133 | Task: binary | Loss: 1.8152\n","Step 134 | Task: regression | Loss: 2.8926\n","Step 135 | Task: binary | Loss: 4.0353\n","Step 136 | Task: classification | Loss: 1.6362\n","Step 137 | Task: classification | Loss: 2.2497\n","Step 138 | Task: classification | Loss: 0.9483\n","Step 139 | Task: regression | Loss: 0.1746\n","Step 140 | Task: classification | Loss: 2.2497\n","Step 141 | Task: binary | Loss: 0.0110\n","Step 142 | Task: regression | Loss: 0.2563\n","Step 143 | Task: regression | Loss: 0.6897\n","Step 144 | Task: classification | Loss: 1.6362\n","Step 145 | Task: binary | Loss: 1.9309\n","Step 146 | Task: regression | Loss: 0.0295\n","Step 147 | Task: classification | Loss: 2.6711\n","Step 148 | Task: regression | Loss: 1.0839\n","Step 149 | Task: regression | Loss: 3.2254\n","Step 150 | Task: regression | Loss: 2.1562\n","Step 151 | Task: regression | Loss: 0.8985\n","Step 152 | Task: binary | Loss: 2.8646\n","Step 153 | Task: binary | Loss: 0.6221\n","Step 154 | Task: classification | Loss: 1.4136\n","Step 155 | Task: classification | Loss: 2.6711\n","Step 156 | Task: classification | Loss: 1.6362\n","Step 157 | Task: binary | Loss: 1.2010\n","Step 158 | Task: binary | Loss: 0.0754\n","Step 159 | Task: binary | Loss: 0.6228\n","Step 160 | Task: regression | Loss: 1.4618\n","Step 161 | Task: classification | Loss: 2.6711\n","Step 162 | Task: classification | Loss: 2.9092\n","Step 163 | Task: regression | Loss: 0.4006\n","Step 164 | Task: regression | Loss: 0.1939\n","Step 165 | Task: binary | Loss: 1.1743\n","Step 166 | Task: classification | Loss: 2.6711\n","Step 167 | Task: regression | Loss: 3.2660\n","Step 168 | Task: classification | Loss: 2.6711\n","Step 169 | Task: classification | Loss: 1.6362\n","Step 170 | Task: binary | Loss: 0.2367\n","Step 171 | Task: classification | Loss: 2.6711\n","Step 172 | Task: regression | Loss: 0.3036\n","Step 173 | Task: regression | Loss: 0.3184\n","Step 174 | Task: regression | Loss: 0.6776\n","Step 175 | Task: binary | Loss: 0.5469\n","Step 176 | Task: binary | Loss: 1.3678\n","Step 177 | Task: classification | Loss: 2.6711\n","Step 178 | Task: binary | Loss: 0.1856\n","Step 179 | Task: binary | Loss: 0.3055\n","Step 180 | Task: classification | Loss: 2.6711\n","Step 181 | Task: classification | Loss: 1.4136\n","Step 182 | Task: classification | Loss: 2.2497\n","Step 183 | Task: classification | Loss: 2.6711\n","Step 184 | Task: regression | Loss: 0.5866\n","Step 185 | Task: binary | Loss: 0.9470\n","Step 186 | Task: binary | Loss: 0.2032\n","Step 187 | Task: binary | Loss: 1.3610\n","Step 188 | Task: classification | Loss: 1.4136\n","Step 189 | Task: binary | Loss: 0.3655\n","Step 190 | Task: classification | Loss: 1.6362\n","Step 191 | Task: regression | Loss: 0.0086\n","Step 192 | Task: regression | Loss: 0.7909\n","Step 193 | Task: regression | Loss: 0.1362\n","Step 194 | Task: binary | Loss: 1.4826\n","Step 195 | Task: classification | Loss: 1.6362\n","Step 196 | Task: binary | Loss: 0.8777\n","Step 197 | Task: binary | Loss: 2.6137\n","Step 198 | Task: regression | Loss: 0.0167\n","Step 199 | Task: binary | Loss: 0.1310\n","Step 200 | Task: regression | Loss: 1.1864\n","Step 201 | Task: regression | Loss: 0.0054\n","Step 202 | Task: binary | Loss: 0.1925\n","Step 203 | Task: binary | Loss: 0.3395\n","Step 204 | Task: classification | Loss: 1.6362\n","Step 205 | Task: classification | Loss: 1.6362\n","Step 206 | Task: binary | Loss: 0.3006\n","Step 207 | Task: classification | Loss: 2.6711\n","Step 208 | Task: regression | Loss: 0.0026\n","Step 209 | Task: binary | Loss: 0.5486\n","Step 210 | Task: regression | Loss: 0.4880\n","Step 211 | Task: classification | Loss: 1.5287\n","Step 212 | Task: binary | Loss: 0.1283\n","Step 213 | Task: regression | Loss: 0.2531\n","Step 214 | Task: classification | Loss: 0.7186\n","Step 215 | Task: binary | Loss: 1.4874\n","Step 216 | Task: binary | Loss: 1.3092\n","Step 217 | Task: classification | Loss: 0.9483\n","Step 218 | Task: binary | Loss: 7.5397\n","Step 219 | Task: regression | Loss: 1.1885\n","Step 220 | Task: regression | Loss: 0.0390\n","Step 221 | Task: classification | Loss: 1.6362\n","Step 222 | Task: binary | Loss: 1.1009\n","Step 223 | Task: regression | Loss: 0.3119\n","Step 224 | Task: regression | Loss: 0.9518\n","Step 225 | Task: binary | Loss: 2.4765\n","Step 226 | Task: regression | Loss: 0.0178\n","Step 227 | Task: binary | Loss: 1.6577\n","Step 228 | Task: classification | Loss: 2.9092\n","Step 229 | Task: regression | Loss: 0.0396\n","Step 230 | Task: binary | Loss: 2.1024\n","Step 231 | Task: binary | Loss: 1.4101\n","Step 232 | Task: classification | Loss: 0.7186\n","Step 233 | Task: classification | Loss: 0.7186\n","Step 234 | Task: regression | Loss: 0.0988\n","Step 235 | Task: binary | Loss: 0.0508\n","Step 236 | Task: classification | Loss: 1.9526\n","Step 237 | Task: classification | Loss: 2.9092\n","Step 238 | Task: regression | Loss: 1.4061\n","Step 239 | Task: classification | Loss: 0.7186\n","Step 240 | Task: regression | Loss: 0.0000\n","Step 241 | Task: binary | Loss: 1.7641\n","Step 242 | Task: binary | Loss: 0.1250\n","Step 243 | Task: regression | Loss: 0.6529\n","Step 244 | Task: regression | Loss: 1.9295\n","Step 245 | Task: regression | Loss: 0.0056\n","Step 246 | Task: binary | Loss: 1.7380\n","Step 247 | Task: regression | Loss: 0.8137\n","Step 248 | Task: binary | Loss: 0.0045\n","Step 249 | Task: classification | Loss: 0.7186\n"]}]},{"cell_type":"code","source":["for step in range(100): # Increase to 100 or more for better plasticity\n"," task = random.choice(tasks)\n"," inputs = {\n"," 'vision': torch.randn(1, 32),\n"," 'language': torch.randn(1, 16),\n"," 'numeric': torch.randn(1, 8)\n"," }\n"," labels = (\n"," torch.randint(0, output_dims[task], (1,))\n"," if task == 'classification'\n"," else torch.randn(1, output_dims[task])\n"," )\n","\n"," output = agent(inputs, task)\n","\n"," if task == 'classification':\n"," loss = F.cross_entropy(output, labels)\n"," pred = output.argmax(dim=1)\n"," acc = (pred == labels).float().mean().item()\n"," metrics = f\"acc: {acc:.2f}\"\n"," elif task == 'binary':\n"," loss = F.binary_cross_entropy_with_logits(output, labels)\n"," pred = torch.sigmoid(output) > 0.5\n"," acc = (pred == labels.bool()).float().mean().item()\n"," metrics = f\"acc: {acc:.2f}\"\n"," else: # regression\n"," loss = F.mse_loss(output, labels)\n"," mae = F.l1_loss(output, labels).item()\n"," metrics = f\"mae: {mae:.2f}\"\n","\n"," print(f\"Step {step:02d} | Task: {task:13s} | Loss: {loss.item():.4f} | {metrics}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ZXsec0DkLsbI","executionInfo":{"status":"ok","timestamp":1750792448670,"user_tz":-60,"elapsed":262,"user":{"displayName":"Aliyu Lawan Halliru","userId":"08436070427613420807"}},"outputId":"b847a163-352a-4d9c-a04c-ef67e1b3823f"},"id":"ZXsec0DkLsbI","execution_count":53,"outputs":[{"output_type":"stream","name":"stdout","text":["Step 00 | Task: classification | Loss: 1.6657 | acc: 0.00\n","Step 01 | Task: binary | Loss: 0.8285 | acc: 1.00\n","Step 02 | Task: regression | Loss: 0.0085 | mae: 0.09\n","Step 03 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 04 | Task: classification | Loss: 1.9228 | acc: 0.00\n","Step 05 | Task: binary | Loss: 0.7943 | acc: 1.00\n","Step 06 | Task: regression | Loss: 2.2827 | mae: 1.51\n","Step 07 | Task: classification | Loss: 2.7771 | acc: 0.00\n","Step 08 | Task: regression | Loss: 0.0678 | mae: 0.26\n","Step 09 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 10 | Task: binary | Loss: 0.8396 | acc: 1.00\n","Step 11 | Task: binary | Loss: 0.9772 | acc: 1.00\n","Step 12 | Task: binary | Loss: 0.7273 | acc: 1.00\n","Step 13 | Task: regression | Loss: 0.5233 | mae: 0.72\n","Step 14 | Task: classification | Loss: 2.7771 | acc: 0.00\n","Step 15 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 16 | Task: regression | Loss: 0.4062 | mae: 0.64\n","Step 17 | Task: regression | Loss: 3.9219 | mae: 1.98\n","Step 18 | Task: regression | Loss: 1.4991 | mae: 1.22\n","Step 19 | Task: classification | Loss: 1.6657 | acc: 0.00\n","Step 20 | Task: regression | Loss: 0.0939 | mae: 0.31\n","Step 21 | Task: classification | Loss: 1.6657 | acc: 0.00\n","Step 22 | Task: regression | Loss: 0.8141 | mae: 0.90\n","Step 23 | Task: classification | Loss: 1.9228 | acc: 0.00\n","Step 24 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 25 | Task: binary | Loss: 0.7571 | acc: 1.00\n","Step 26 | Task: binary | Loss: 0.9775 | acc: 1.00\n","Step 27 | Task: binary | Loss: 0.7558 | acc: 1.00\n","Step 28 | Task: regression | Loss: 5.6336 | mae: 2.37\n","Step 29 | Task: binary | Loss: 0.8145 | acc: 1.00\n","Step 30 | Task: regression | Loss: 1.7607 | mae: 1.33\n","Step 31 | Task: classification | Loss: 2.0013 | acc: 0.00\n","Step 32 | Task: classification | Loss: 1.9228 | acc: 0.00\n","Step 33 | Task: binary | Loss: 0.7005 | acc: 1.00\n","Step 34 | Task: regression | Loss: 2.1473 | mae: 1.47\n","Step 35 | Task: binary | Loss: 0.6090 | acc: 1.00\n","Step 36 | Task: binary | Loss: 0.5324 | acc: 1.00\n","Step 37 | Task: regression | Loss: 1.4789 | mae: 1.22\n","Step 38 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 39 | Task: regression | Loss: 0.3726 | mae: 0.61\n","Step 40 | Task: regression | Loss: 0.9034 | mae: 0.95\n","Step 41 | Task: classification | Loss: 1.9228 | acc: 0.00\n","Step 42 | Task: regression | Loss: 0.1888 | mae: 0.43\n","Step 43 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 44 | Task: binary | Loss: 0.7300 | acc: 1.00\n","Step 45 | Task: binary | Loss: 0.7485 | acc: 1.00\n","Step 46 | Task: regression | Loss: 7.2684 | mae: 2.70\n","Step 47 | Task: binary | Loss: 1.0089 | acc: 1.00\n","Step 48 | Task: classification | Loss: 2.0013 | acc: 0.00\n","Step 49 | Task: classification | Loss: 2.0013 | acc: 0.00\n","Step 50 | Task: regression | Loss: 1.6282 | mae: 1.28\n","Step 51 | Task: regression | Loss: 0.7651 | mae: 0.87\n","Step 52 | Task: regression | Loss: 1.1417 | mae: 1.07\n","Step 53 | Task: classification | Loss: 2.7771 | acc: 0.00\n","Step 54 | Task: binary | Loss: 0.8191 | acc: 1.00\n","Step 55 | Task: regression | Loss: 1.7363 | mae: 1.32\n","Step 56 | Task: regression | Loss: 0.6425 | mae: 0.80\n","Step 57 | Task: regression | Loss: 4.0519 | mae: 2.01\n","Step 58 | Task: classification | Loss: 2.7771 | acc: 0.00\n","Step 59 | Task: regression | Loss: 1.7607 | mae: 1.33\n","Step 60 | Task: regression | Loss: 0.8861 | mae: 0.94\n","Step 61 | Task: binary | Loss: 0.8241 | acc: 1.00\n","Step 62 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 63 | Task: regression | Loss: 2.5003 | mae: 1.58\n","Step 64 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 65 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 66 | Task: classification | Loss: 1.9228 | acc: 0.00\n","Step 67 | Task: regression | Loss: 3.5996 | mae: 1.90\n","Step 68 | Task: regression | Loss: 2.0555 | mae: 1.43\n","Step 69 | Task: binary | Loss: 0.9101 | acc: 1.00\n","Step 70 | Task: binary | Loss: 0.7864 | acc: 1.00\n","Step 71 | Task: regression | Loss: 0.8772 | mae: 0.94\n","Step 72 | Task: regression | Loss: 0.0098 | mae: 0.10\n","Step 73 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 74 | Task: regression | Loss: 0.3767 | mae: 0.61\n","Step 75 | Task: regression | Loss: 0.0136 | mae: 0.12\n","Step 76 | Task: classification | Loss: 2.0013 | acc: 0.00\n","Step 77 | Task: regression | Loss: 0.4406 | mae: 0.66\n","Step 78 | Task: classification | Loss: 1.9228 | acc: 0.00\n","Step 79 | Task: binary | Loss: 0.8423 | acc: 1.00\n","Step 80 | Task: binary | Loss: 0.8438 | acc: 1.00\n","Step 81 | Task: regression | Loss: 0.7293 | mae: 0.85\n","Step 82 | Task: classification | Loss: 0.7606 | acc: 1.00\n","Step 83 | Task: regression | Loss: 0.1307 | mae: 0.36\n","Step 84 | Task: classification | Loss: 1.9228 | acc: 0.00\n","Step 85 | Task: binary | Loss: 0.8713 | acc: 1.00\n","Step 86 | Task: classification | Loss: 2.0013 | acc: 0.00\n","Step 87 | Task: binary | Loss: 0.5357 | acc: 1.00\n","Step 88 | Task: classification | Loss: 1.9228 | acc: 0.00\n","Step 89 | Task: binary | Loss: 0.6586 | acc: 1.00\n","Step 90 | Task: regression | Loss: 0.7393 | mae: 0.86\n","Step 91 | Task: regression | Loss: 0.0342 | mae: 0.18\n","Step 92 | Task: binary | Loss: 0.6178 | acc: 1.00\n","Step 93 | Task: regression | Loss: 0.8341 | mae: 0.91\n","Step 94 | Task: classification | Loss: 2.0013 | acc: 0.00\n","Step 95 | Task: classification | Loss: 1.6657 | acc: 0.00\n","Step 96 | Task: binary | Loss: 0.7904 | acc: 1.00\n","Step 97 | Task: binary | Loss: 0.6778 | acc: 1.00\n","Step 98 | Task: classification | Loss: 2.0013 | acc: 0.00\n","Step 99 | Task: binary | Loss: 0.7135 | acc: 1.00\n"]}]}],"metadata":{"colab":{"provenance":[{"file_id":"1ozSNP2Eodi2WhmyYnlQA0y-1Pfm51UkW","timestamp":1750788847560}]},"language_info":{"name":"python"},"kernelspec":{"name":"python3","display_name":"Python 3"}},"nbformat":4,"nbformat_minor":5}
|
|
|
|