|
import torch |
|
import torch.nn as nn |
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer |
|
|
|
class ZIAModel(nn.Module): |
|
def __init__(self, n_intents=10, d_model=128, nhead=8, num_layers=6, dim_feedforward=512): |
|
super(ZIAModel, self).__init__() |
|
self.d_model = d_model |
|
|
|
|
|
self.gaze_encoder = nn.Linear(2, d_model) |
|
self.hr_encoder = nn.Linear(1, d_model) |
|
self.eeg_encoder = nn.Linear(4, d_model) |
|
self.context_encoder = nn.Linear(32 + 3 + 20, d_model) |
|
|
|
|
|
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.1, batch_first=True) |
|
self.transformer = TransformerEncoder(encoder_layer, num_layers) |
|
|
|
|
|
self.fc = nn.Linear(d_model, n_intents) |
|
|
|
def forward(self, gaze, hr, eeg, context): |
|
|
|
gaze_emb = self.gaze_encoder(gaze) |
|
hr_emb = self.hr_encoder(hr.unsqueeze(-1)) |
|
eeg_emb = self.eeg_encoder(eeg) |
|
context_emb = self.context_encoder(context) |
|
|
|
|
|
fused = (gaze_emb + hr_emb + eeg_emb + context_emb) / 4 |
|
|
|
|
|
output = self.transformer(fused) |
|
output = output.mean(dim=1) |
|
|
|
|
|
logits = self.fc(output) |
|
return logits |
|
|
|
|
|
if __name__ == "__main__": |
|
model = ZIAModel() |
|
print(model) |
|
|