PrxteinMPNN

A JAX/Equinox implementation of ProteinMPNN for inverse protein folding and sequence design.

Model Description

PrxteinMPNN is a message-passing neural network that generates amino acid sequences given a protein backbone structure. This implementation uses JAX and Equinox for efficient computation and functional programming patterns.

Key Features:

  • Fully modular Equinox implementation
  • JAX-based for GPU acceleration and automatic differentiation
  • Multiple pre-trained model variants (original and soluble)
  • Multiple training epochs (002, 010, 020, 030)

Available Models

All models use the same architecture with different training:

Original Models

  • original_v_48_002 - Trained for 2 epochs
  • original_v_48_010 - Trained for 10 epochs
  • original_v_48_020 - Trained for 20 epochs (recommended)
  • original_v_48_030 - Trained for 30 epochs

Soluble Models

  • soluble_v_48_002 - Trained for 2 epochs on soluble proteins
  • soluble_v_48_010 - Trained for 10 epochs on soluble proteins
  • soluble_v_48_020 - Trained for 20 epochs on soluble proteins (recommended)
  • soluble_v_48_030 - Trained for 30 epochs on soluble proteins

Installation

pip install jax equinox huggingface_hub

Usage

Basic Usage

import jax
import jax.numpy as jnp
import equinox as eqx
from huggingface_hub import hf_hub_download

# Download model from HuggingFace
model_path = hf_hub_download(
    repo_id="maraxen/prxteinmpnn",
    filename="eqx/original_v_48_020.eqx",
    repo_type="model",
)

# Create model structure (must match saved architecture)
from prxteinmpnn.eqx_new import PrxteinMPNN

key = jax.random.PRNGKey(0)
model = PrxteinMPNN(
    node_features=128,
    edge_features=128,
    hidden_features=512,
    num_encoder_layers=3,
    num_decoder_layers=3,
    vocab_size=21,
    k_neighbors=48,
    key=key,
)

# Load weights
model = eqx.tree_deserialise_leaves(model_path, model)

# Use model for inference
# ... (see full documentation for inference examples)

Using the High-Level API

from prxteinmpnn.io.weights import load_model

# Automatically downloads and loads the model
model = load_model(
    model_version="v_48_020",
    model_weights="original"
)

Model Architecture

Hyperparameters:

  • Node features: 128
  • Edge features: 128
  • Hidden features: 512
  • Encoder layers: 3
  • Decoder layers: 3
  • K-nearest neighbors: 48
  • Vocabulary size: 21 (20 amino acids + 1 unknown)

Architecture:

  • Message-passing encoder for structural features
  • Autoregressive decoder for sequence generation
  • Attention-based edge updates
  • LayerNorm and residual connections

Training Data

The models were trained on protein structures from the Protein Data Bank (PDB):

  • Original models: Standard PDB training set
  • Soluble models: Filtered for soluble, well-expressed proteins

Performance

These models achieve state-of-the-art performance on:

  • Native sequence recovery
  • Structural compatibility (predicted structure vs. designed sequence)
  • Expressibility and stability (for soluble models)

Citation

If you use PrxteinMPNN in your research, please cite the original ProteinMPNN paper:

@article{dauparas2022robust,
  title={Robust deep learning--based protein sequence design using ProteinMPNN},
  author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others},
  journal={Science},
  volume={378},
  number={6615},
  pages={49--56},
  year={2022},
  publisher={American Association for the Advancement of Science}
}

License

MIT License - See LICENSE file for details.

Links

Technical Details

File Format

Models are saved using Equinox's tree_serialise_leaves format (.eqx files), which:

  • Preserves PyTree structure
  • Ensures bit-perfect reproducibility
  • Is compatible with JAX's functional programming paradigm
  • Supports efficient serialization/deserialization

Computational Requirements

  • Memory: ~30 MB per model
  • Inference: CPU-compatible, GPU-accelerated
  • Batch processing: Supported via jax.vmap

Updates

Latest (v2.0):

  • Migrated to unified Equinox architecture
  • All models now in .eqx format
  • Improved modularity and type safety
  • Full JAX compatibility with JIT, vmap, and grad

For more information, examples, and tutorials, visit the GitHub repository.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support