SuperOcc: Toward Cohesive Temporal Modeling for Superquadric-based Occupancy Prediction
Paper β’ 2601.15644 β’ Published
YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
A modular, end-to-end pipeline for reconstructing semantically consistent 3D traffic scenes from a single RGB image. Designed for near real-time inference (β₯15 FPS on RTX 3090) with all GNN components under 500K parameters.
RGB Image (HΓWΓ3)
β
βΌ
βββββββββββββββββββββββββββ
β Stage 1: Input β
β Augmentation β β 5-channel tensor [RGB + Positional + Edge]
β β’ Positional Encoding β
β β’ Sobel/Canny Edges β
βββββββββββ¬ββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββ
β Stage 2: Segmentation β
β β’ Lightweight UNet β β Semantic map S (HΓWΓK)
β β’ Edge Weighting β β S'(x,y) = S(x,y) * (1 + Ξ±*C(x,y))
β β’ Boundary Head (SBCB) β
βββββββββββ¬ββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββ
β Stage 3: Primitives β
β β’ Connected Components β β Cuboids, Cylinders, Cones, Planes
β β’ PCA-based Fitting β β Scene Graph (nodes + edges)
β β’ Graph Construction β
βββββββββββ¬ββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββ
β Stage 4: GNN β
β β’ GraphSAGE / GATv2 β β Refined relational features
β β’ Edge Feature Inject β β Improved spatial consistency
β β’ LayerNorm + Dropout β
βββββββββββ¬ββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββ
β Stage 5: Point Cloud β
β β’ Surface Sampling β β 2K-20K 3D points
β β’ Gaussian Noise β β Class/Instance/Primitive labels
β β’ PLY Export β β Optional GNN features per point
βββββββββββββββββββββββββββ
pip install torch torchvision torch_geometric scipy scikit-learn numpy
import torch
from traffic3d.models.pipeline import Traffic3DPipeline
# Initialize pipeline
pipeline = Traffic3DPipeline(
num_classes=19, # Cityscapes classes
base_ch=32, # Lightweight UNet (4.3M params)
gnn_type='sage', # or 'gat', 'hybrid'
edge_method='sobel', # or 'canny'
points_per_primitive=512,
)
# Forward pass
rgb = torch.randint(0, 256, (1, 3, 512, 1024), dtype=torch.uint8)
results = pipeline(rgb, training=False)
# Access outputs
segmentation = results['seg_outputs']['segmentation'] # [1, 512, 1024]
primitives = results['primitives'][0] # List of Primitive objects
point_cloud = results['point_clouds'][0] # PointCloudOutput
# Save point cloud
from traffic3d.models.point_cloud import PointCloudGenerator
PointCloudGenerator.save_ply(point_cloud, 'scene.ply')
| Channel | Description | Purpose |
|---|---|---|
| 0-2 | RGB (normalized) | Visual features |
| 3 | Positional Encoding P(x,y) | Vertical depth prior (top=far, bottom=near) |
| 4 | Edge Confidence C(x,y) | Boundary detection for edge weighting |
Lightweight UNet with edge weighting and auxiliary boundary supervision (SBCB-style, zero inference overhead):
S'(x,y) = S(x,y) * (1 + Ξ± * C(x,y))L_total = L_ce_edge + Ξ» * L_boundary (Ξ»=0.4)| Object Type | Primitive | Fitting Method |
|---|---|---|
| Vehicles/Buildings | Cuboid | PCA-based orientation |
| Pedestrians | Cylinder | Bounding extent |
| Trees | Cone | Bounding extent |
| Road/Sky | Plane | PCA normal estimation |
[class_embedding(16), centroid(3), size(3), orientation(4)][distance, adjacency_flag, relative_position(3)]| Model | Architecture | Parameters | Description |
|---|---|---|---|
| GraphSAGE | EdgeAwareSAGEConv Γ 2 | ~29K | Custom MessagePassing with edge injection |
| GATv2 | GATv2Conv (4-head + 1-head) | ~29K | Dynamic attention with native edge_dim |
| Hybrid | SAGE + GAT + learned gate | ~62K | Automatic blending of both approaches |
from traffic3d.models.pipeline import Traffic3DPipeline, Traffic3DTrainer
pipeline = Traffic3DPipeline(num_classes=19)
trainer = Traffic3DTrainer(pipeline, device=torch.device('cuda'))
trainer.phase1_pretrain_segmentation(train_loader, epochs=30, lr=1e-3)
trainer.phase2_finetune_edge_weighted(train_loader, epochs=15, lr=5e-4, lambda_boundary=0.4)
trainer.phase3_train_gnn(graph_dataset, epochs=50, lr=1e-3)
trainer.phase4_end_to_end(train_loader, epochs=10, lr=1e-4)
| Loss | Formula | Use |
|---|---|---|
| EdgeWeightedCE | CE * (1 + Ξ±*C(x,y)) |
Segmentation with boundary focus |
| BoundaryLoss | Binary CE on boundary (on-the-fly GT) | Boundary refinement |
| CombinedSegLoss | L_ce + Ξ» * L_boundary (Ξ»=0.4) |
Full segmentation training |
| RelationalConsistency | Contrastive on GNN features | Scene graph training |
| ChamferDistance | Bidirectional nearest-neighbor | 3D quality evaluation |
| Metric | Target | Description |
|---|---|---|
| 3D IoU | ~0.68 | 3D bounding box overlap |
| Centroid L2 | ~0.49m | Primitive position accuracy |
| Edge Graph Accuracy | ~78% | Scene graph correctness (F1) |
| Chamfer Distance | ~0.041 | Point cloud reconstruction quality |
| Boundary IoU | +15% | Improvement over non-edge baseline |
| FPS | β₯15 | RTX 3090 real-time throughput |
from traffic3d.utils.evaluation import AblationStudy
ablation = AblationStudy(device=torch.device('cuda'))
results = ablation.run_all()
# Ablates: Ξ», GNN architecture, edge method, points per primitive
print(ablation.summary_table())
GNN=sage | GNN: 28,736 | Under 500K: β | Total Pipeline: 4.36M
GNN=gat | GNN: 29,312 | Under 500K: β | Total Pipeline: 4.36M
GNN=hybrid | GNN: 62,016 | Under 500K: β | Total Pipeline: 4.39M
| Dataset | Use | Classes |
|---|---|---|
| Cityscapes | Primary training | 19 semantic |
| BDD100K | Robustness testing | 19 semantic |
| CARLA | Synthetic 3D GT supervision | Configurable |
traffic3d/
βββ __init__.py
βββ models/
β βββ input_augmentation.py # Stage 1: Positional + Edge encoding
β βββ segmentation.py # Stage 2: Lightweight UNet + edge weighting
β βββ primitive_extraction.py # Stage 3: Primitives + scene graph
β βββ gnn_refinement.py # Stage 4: GraphSAGE / GATv2 / Hybrid GNN
β βββ point_cloud.py # Stage 5: Surface sampling + PLY export
β βββ pipeline.py # End-to-end pipeline + 4-phase trainer
βββ losses/
β βββ __init__.py # EdgeCE, BoundaryLoss, ChamferDistance, etc.
βββ utils/
β βββ evaluation.py # Metrics, Evaluator, AblationStudy
βββ data/ and configs/
MIT License