serpent / core /graph_mamba.py
kfoughali's picture
Update core/graph_mamba.py
1bdb453 verified
raw
history blame
7.29 kB
import torch
import torch.nn as nn
from .mamba_block import MambaBlock
from .graph_sequencer import GraphSequencer, PositionalEncoder
class GraphMamba(nn.Module):
"""
Production Graph-Mamba model
Device-safe implementation with dynamic handling
"""
def __init__(self, config):
super().__init__()
self.config = config
self.d_model = config['model']['d_model']
self.n_layers = config['model']['n_layers']
self.dropout = config['model']['dropout']
self.ordering_strategy = config['ordering']['strategy']
# Input projection (dynamic input dimension)
self.input_proj = None # Will be initialized on first forward
# Positional encoding
self.pos_encoder = PositionalEncoder()
self.pos_embed = nn.Linear(11, self.d_model) # 1 + 10 distances
# Mamba layers
self.mamba_layers = nn.ModuleList([
MambaBlock(
d_model=self.d_model,
d_state=config['model']['d_state'],
d_conv=config['model']['d_conv'],
expand=config['model']['expand']
)
for _ in range(self.n_layers)
])
# Layer norms
self.layer_norms = nn.ModuleList([
nn.LayerNorm(self.d_model)
for _ in range(self.n_layers)
])
# Dropout
self.dropout_layer = nn.Dropout(self.dropout)
# Graph sequencer
self.sequencer = GraphSequencer()
# Classification head (for demo)
self.classifier = None
def _init_input_proj(self, input_dim, device):
"""Initialize input projection dynamically"""
if self.input_proj is None:
self.input_proj = nn.Linear(input_dim, self.d_model).to(device)
def _init_classifier(self, num_classes, device):
"""Initialize classifier dynamically"""
if self.classifier is None:
self.classifier = nn.Linear(self.d_model, num_classes).to(device)
def forward(self, x, edge_index, batch=None):
"""
Forward pass with device-safe handling
Args:
x: Node features (num_nodes, input_dim)
edge_index: Edge connectivity (2, num_edges)
batch: Batch assignment (num_nodes,) - optional
"""
num_nodes = x.size(0)
input_dim = x.size(1)
device = x.device
# Move all components to correct device
self.to(device)
# Initialize input projection if needed
self._init_input_proj(input_dim, device)
# Project input features
h = self.input_proj(x) # (num_nodes, d_model)
if batch is None:
# Single graph processing
h = self._process_single_graph(h, edge_index)
else:
# Batch processing
h = self._process_batch(h, edge_index, batch)
return h
def _process_single_graph(self, h, edge_index):
"""Process a single graph - device safe"""
num_nodes = h.size(0)
device = h.device
# Ensure edge_index is on correct device
edge_index = edge_index.to(device)
# Get ordering
if self.ordering_strategy == "spectral":
order = self.sequencer.spectral_ordering(edge_index, num_nodes)
elif self.ordering_strategy == "degree":
order = self.sequencer.degree_ordering(edge_index, num_nodes)
elif self.ordering_strategy == "community":
order = self.sequencer.community_ordering(edge_index, num_nodes)
else: # default to BFS
order = self.sequencer.bfs_ordering(edge_index, num_nodes)
# Ensure order is on correct device
order = order.to(device)
# Add positional encoding
seq_pos, distances = self.pos_encoder.encode_positions(h, edge_index, order)
seq_pos = seq_pos.to(device)
distances = distances.to(device)
pos_features = torch.cat([seq_pos, distances], dim=1) # (num_nodes, 11)
pos_embed = self.pos_embed(pos_features)
# Reorder nodes for sequential processing
h_ordered = h[order] + pos_embed[order] # Add positional encoding
h_ordered = h_ordered.unsqueeze(0) # (1, num_nodes, d_model)
# Process through Mamba layers
for mamba, ln in zip(self.mamba_layers, self.layer_norms):
# Pre-norm residual connection
h_ordered = h_ordered + self.dropout_layer(mamba(ln(h_ordered)))
# Restore original order
h_out = h_ordered.squeeze(0) # (num_nodes, d_model)
# Create inverse mapping
inverse_order = torch.argsort(order)
h_final = h_out[inverse_order]
return h_final
def _process_batch(self, h, edge_index, batch):
"""Process batched graphs - device safe"""
device = h.device
batch = batch.to(device)
edge_index = edge_index.to(device)
batch_size = batch.max().item() + 1
outputs = []
for b in range(batch_size):
# Extract subgraph
mask = batch == b
batch_h = h[mask]
# Get edges for this graph
edge_mask = mask[edge_index[0]] & mask[edge_index[1]]
batch_edges = edge_index[:, edge_mask]
if batch_edges.shape[1] > 0:
# Reindex edges to local indices
node_indices = torch.where(mask)[0]
node_map = torch.zeros(h.size(0), dtype=torch.long, device=device)
node_map[node_indices] = torch.arange(batch_h.size(0), device=device)
batch_edges_local = node_map[batch_edges]
else:
# Empty graph
batch_edges_local = torch.empty((2, 0), dtype=torch.long, device=device)
# Process subgraph
batch_output = self._process_single_graph(batch_h, batch_edges_local)
outputs.append(batch_output)
# Reconstruct full batch
h_out = torch.zeros_like(h)
for b, output in enumerate(outputs):
mask = batch == b
h_out[mask] = output
return h_out
def get_graph_embedding(self, h, batch=None):
"""Get graph-level representation"""
if batch is None:
# Single graph - mean pooling
return h.mean(dim=0, keepdim=True)
else:
# Batched graphs - manual pooling to avoid dependencies
device = h.device
batch = batch.to(device)
batch_size = batch.max().item() + 1
graph_embeddings = []
for b in range(batch_size):
mask = batch == b
if mask.any():
graph_emb = h[mask].mean(dim=0)
graph_embeddings.append(graph_emb)
else:
graph_embeddings.append(torch.zeros(h.size(1), device=device))
return torch.stack(graph_embeddings)