PDE-Transformer
This repository contains pre-trained weights of PDE-Transformer, a transformer-based foundation model for physics simulations on regular grids. PDE-Transformer combines recent architectural improvements of diffusion transformers with adjustments specific for large-scale physical simulations to provide a scalable and versatile general-purpose architecture for physics modeling.
π Project Links
- Project Webpage: https://tum-pbs.github.io/pde-transformer/landing.html
- Paper: Efficient and Versatile Transformers for Physics Simulations
- GitHub Repository: https://github.com/tum-pbs/pde-transformer
π Model Description
PDE-Transformer is designed to efficiently process and predict the evolution of physical systems described by partial differential equations (PDEs). It can handle multiple types of PDEs, different resolutions, domain extents, boundary conditions, and includes deep conditioning mechanisms for PDE- and task-specific information.
Key features:
- Multi-scale architecture with token down- and upsampling for efficient modeling
- Shifted window attention for improved scaling to high-resolution data
- Mixed Channel (MC) and Separate Channel (SC) representations for handling multiple physical quantities
- Flexible conditioning mechanism for PDE parameters, boundary conditions, and simulation metadata
- Pre-training and fine-tuning capabilities for transfer learning across different physics domains
Installation
PDE-Transformer models and additional tools for training/inference can be installed via pip:
pip install pdetransformer
ποΈ Architecture Variants
PDE-Transformer comes in three different model sizes with separate channel (SC) and mixed channel (MC) variants:
How To Load Pretrained Models
PDETransformer can be loaded via
from pdetransformer.core.mixed_channels import PDETransformer
import torch
# Load pre-trained model
subfolder = 'mc-s'
model = PDETransformer.from_pretrained('thuerey-group/pde-transformer', subfolder=subfolder).cuda()
# For physics simulation
x = torch.randn((1,2,256,256), dtype=torch.float32).cuda()
predictions = model(x)
The model variant can be chosen via the subfolder, see the following list of pretrained models. For more information, see the documentation.
Available Models
Model | Channels | Size | Hidden Dim | Heads | Parameters | Training Epochs | Model Size |
---|---|---|---|---|---|---|---|
SC-S | Separate | Small | 96 | 4 | ~46M | 100 | ~133MB |
SC-B | Separate | Base | 192 | 8 | ~178M | 100 | ~522MB |
SC-L | Separate | Large | 384 | 16 | ~701M | 100 | ~2.07GB |
MC-S | Mixed | Small | 96 | 4 | ~33M | 100 | ~187MB |
MC-B | Mixed | Base | 192 | 8 | ~130M | 100 | ~716MB |
MC-L | Mixed | Large | 384 | 16 | ~518M | 100 | ~2.81GB |
Model Specifications of Pretrained Models
- Separate Channel (SC): Embeds different physical channels independently with channel-wise axial attention. Number of input/outputs channels is variable.
- Mixed Channel (MC): Embeds all physical channels within the same token representation. Using 2 input/output channels.
- Patch Size: Embeds 4Γ4 patch into spatio-temporal token.
- Window Size: 8Γ8 for windowed attention
- Boundary Conditions: Supports both periodic and non-periodic boundary conditions
Citation
If you use PDE-Transformer in your research, please cite:
@article{holzschuh2024pde,
title={PDE-Transformer: Efficient and Versatile Transformers for Physics Simulations},
author={Holzschuh, Benjamin and Liu, Qiang and Kohl, Georg and Thuerey, Nils},
booktitle = {Forty-second International Conference on Machine Learning, {ICML} 2025, Vancouver, Canada, July 13-19, 2025},
year = {2025}
}
π License
This project is licensed under the MIT License. See the LICENSE file for details.
Note: This is a research project from the Technical University of Munich (TUM) Physics-based Simulation Group. For questions and support, please refer to the GitHub repository or contact the authors.
- Downloads last month
- 0