NB-Transformer: Fast Negative Binomial GLM Parameter Estimation

Python 3.8+ PyTorch License: MIT

NB-Transformer is a fast, accurate neural network approach for Negative Binomial GLM parameter estimation, designed as a modern replacement for statistical analysis of counts. Using transformer-based attention mechanisms, it provides 14.8x speedup over classical methods while maintaining superior accuracy.

Paper: arxiv.org/abs/2508.04111

๐Ÿš€ Key Features

  • โšก Ultra-Fast: 14.8x faster than classical GLM (0.076ms vs 1.128ms per test)
  • ๐ŸŽฏ More Accurate: 47% better accuracy on log fold change estimation
  • ๐Ÿ”ฌ Complete Statistical Inference: P-values, confidence intervals, and power analysis
  • ๐Ÿ“Š Robust: 100% success rate vs 98.7% for classical methods
  • ๐Ÿง  Transformer Architecture: Attention-based modeling of variable-length sample sets
  • ๐Ÿ“ฆ Easy to Use: Simple API with pre-trained model included

๐Ÿ“ˆ Performance Benchmarks

Based on comprehensive validation with 1000+ test cases:

Method Success Rate Time (ms) ฮผ MAE ฮฒ MAE ฮฑ MAE
NB-Transformer 100.0% 0.076 0.202 0.152 0.477
Classical GLM 98.7% 1.128 0.212 0.284 0.854
Method of Moments 100.0% 0.021 0.213 0.289 0.852

Key Achievements:

  • 47% better accuracy on ฮฒ (log fold change) - the critical parameter for differential expression
  • 44% better accuracy on ฮฑ (dispersion) - essential for proper statistical inference
  • 100% convergence rate with no numerical instabilities

๐Ÿ› ๏ธ Installation

pip install nb-transformer

Or install from source:

git clone https://huggingface.co/valsv/nb-transformer
cd nb-transformer
pip install -e .

๐ŸŽฏ Quick Start

Basic Usage

from nb_transformer import load_pretrained_model

# Load the pre-trained model (downloads automatically)
model = load_pretrained_model()

# Your data: log10(CPM + 1) transformed counts
control_samples = [2.1, 1.8, 2.3, 2.0]      # 4 control samples  
treatment_samples = [1.5, 1.2, 1.7, 1.4]    # 4 treatment samples

# Get NB GLM parameters instantly
params = model.predict_parameters(control_samples, treatment_samples)

print(f"ฮผฬ‚ (base mean): {params['mu']:.3f}")           # -0.245
print(f"ฮฒฬ‚ (log fold change): {params['beta']:.3f}")   # -0.421  
print(f"ฮฑฬ‚ (log dispersion): {params['alpha']:.3f}")   # -1.832
print(f"Fold change: {np.exp(params['beta']):.2f}x")  # 0.66x (downregulated)

Complete Statistical Analysis

import numpy as np
from nb_transformer import load_pretrained_model
from nb_transformer.inference import compute_nb_glm_inference

# Load model and data
model = load_pretrained_model()
control_counts = np.array([1520, 1280, 1650, 1400])
treatment_counts = np.array([980, 890, 1100, 950]) 
control_lib_sizes = np.array([1e6, 1.1e6, 0.9e6, 1.05e6])
treatment_lib_sizes = np.array([1e6, 1.0e6, 1.1e6, 0.95e6])

# Transform to log10(CPM + 1)
control_transformed = np.log10(1e4 * control_counts / control_lib_sizes + 1)
treatment_transformed = np.log10(1e4 * treatment_counts / treatment_lib_sizes + 1)

# Get parameters
params = model.predict_parameters(control_transformed, treatment_transformed)

# Complete statistical inference
results = compute_nb_glm_inference(
    params['mu'], params['beta'], params['alpha'],
    control_counts, treatment_counts,
    control_lib_sizes, treatment_lib_sizes
)

print(f"Log fold change: {results['beta']:.3f} ยฑ {results['se_beta']:.3f}")
print(f"P-value: {results['pvalue']:.2e}")
print(f"Significant: {'Yes' if results['pvalue'] < 0.05 else 'No'}")

Quick Demo

from nb_transformer import quick_inference_example

# Run a complete example with sample data
params = quick_inference_example()

๐Ÿ”ฌ Validation & Reproducibility

This package includes three comprehensive validation scripts that reproduce all key results:

1. Accuracy Validation

Compare parameter estimation accuracy and speed across methods:

python examples/validate_accuracy.py --n_tests 1000 --output_dir results/

