Almusawee commited on
Commit
57e9a98
·
verified ·
1 Parent(s): fb4bacf

Create SynCo_modular_brain_agent_with_spikes_and_plasticity.py

Browse files
SynCo_modular_brain_agent_with_spikes_and_plasticity.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 ALMUSAWIY Halliru
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # === V3 Modular Brain Agent with Plasticity - Block 1 ===
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import numpy as np
29
+ import random
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from collections import deque
32
+ from torchvision import datasets, transforms
33
+
34
+ # === Plastic Synapse Mechanisms ===
35
+ class PlasticLinear(nn.Module):
36
+ def __init__(self, in_features, out_features, plasticity_type="hebbian", learning_rate=0.01):
37
+ super().__init__()
38
+ self.in_features = in_features
39
+ self.out_features = out_features
40
+ self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
41
+ self.bias = nn.Parameter(torch.zeros(out_features))
42
+ self.plasticity_type = plasticity_type
43
+ self.eta = learning_rate
44
+ self.trace = torch.zeros(out_features, in_features)
45
+ self.register_buffer('prev_y', torch.zeros(out_features))
46
+
47
+ def forward(self, x):
48
+ y = F.linear(x, self.weight, self.bias)
49
+ if self.training:
50
+ x_detached = x.detach()
51
+ y_detached = y.detach()
52
+ if self.plasticity_type == "hebbian":
53
+ hebb = torch.einsum('bi,bj->ij', y_detached, x_detached) / x.size(0)
54
+ self.trace = (1 - self.eta) * self.trace + self.eta * hebb
55
+ with torch.no_grad():
56
+ self.weight += self.trace
57
+ elif self.plasticity_type == "stdp":
58
+ dy = y_detached - self.prev_y
59
+ stdp = torch.einsum('bi,bj->ij', dy, x_detached) / x.size(0)
60
+ self.trace = (1 - self.eta) * self.trace + self.eta * stdp
61
+ with torch.no_grad():
62
+ self.weight += self.trace
63
+ self.prev_y = y_detached.clone()
64
+ return y
65
+
66
+ # === Spiking Surrogate Functions and Base Neurons ===
67
+ class SpikeFunction(torch.autograd.Function):
68
+ @staticmethod
69
+ def forward(ctx, input):
70
+ ctx.save_for_backward(input)
71
+ return (input > 0).float()
72
+
73
+ @staticmethod
74
+ def backward(ctx, grad_output):
75
+ input, = ctx.saved_tensors
76
+ return grad_output * (abs(input) < 1).float()
77
+
78
+ spike_fn = SpikeFunction.apply
79
+
80
+ class LIFNeuron(nn.Module):
81
+ def __init__(self, tau=2.0):
82
+ super().__init__()
83
+ self.tau = tau
84
+ self.mem = 0
85
+
86
+ def forward(self, x):
87
+ decay = torch.exp(torch.tensor(-1.0 / self.tau))
88
+ self.mem = self.mem * decay + x
89
+ out = spike_fn(self.mem - 1.0)
90
+ self.mem = self.mem * (1.0 - out.detach())
91
+ return out
92
+
93
+ # === Adaptive LIF Neuron ===
94
+ class AdaptiveLIF(nn.Module):
95
+ def __init__(self, size, tau=2.0, beta=0.2):
96
+ super().__init__()
97
+ self.size = size
98
+ self.tau = tau
99
+ self.beta = beta
100
+ self.mem = torch.zeros(size)
101
+ self.thresh = torch.ones(size)
102
+
103
+ def forward(self, x):
104
+ decay = torch.exp(torch.tensor(-1.0 / self.tau))
105
+ self.mem = self.mem * decay + x
106
+ out = spike_fn(self.mem - self.thresh)
107
+ self.thresh = self.thresh + self.beta * out
108
+ self.mem = self.mem * (1.0 - out.detach())
109
+ return out
110
+
111
+ # === Relay Layer with Attention ===
112
+ class RelayLayer(nn.Module):
113
+ def __init__(self, dim, heads=4):
114
+ super().__init__()
115
+ self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, batch_first=True)
116
+ self.lif = LIFNeuron()
117
+
118
+ def forward(self, x):
119
+ attn_out, _ = self.attn(x, x, x)
120
+ return self.lif(attn_out)
121
+
122
+ # === Working Memory ===
123
+ class WorkingMemory(nn.Module):
124
+ def __init__(self, input_dim, hidden_dim):
125
+ super().__init__()
126
+ self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
127
+
128
+ def forward(self, x):
129
+ out, _ = self.lstm(x)
130
+ return out[:, -1]
131
+
132
+ # === Place Cell Grid ===
133
+ class PlaceGrid(nn.Module):
134
+ def __init__(self, grid_size=10, embedding_dim=64):
135
+ super().__init__()
136
+ self.embedding = nn.Embedding(grid_size**2, embedding_dim)
137
+
138
+ def forward(self, index):
139
+ return self.embedding(index)
140
+
141
+ # === Mirror Comparator ===
142
+ class MirrorComparator(nn.Module):
143
+ def __init__(self, dim):
144
+ super().__init__()
145
+ self.cos = nn.CosineSimilarity(dim=1)
146
+
147
+ def forward(self, x1, x2):
148
+ return self.cos(x1, x2).unsqueeze(1)
149
+
150
+ # === Neuroendocrine Module ===
151
+ class NeuroendocrineModulator(nn.Module):
152
+ def __init__(self, input_dim, hidden_dim):
153
+ super().__init__()
154
+ self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
155
+
156
+ def forward(self, x):
157
+ out, _ = self.lstm(x)
158
+ return out[:, -1]
159
+
160
+ # === Autonomic Feedback Module ===
161
+ class AutonomicFeedback(nn.Module):
162
+ def __init__(self, input_dim):
163
+ super().__init__()
164
+ self.feedback = nn.Linear(input_dim, input_dim)
165
+
166
+ def forward(self, x):
167
+ return torch.tanh(self.feedback(x))
168
+
169
+ # === Replay Buffer ===
170
+ class ReplayBuffer:
171
+ def __init__(self, capacity=1000):
172
+ self.buffer = deque(maxlen=capacity)
173
+
174
+ def add(self, inputs, labels, task):
175
+ self.buffer.append((inputs, labels, task))
176
+
177
+ def sample(self, batch_size):
178
+ indices = random.sample(range(len(self.buffer)), batch_size)
179
+ batch = [self.buffer[i] for i in indices]
180
+ inputs, labels, tasks = zip(*batch)
181
+ return inputs, labels, tasks
182
+
183
+ # === Full Modular Brain Agent with Plasticity ===
184
+ class ModularBrainAgent(nn.Module):
185
+ def __init__(self, input_dims, hidden_dim, output_dims):
186
+ super().__init__()
187
+ self.vision_encoder = nn.Linear(input_dims['vision'], hidden_dim)
188
+ self.language_encoder = nn.Linear(input_dims['language'], hidden_dim)
189
+ self.numeric_encoder = nn.Linear(input_dims['numeric'], hidden_dim)
190
+
191
+ # Plastic synapses (Hebbian and STDP)
192
+ self.connect_sensory_to_relay = PlasticLinear(hidden_dim * 3, hidden_dim, plasticity_type='hebbian')
193
+ self.relay_layer = RelayLayer(hidden_dim)
194
+ self.connect_relay_to_inter = PlasticLinear(hidden_dim, hidden_dim, plasticity_type='stdp')
195
+
196
+ self.interneuron = AdaptiveLIF(hidden_dim)
197
+ self.memory = WorkingMemory(hidden_dim, hidden_dim)
198
+ self.place = PlaceGrid(grid_size=10, embedding_dim=hidden_dim)
199
+ self.comparator = MirrorComparator(hidden_dim)
200
+ self.emotion = NeuroendocrineModulator(hidden_dim, hidden_dim)
201
+ self.feedback = AutonomicFeedback(hidden_dim)
202
+
203
+ self.task_heads = nn.ModuleDict({
204
+ task: nn.Linear(hidden_dim, out_dim)
205
+ for task, out_dim in output_dims.items()
206
+ })
207
+
208
+ self.replay = ReplayBuffer()
209
+
210
+ def forward(self, inputs, task, position_idx=None):
211
+ v = self.vision_encoder(inputs['vision'])
212
+ l = self.language_encoder(inputs['language'])
213
+ n = self.numeric_encoder(inputs['numeric'])
214
+
215
+ sensory_cat = torch.cat([v, l, n], dim=-1)
216
+ z = self.connect_sensory_to_relay(sensory_cat)
217
+
218
+ z = self.relay_layer(z.unsqueeze(1)).squeeze(1)
219
+ z = self.connect_relay_to_inter(z)
220
+ z = self.interneuron(z)
221
+
222
+ m = self.memory(z.unsqueeze(1))
223
+ p = self.place(position_idx if position_idx is not None else torch.tensor([0]))
224
+ e = self.emotion(z.unsqueeze(1))
225
+ f = self.feedback(z)
226
+
227
+ combined = z + m + p + e + f
228
+ out = self.task_heads[task](combined)
229
+ return out
230
+
231
+ def remember(self, inputs, labels, task):
232
+ self.replay.add(inputs, labels, task)
233
+
234
+ # === Main Test Block ===
235
+ if __name__ == "__main__":
236
+ input_dims = {'vision': 32, 'language': 16, 'numeric': 8}
237
+ output_dims = {'classification': 5, 'regression': 1, 'binary': 1}
238
+ agent = ModularBrainAgent(input_dims, hidden_dim=64, output_dims=output_dims)
239
+
240
+ tasks = list(output_dims.keys())
241
+
242
+ for step in range(250):
243
+ task = random.choice(tasks)
244
+ inputs = {
245
+ 'vision': torch.randn(1, 32),
246
+ 'language': torch.randn(1, 16),
247
+ 'numeric': torch.randn(1, 8)
248
+ }
249
+ labels = torch.randint(0, output_dims[task], (1,)) if task == 'classification' else torch.randn(1, output_dims[task])
250
+ output = agent(inputs, task)
251
+ loss = F.cross_entropy(output, labels) if task == 'classification' else F.mse_loss(output, labels)
252
+ print(f"Step {step:02d} | Task: {task:13s} | Loss: {loss.item():.4f}")