|
|
import torch |
|
|
import torch.nn as nn |
|
|
from .mamba_block import MambaBlock |
|
|
from .graph_sequencer import GraphSequencer, PositionalEncoder |
|
|
|
|
|
class GraphMamba(nn.Module): |
|
|
""" |
|
|
Production Graph-Mamba model |
|
|
Device-safe implementation with dynamic handling |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
self.config = config |
|
|
self.d_model = config['model']['d_model'] |
|
|
self.n_layers = config['model']['n_layers'] |
|
|
self.dropout = config['model']['dropout'] |
|
|
self.ordering_strategy = config['ordering']['strategy'] |
|
|
|
|
|
|
|
|
self.input_proj = None |
|
|
|
|
|
|
|
|
self.pos_encoder = PositionalEncoder() |
|
|
self.pos_embed = nn.Linear(11, self.d_model) |
|
|
|
|
|
|
|
|
self.mamba_layers = nn.ModuleList([ |
|
|
MambaBlock( |
|
|
d_model=self.d_model, |
|
|
d_state=config['model']['d_state'], |
|
|
d_conv=config['model']['d_conv'], |
|
|
expand=config['model']['expand'] |
|
|
) |
|
|
for _ in range(self.n_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.layer_norms = nn.ModuleList([ |
|
|
nn.LayerNorm(self.d_model) |
|
|
for _ in range(self.n_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.dropout_layer = nn.Dropout(self.dropout) |
|
|
|
|
|
|
|
|
self.sequencer = GraphSequencer() |
|
|
|
|
|
|
|
|
self.classifier = None |
|
|
|
|
|
def _init_input_proj(self, input_dim, device): |
|
|
"""Initialize input projection dynamically""" |
|
|
if self.input_proj is None: |
|
|
self.input_proj = nn.Linear(input_dim, self.d_model).to(device) |
|
|
|
|
|
def _init_classifier(self, num_classes, device): |
|
|
"""Initialize classifier dynamically""" |
|
|
if self.classifier is None: |
|
|
self.classifier = nn.Linear(self.d_model, num_classes).to(device) |
|
|
|
|
|
def forward(self, x, edge_index, batch=None): |
|
|
""" |
|
|
Forward pass with device-safe handling |
|
|
|
|
|
Args: |
|
|
x: Node features (num_nodes, input_dim) |
|
|
edge_index: Edge connectivity (2, num_edges) |
|
|
batch: Batch assignment (num_nodes,) - optional |
|
|
""" |
|
|
num_nodes = x.size(0) |
|
|
input_dim = x.size(1) |
|
|
device = x.device |
|
|
|
|
|
|
|
|
self.to(device) |
|
|
|
|
|
|
|
|
self._init_input_proj(input_dim, device) |
|
|
|
|
|
|
|
|
h = self.input_proj(x) |
|
|
|
|
|
if batch is None: |
|
|
|
|
|
h = self._process_single_graph(h, edge_index) |
|
|
else: |
|
|
|
|
|
h = self._process_batch(h, edge_index, batch) |
|
|
|
|
|
return h |
|
|
|
|
|
def _process_single_graph(self, h, edge_index): |
|
|
"""Process a single graph - device safe""" |
|
|
num_nodes = h.size(0) |
|
|
device = h.device |
|
|
|
|
|
|
|
|
edge_index = edge_index.to(device) |
|
|
|
|
|
|
|
|
if self.ordering_strategy == "spectral": |
|
|
order = self.sequencer.spectral_ordering(edge_index, num_nodes) |
|
|
elif self.ordering_strategy == "degree": |
|
|
order = self.sequencer.degree_ordering(edge_index, num_nodes) |
|
|
elif self.ordering_strategy == "community": |
|
|
order = self.sequencer.community_ordering(edge_index, num_nodes) |
|
|
else: |
|
|
order = self.sequencer.bfs_ordering(edge_index, num_nodes) |
|
|
|
|
|
|
|
|
order = order.to(device) |
|
|
|
|
|
|
|
|
seq_pos, distances = self.pos_encoder.encode_positions(h, edge_index, order) |
|
|
seq_pos = seq_pos.to(device) |
|
|
distances = distances.to(device) |
|
|
|
|
|
pos_features = torch.cat([seq_pos, distances], dim=1) |
|
|
pos_embed = self.pos_embed(pos_features) |
|
|
|
|
|
|
|
|
h_ordered = h[order] + pos_embed[order] |
|
|
h_ordered = h_ordered.unsqueeze(0) |
|
|
|
|
|
|
|
|
for mamba, ln in zip(self.mamba_layers, self.layer_norms): |
|
|
|
|
|
h_ordered = h_ordered + self.dropout_layer(mamba(ln(h_ordered))) |
|
|
|
|
|
|
|
|
h_out = h_ordered.squeeze(0) |
|
|
|
|
|
|
|
|
inverse_order = torch.argsort(order) |
|
|
h_final = h_out[inverse_order] |
|
|
|
|
|
return h_final |
|
|
|
|
|
def _process_batch(self, h, edge_index, batch): |
|
|
"""Process batched graphs - device safe""" |
|
|
device = h.device |
|
|
batch = batch.to(device) |
|
|
edge_index = edge_index.to(device) |
|
|
|
|
|
batch_size = batch.max().item() + 1 |
|
|
outputs = [] |
|
|
|
|
|
for b in range(batch_size): |
|
|
|
|
|
mask = batch == b |
|
|
batch_h = h[mask] |
|
|
|
|
|
|
|
|
edge_mask = mask[edge_index[0]] & mask[edge_index[1]] |
|
|
batch_edges = edge_index[:, edge_mask] |
|
|
|
|
|
if batch_edges.shape[1] > 0: |
|
|
|
|
|
node_indices = torch.where(mask)[0] |
|
|
node_map = torch.zeros(h.size(0), dtype=torch.long, device=device) |
|
|
node_map[node_indices] = torch.arange(batch_h.size(0), device=device) |
|
|
batch_edges_local = node_map[batch_edges] |
|
|
else: |
|
|
|
|
|
batch_edges_local = torch.empty((2, 0), dtype=torch.long, device=device) |
|
|
|
|
|
|
|
|
batch_output = self._process_single_graph(batch_h, batch_edges_local) |
|
|
outputs.append(batch_output) |
|
|
|
|
|
|
|
|
h_out = torch.zeros_like(h) |
|
|
for b, output in enumerate(outputs): |
|
|
mask = batch == b |
|
|
h_out[mask] = output |
|
|
|
|
|
return h_out |
|
|
|
|
|
def get_graph_embedding(self, h, batch=None): |
|
|
"""Get graph-level representation""" |
|
|
if batch is None: |
|
|
|
|
|
return h.mean(dim=0, keepdim=True) |
|
|
else: |
|
|
|
|
|
device = h.device |
|
|
batch = batch.to(device) |
|
|
batch_size = batch.max().item() + 1 |
|
|
|
|
|
graph_embeddings = [] |
|
|
for b in range(batch_size): |
|
|
mask = batch == b |
|
|
if mask.any(): |
|
|
graph_emb = h[mask].mean(dim=0) |
|
|
graph_embeddings.append(graph_emb) |
|
|
else: |
|
|
graph_embeddings.append(torch.zeros(h.size(1), device=device)) |
|
|
|
|
|
return torch.stack(graph_embeddings) |