--- license: mit language: - en tags: - physics - PDEs - surrogate --- # PDE-Transformer This repository contains pre-trained weights of **PDE-Transformer**, a transformer-based foundation model for physics simulations on regular grids. PDE-Transformer combines recent architectural improvements of diffusion transformers with adjustments specific for large-scale physical simulations to provide a scalable and versatile general-purpose architecture for physics modeling. ## 🌐 Project Links - **Project Webpage**: [https://tum-pbs.github.io/pde-transformer/landing.html](https://tum-pbs.github.io/pde-transformer/landing.html) - **Paper**: [Efficient and Versatile Transformers for Physics Simulations](https://arxiv.org/abs/2505.24717v1) - **GitHub Repository**: [https://github.com/tum-pbs/pde-transformer](https://github.com/tum-pbs/pde-transformer) ## 📝 Model Description PDE-Transformer is designed to efficiently process and predict the evolution of physical systems described by partial differential equations (PDEs). It can handle multiple types of PDEs, different resolutions, domain extents, boundary conditions, and includes deep conditioning mechanisms for PDE- and task-specific information. Key features: - **Multi-scale architecture** with token down- and upsampling for efficient modeling - **Shifted window attention** for improved scaling to high-resolution data - **Mixed Channel (MC) and Separate Channel (SC)** representations for handling multiple physical quantities - **Flexible conditioning mechanism** for PDE parameters, boundary conditions, and simulation metadata - **Pre-training and fine-tuning capabilities** for transfer learning across different physics domains ## Installation PDE-Transformer models and additional tools for training/inference can be installed via pip: ```bash pip install pdetransformer ``` ## 🏗️ Architecture Variants PDE-Transformer comes in three different model sizes with separate channel (SC) and mixed channel (MC) variants: ### How To Load Pretrained Models PDETransformer can be loaded via ```python from pdetransformer.core.mixed_channels import PDETransformer import torch # Load pre-trained model subfolder = 'mc-s' model = PDETransformer.from_pretrained('thuerey-group/pde-transformer', subfolder=subfolder).cuda() # For physics simulation x = torch.randn((1,2,256,256), dtype=torch.float32).cuda() predictions = model(x) ``` The model variant can be chosen via the subfolder, see the following list of pretrained models. For more information, see the [documentation](https://tum-pbs.github.io/pde-transformer/). ### Available Models | Model | Channels | Size | Hidden Dim | Heads | Parameters | Training Epochs | Model Size | |-------|----------|------|------------|-------|------------|----------------|------------| | **SC-S** | Separate | Small | 96 | 4 | ~46M | 100 | ~133MB | | **SC-B** | Separate | Base | 192 | 8 | ~178M | 100 | ~522MB | | **SC-L** | Separate | Large | 384 | 16 | ~701M | 100 | ~2.07GB | | **MC-S** | Mixed | Small | 96 | 4 | ~33M | 100 | ~187MB | | **MC-B** | Mixed | Base | 192 | 8 | ~130M | 100 | ~716MB | | **MC-L** | Mixed | Large | 384 | 16 | ~518M | 100 | ~2.81GB | ### Model Specifications of Pretrained Models - **Separate Channel (SC)**: Embeds different physical channels independently with channel-wise axial attention. Number of input/outputs channels is variable. - **Mixed Channel (MC)**: Embeds all physical channels within the same token representation. Using 2 input/output channels. - **Patch Size**: Embeds 4×4 patch into spatio-temporal token. - **Window Size**: 8×8 for windowed attention - **Boundary Conditions**: Supports both periodic and non-periodic boundary conditions ## Citation If you use PDE-Transformer in your research, please cite: ```bibtex @article{holzschuh2024pde, title={PDE-Transformer: Efficient and Versatile Transformers for Physics Simulations}, author={Holzschuh, Benjamin and Liu, Qiang and Kohl, Georg and Thuerey, Nils}, booktitle = {Forty-second International Conference on Machine Learning, {ICML} 2025, Vancouver, Canada, July 13-19, 2025}, year = {2025} } ``` ## 📄 License This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. --- **Note**: This is a research project from the Technical University of Munich (TUM) Physics-based Simulation Group. For questions and support, please refer to the GitHub repository or contact the authors.