zero-input-ai / zia_model.py
aditide's picture
Upload 2 files
96320a1 verified
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class ZIAModel(nn.Module):
def __init__(self, n_intents=10, d_model=128, nhead=8, num_layers=6, dim_feedforward=512):
super(ZIAModel, self).__init__()
self.d_model = d_model
# Modality-specific encoders
self.gaze_encoder = nn.Linear(2, d_model)
self.hr_encoder = nn.Linear(1, d_model)
self.eeg_encoder = nn.Linear(4, d_model)
self.context_encoder = nn.Linear(32 + 3 + 20, d_model) # Time (32) + Location (3) + Usage (20)
# Transformer
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.1, batch_first=True)
self.transformer = TransformerEncoder(encoder_layer, num_layers)
# Output layer
self.fc = nn.Linear(d_model, n_intents)
def forward(self, gaze, hr, eeg, context):
# Encode modalities
gaze_emb = self.gaze_encoder(gaze) # [batch, seq, d_model]
hr_emb = self.hr_encoder(hr.unsqueeze(-1))
eeg_emb = self.eeg_encoder(eeg)
context_emb = self.context_encoder(context)
# Fuse modalities
fused = (gaze_emb + hr_emb + eeg_emb + context_emb) / 4 # Simple averaging
# Transformer
output = self.transformer(fused)
output = output.mean(dim=1) # Pool over sequence
# Predict intent
logits = self.fc(output)
return logits
# Example usage
if __name__ == "__main__":
model = ZIAModel()
print(model)