NB-Transformer: Fast Negative Binomial GLM Parameter Estimation
NB-Transformer is a fast, accurate neural network approach for Negative Binomial GLM parameter estimation, designed as a modern replacement for statistical analysis of counts. Using transformer-based attention mechanisms, it provides 14.8x speedup over classical methods while maintaining superior accuracy.
Paper: arxiv.org/abs/2508.04111
๐ Key Features
- โก Ultra-Fast: 14.8x faster than classical GLM (0.076ms vs 1.128ms per test)
- ๐ฏ More Accurate: 47% better accuracy on log fold change estimation
- ๐ฌ Complete Statistical Inference: P-values, confidence intervals, and power analysis
- ๐ Robust: 100% success rate vs 98.7% for classical methods
- ๐ง Transformer Architecture: Attention-based modeling of variable-length sample sets
- ๐ฆ Easy to Use: Simple API with pre-trained model included
๐ Performance Benchmarks
Based on comprehensive validation with 1000+ test cases:
Method | Success Rate | Time (ms) | ฮผ MAE | ฮฒ MAE | ฮฑ MAE |
---|---|---|---|---|---|
NB-Transformer | 100.0% | 0.076 | 0.202 | 0.152 | 0.477 |
Classical GLM | 98.7% | 1.128 | 0.212 | 0.284 | 0.854 |
Method of Moments | 100.0% | 0.021 | 0.213 | 0.289 | 0.852 |
Key Achievements:
- 47% better accuracy on ฮฒ (log fold change) - the critical parameter for differential expression
- 44% better accuracy on ฮฑ (dispersion) - essential for proper statistical inference
- 100% convergence rate with no numerical instabilities
๐ ๏ธ Installation
pip install nb-transformer
Or install from source:
git clone https://huggingface.co/valsv/nb-transformer
cd nb-transformer
pip install -e .
๐ฏ Quick Start
Basic Usage
from nb_transformer import load_pretrained_model
# Load the pre-trained model (downloads automatically)
model = load_pretrained_model()
# Your data: log10(CPM + 1) transformed counts
control_samples = [2.1, 1.8, 2.3, 2.0] # 4 control samples
treatment_samples = [1.5, 1.2, 1.7, 1.4] # 4 treatment samples
# Get NB GLM parameters instantly
params = model.predict_parameters(control_samples, treatment_samples)
print(f"ฮผฬ (base mean): {params['mu']:.3f}") # -0.245
print(f"ฮฒฬ (log fold change): {params['beta']:.3f}") # -0.421
print(f"ฮฑฬ (log dispersion): {params['alpha']:.3f}") # -1.832
print(f"Fold change: {np.exp(params['beta']):.2f}x") # 0.66x (downregulated)
Complete Statistical Analysis
import numpy as np
from nb_transformer import load_pretrained_model
from nb_transformer.inference import compute_nb_glm_inference
# Load model and data
model = load_pretrained_model()
control_counts = np.array([1520, 1280, 1650, 1400])
treatment_counts = np.array([980, 890, 1100, 950])
control_lib_sizes = np.array([1e6, 1.1e6, 0.9e6, 1.05e6])
treatment_lib_sizes = np.array([1e6, 1.0e6, 1.1e6, 0.95e6])
# Transform to log10(CPM + 1)
control_transformed = np.log10(1e4 * control_counts / control_lib_sizes + 1)
treatment_transformed = np.log10(1e4 * treatment_counts / treatment_lib_sizes + 1)
# Get parameters
params = model.predict_parameters(control_transformed, treatment_transformed)
# Complete statistical inference
results = compute_nb_glm_inference(
params['mu'], params['beta'], params['alpha'],
control_counts, treatment_counts,
control_lib_sizes, treatment_lib_sizes
)
print(f"Log fold change: {results['beta']:.3f} ยฑ {results['se_beta']:.3f}")
print(f"P-value: {results['pvalue']:.2e}")
print(f"Significant: {'Yes' if results['pvalue'] < 0.05 else 'No'}")
Quick Demo
from nb_transformer import quick_inference_example
# Run a complete example with sample data
params = quick_inference_example()
๐ฌ Validation & Reproducibility
This package includes three comprehensive validation scripts that reproduce all key results:
1. Accuracy Validation
Compare parameter estimation accuracy and speed across methods:
python examples/validate_accuracy.py --n_tests 1000 --output_dir results/
Expected Output:
- Accuracy comparison plots
- Speed benchmarks
- Parameter estimation metrics
- Success rate analysis
2. P-value Calibration Validation
Validate that p-values are properly calibrated under null hypothesis:
python examples/validate_calibration.py --n_tests 10000 --output_dir results/
Expected Output:
- QQ plots for p-value uniformity
- Statistical tests for calibration
- False positive rate analysis
- Calibration assessment report
3. Statistical Power Analysis
Evaluate statistical power across experimental designs and effect sizes:
python examples/validate_power.py --n_tests 1000 --output_dir results/
Expected Output:
- Power curves by experimental design (3v3, 5v5, 7v7, 9v9)
- Effect size analysis
- Method comparison across designs
- Statistical power benchmarks
๐งฎ Mathematical Foundation
Model Architecture
NB-Transformer uses a specialized transformer architecture for set-to-set comparison:
- Input: Two variable-length sets of log-transformed expression values
- Architecture: Pair-set transformer with intra-set and cross-set attention
- Output: Three parameters (ฮผ, ฮฒ, ฮฑ) for Negative Binomial GLM
- Training: 2.5M parameters trained on synthetic data with known ground truth
Statistical Inference
The model enables complete statistical inference through Fisher information:
- Parameter Estimation: Direct neural network prediction (ฮผฬ, ฮฒฬ, ฮฑฬ)
- Fisher Weights: Wi = mi/(1 + ฯmi) where mi = โiexp(ฮผฬ + xiฮฒฬ)
- Standard Errors: SE(ฮฒฬ) = โ[(X'WX)-1]ฮฒฮฒ
- Wald Statistics: W = ฮฒฬยฒ/SE(ฮฒฬ)ยฒ ~ ฯยฒ(1) under Hโ: ฮฒ = 0
- P-values: Proper Type I error control validated via calibration analysis
Key Innovation
Unlike iterative maximum likelihood estimation, NB-Transformer learns the parameter mapping directly from data patterns, enabling:
- Instant inference without convergence issues
- Robust parameter estimation across challenging scenarios
- Full statistical validity through Fisher information framework
๐ Comprehensive Validation Results
Accuracy Across Parameter Types
Parameter | NB-Transformer | Classical GLM | Improvement |
---|---|---|---|
ฮผ (base mean) | 0.202 MAE | 0.212 MAE | 5% better |
ฮฒ (log fold change) | 0.152 MAE | 0.284 MAE | 47% better |
ฮฑ (dispersion) | 0.477 MAE | 0.854 MAE | 44% better |
Statistical Power Analysis
Power analysis across experimental designs shows competitive performance:
Design | Effect Size ฮฒ=1.0 | Effect Size ฮฒ=2.0 |
---|---|---|
3v3 samples | 85% power | 99% power |
5v5 samples | 92% power | >99% power |
7v7 samples | 96% power | >99% power |
9v9 samples | 98% power | >99% power |
P-value Calibration
Rigorous calibration validation confirms proper statistical inference:
- Kolmogorov-Smirnov test: p = 0.127 (well-calibrated)
- Anderson-Darling test: p = 0.089 (well-calibrated)
- False positive rate: 5.1% at ฮฑ = 0.05 (properly controlled)
๐๏ธ Architecture Details
Model Specifications
- Model Type: Pair-set transformer for NB GLM parameter estimation
- Parameters: 2.5M trainable parameters
- Architecture:
- Input dimension: 128
- Attention heads: 8
- Self-attention layers: 3
- Cross-attention layers: 3
- Dropout: 0.1
- Training: Synthetic data with online generation
- Validation Loss: 0.4628 (v13 checkpoint)
Input/Output Specification
- Input: Two lists of log10(CPM + 1) transformed expression values
- Output: Dictionary with keys 'mu', 'beta', 'alpha' (all on log scale)
- Sample Size: Handles 2-20 samples per condition (variable length)
- Expression Range: Optimized for typical RNA-seq expression levels
๐ง Advanced Usage
Custom Model Loading
from nb_transformer import load_pretrained_model
# Load model on specific device
model = load_pretrained_model(device='cuda') # or 'cpu', 'mps'
# Load custom checkpoint
model = load_pretrained_model(checkpoint_path='path/to/custom.ckpt')
Batch Processing
# Process multiple gene comparisons efficiently
from nb_transformer.method_of_moments import estimate_batch_parameters_vectorized
control_sets = [[2.1, 1.8, 2.3], [1.9, 2.2, 1.7]] # Multiple genes
treatment_sets = [[1.5, 1.2, 1.7], [2.1, 2.4, 1.9]]
# Fast batch estimation
results = estimate_batch_parameters_vectorized(control_sets, treatment_sets)
Training Custom Models
from nb_transformer import train_dispersion_transformer, ParameterDistributions
# Define custom parameter distributions
param_dist = ParameterDistributions()
param_dist.mu_params = {'loc': -1.0, 'scale': 2.0}
param_dist.alpha_params = {'mean': -2.0, 'std': 1.0}
param_dist.beta_params = {'prob_de': 0.3, 'std': 1.0}
# Training configuration
config = {
'model_config': {
'd_model': 128,
'n_heads': 8,
'num_self_layers': 3,
'num_cross_layers': 3,
'dropout': 0.1
},
'batch_size': 512,
'max_epochs': 20,
'examples_per_epoch': 100000,
'parameter_distributions': param_dist
}
# Train model
results = train_dispersion_transformer(config)
๐ Requirements
Core Dependencies
- Python โฅ 3.8
- PyTorch โฅ 1.10.0
- PyTorch Lightning โฅ 1.8.0
- NumPy โฅ 1.21.0
- SciPy โฅ 1.7.0
Optional Dependencies
- Validation:
statsmodels
,pandas
,matplotlib
,scikit-learn
- Visualization:
plotnine
,theme-nxn
(custom plotting theme) - Development:
pytest
,flake8
,black
,mypy
๐งช Model Training Details
Training Data
- Synthetic Generation: Online negative binomial data generation
- Parameter Distributions: Based on empirical RNA-seq statistics
- Sample Sizes: Variable 2-10 samples per condition
- Expression Levels: Realistic RNA-seq dynamic range
- Library Sizes: Log-normal distribution (CV ~30%)
Training Process
- Epochs: 100 epochs
- Batch Size: 32
- Learning Rate: 1e-4 with ReduceLROnPlateau scheduler
- Loss Function: Multi-task MSE loss with parameter-specific weights
- Validation: Hold-out synthetic data with different parameter seeds
Hardware Optimization
- Apple Silicon: Optimized for MPS (Metal Performance Shaders)
- Multi-core CPU: Efficient multi-worker data generation
- Memory Usage: Minimal memory footprint (~100MB model)
- Inference Speed: Single-core CPU sufficient for real-time analysis
๐ค Contributing
We welcome contributions! Please see our contributing guidelines:
- Bug Reports: Open issues with detailed reproduction steps
- Feature Requests: Propose new functionality with use cases
- Code Contributions: Fork, develop, and submit pull requests
- Validation: Run validation scripts to ensure reproducibility
- Documentation: Improve examples and documentation
Development Setup
git clone https://huggingface.co/valsv/nb-transformer
cd nb-transformer
pip install -e ".[dev,analysis]"
# Run tests
pytest tests/
# Run validation
python examples/validate_accuracy.py --n_tests 100
๐ Citation
If you use NB-Transformer in your research, please cite:
@software{svensson2025nbtransformer,
title={NB-Transformer: Fast Negative Binomial GLM Parameter Estimation using Transformers},
author={Svensson, Valentine},
year={2025},
url={https://huggingface.co/valsv/nb-transformer},
version={1.0.0}
}
๐ Related Work
Transformer Applications in Biology
- Set-based Learning: Zaheer et al. (2017). Deep Sets. NIPS.
- Attention Mechanisms: Vaswani et al. (2017). Attention Is All You Need. NIPS.
- Biological Applications: Rives et al. (2021). Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences. PNAS.
โ๏ธ License
MIT License - see LICENSE file for details.
๐ท๏ธ Version History
v1.0.0 (2025-08-04)
- Initial release with pre-trained v13 model
- Complete validation suite (accuracy, calibration, power)
- Production-ready API with comprehensive documentation
- Hugging Face integration for easy model distribution
๐ Ready to revolutionize your differential expression analysis? Install NB-Transformer today!
pip install nb-transformer
For questions, issues, or contributions, visit our Hugging Face repository or open an issue.
Evaluation results
- Log Fold Change MAE on Synthetic NB GLM Dataself-reported0.152
- Inference Time (ms) on Synthetic NB GLM Dataself-reported0.076