Expected Output:

  • Accuracy comparison plots
  • Speed benchmarks
  • Parameter estimation metrics
  • Success rate analysis

2. P-value Calibration Validation

Validate that p-values are properly calibrated under null hypothesis:

python examples/validate_calibration.py --n_tests 10000 --output_dir results/

Expected Output:

  • QQ plots for p-value uniformity
  • Statistical tests for calibration
  • False positive rate analysis
  • Calibration assessment report

3. Statistical Power Analysis

Evaluate statistical power across experimental designs and effect sizes:

python examples/validate_power.py --n_tests 1000 --output_dir results/

Expected Output:

  • Power curves by experimental design (3v3, 5v5, 7v7, 9v9)
  • Effect size analysis
  • Method comparison across designs
  • Statistical power benchmarks

๐Ÿงฎ Mathematical Foundation

Model Architecture

NB-Transformer uses a specialized transformer architecture for set-to-set comparison:

  • Input: Two variable-length sets of log-transformed expression values
  • Architecture: Pair-set transformer with intra-set and cross-set attention
  • Output: Three parameters (ฮผ, ฮฒ, ฮฑ) for Negative Binomial GLM
  • Training: 2.5M parameters trained on synthetic data with known ground truth

Statistical Inference

The model enables complete statistical inference through Fisher information:

  1. Parameter Estimation: Direct neural network prediction (ฮผฬ‚, ฮฒฬ‚, ฮฑฬ‚)
  2. Fisher Weights: Wi = mi/(1 + ฯ†mi) where mi = โ„“iexp(ฮผฬ‚ + xiฮฒฬ‚)
  3. Standard Errors: SE(ฮฒฬ‚) = โˆš[(X'WX)-1]ฮฒฮฒ
  4. Wald Statistics: W = ฮฒฬ‚ยฒ/SE(ฮฒฬ‚)ยฒ ~ ฯ‡ยฒ(1) under Hโ‚€: ฮฒ = 0
  5. P-values: Proper Type I error control validated via calibration analysis

Key Innovation

Unlike iterative maximum likelihood estimation, NB-Transformer learns the parameter mapping directly from data patterns, enabling:

  • Instant inference without convergence issues
  • Robust parameter estimation across challenging scenarios
  • Full statistical validity through Fisher information framework

๐Ÿ“Š Comprehensive Validation Results

Accuracy Across Parameter Types

Parameter NB-Transformer Classical GLM Improvement
ฮผ (base mean) 0.202 MAE 0.212 MAE 5% better
ฮฒ (log fold change) 0.152 MAE 0.284 MAE 47% better
ฮฑ (dispersion) 0.477 MAE 0.854 MAE 44% better

Statistical Power Analysis

Power analysis across experimental designs shows competitive performance:

Design Effect Size ฮฒ=1.0 Effect Size ฮฒ=2.0
3v3 samples 85% power 99% power
5v5 samples 92% power >99% power
7v7 samples 96% power >99% power
9v9 samples 98% power >99% power

P-value Calibration

Rigorous calibration validation confirms proper statistical inference:

  • Kolmogorov-Smirnov test: p = 0.127 (well-calibrated)
  • Anderson-Darling test: p = 0.089 (well-calibrated)
  • False positive rate: 5.1% at ฮฑ = 0.05 (properly controlled)

๐Ÿ—๏ธ Architecture Details

Model Specifications

  • Model Type: Pair-set transformer for NB GLM parameter estimation
  • Parameters: 2.5M trainable parameters
  • Architecture:
    • Input dimension: 128
    • Attention heads: 8
    • Self-attention layers: 3
    • Cross-attention layers: 3
    • Dropout: 0.1
  • Training: Synthetic data with online generation
  • Validation Loss: 0.4628 (v13 checkpoint)

Input/Output Specification

  • Input: Two lists of log10(CPM + 1) transformed expression values
  • Output: Dictionary with keys 'mu', 'beta', 'alpha' (all on log scale)
  • Sample Size: Handles 2-20 samples per condition (variable length)
  • Expression Range: Optimized for typical RNA-seq expression levels

๐Ÿ”ง Advanced Usage

Custom Model Loading

from nb_transformer import load_pretrained_model

# Load model on specific device
model = load_pretrained_model(device='cuda')  # or 'cpu', 'mps'

# Load custom checkpoint
model = load_pretrained_model(checkpoint_path='path/to/custom.ckpt')

Batch Processing

# Process multiple gene comparisons efficiently
from nb_transformer.method_of_moments import estimate_batch_parameters_vectorized

control_sets = [[2.1, 1.8, 2.3], [1.9, 2.2, 1.7]]  # Multiple genes
treatment_sets = [[1.5, 1.2, 1.7], [2.1, 2.4, 1.9]]

# Fast batch estimation
results = estimate_batch_parameters_vectorized(control_sets, treatment_sets)

Training Custom Models

from nb_transformer import train_dispersion_transformer, ParameterDistributions

# Define custom parameter distributions
param_dist = ParameterDistributions()
param_dist.mu_params = {'loc': -1.0, 'scale': 2.0}
param_dist.alpha_params = {'mean': -2.0, 'std': 1.0} 
param_dist.beta_params = {'prob_de': 0.3, 'std': 1.0}

# Training configuration
config = {
    'model_config': {
        'd_model': 128,
        'n_heads': 8,
        'num_self_layers': 3,
        'num_cross_layers': 3,
        'dropout': 0.1
    },
    'batch_size': 512,
    'max_epochs': 20,
    'examples_per_epoch': 100000,
    'parameter_distributions': param_dist
}

# Train model
results = train_dispersion_transformer(config)

๐Ÿ“‹ Requirements

Core Dependencies

  • Python โ‰ฅ 3.8
  • PyTorch โ‰ฅ 1.10.0
  • PyTorch Lightning โ‰ฅ 1.8.0
  • NumPy โ‰ฅ 1.21.0
  • SciPy โ‰ฅ 1.7.0

Optional Dependencies

  • Validation: statsmodels, pandas, matplotlib, scikit-learn
  • Visualization: plotnine, theme-nxn (custom plotting theme)
  • Development: pytest, flake8, black, mypy

๐Ÿงช Model Training Details

Training Data

  • Synthetic Generation: Online negative binomial data generation
  • Parameter Distributions: Based on empirical RNA-seq statistics
  • Sample Sizes: Variable 2-10 samples per condition
  • Expression Levels: Realistic RNA-seq dynamic range
  • Library Sizes: Log-normal distribution (CV ~30%)

Training Process

  • Epochs: 100 epochs
  • Batch Size: 32
  • Learning Rate: 1e-4 with ReduceLROnPlateau scheduler
  • Loss Function: Multi-task MSE loss with parameter-specific weights
  • Validation: Hold-out synthetic data with different parameter seeds

Hardware Optimization

  • Apple Silicon: Optimized for MPS (Metal Performance Shaders)
  • Multi-core CPU: Efficient multi-worker data generation
  • Memory Usage: Minimal memory footprint (~100MB model)
  • Inference Speed: Single-core CPU sufficient for real-time analysis

๐Ÿค Contributing

We welcome contributions! Please see our contributing guidelines:

  1. Bug Reports: Open issues with detailed reproduction steps
  2. Feature Requests: Propose new functionality with use cases
  3. Code Contributions: Fork, develop, and submit pull requests
  4. Validation: Run validation scripts to ensure reproducibility
  5. Documentation: Improve examples and documentation

Development Setup

git clone https://huggingface.co/valsv/nb-transformer
cd nb-transformer
pip install -e ".[dev,analysis]"

# Run tests
pytest tests/

# Run validation
python examples/validate_accuracy.py --n_tests 100

๐Ÿ“– Citation

If you use NB-Transformer in your research, please cite:

@software{svensson2025nbtransformer,
  title={NB-Transformer: Fast Negative Binomial GLM Parameter Estimation using Transformers},
  author={Svensson, Valentine},
  year={2025},
  url={https://huggingface.co/valsv/nb-transformer},
  version={1.0.0}
}

๐Ÿ“š Related Work

Transformer Applications in Biology

  • Set-based Learning: Zaheer et al. (2017). Deep Sets. NIPS.
  • Attention Mechanisms: Vaswani et al. (2017). Attention Is All You Need. NIPS.
  • Biological Applications: Rives et al. (2021). Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences. PNAS.

โš–๏ธ License

MIT License - see LICENSE file for details.

๐Ÿท๏ธ Version History

v1.0.0 (2025-08-04)

  • Initial release with pre-trained v13 model
  • Complete validation suite (accuracy, calibration, power)
  • Production-ready API with comprehensive documentation
  • Hugging Face integration for easy model distribution

๐Ÿš€ Ready to revolutionize your differential expression analysis? Install NB-Transformer today!

pip install nb-transformer

For questions, issues, or contributions, visit our Hugging Face repository or open an issue.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Evaluation results