mlp-mixer-gravit-c2 / README.md
parlange's picture
Upload MLP-Mixer model from experiment c2
e9adf64 verified
|
raw
history blame
5.67 kB
metadata
license: apache-2.0
tags:
  - image-classification
  - pytorch
  - timm
  - mlp-mixer
  - vision-transformer
  - transformer
  - gravitational-lensing
  - strong-lensing
  - astronomy
  - astrophysics
datasets:
  - parlange/gravit-c21-j24
metrics:
  - accuracy
  - auc
  - f1
paper:
  - title: 'GraViT: A Gravitational Lens Discovery Toolkit with Vision Transformers'
    url: https://arxiv.org/abs/2509.00226
    authors: Parlange et al.
model-index:
  - name: MLP-Mixer-c2
    results:
      - task:
          type: image-classification
          name: Strong Gravitational Lens Discovery
        dataset:
          type: common-test-sample
          name: Common Test Sample (More et al. 2024)
        metrics:
          - type: accuracy
            value: 0.7143
            name: Average Accuracy
          - type: auc
            value: 0.868
            name: Average AUC-ROC
          - type: f1
            value: 0.5146
            name: Average F1-Score

🌌 mlp-mixer-gravit-c2

πŸ”­ This model is part of GraViT: Transfer Learning with Vision Transformers and MLP-Mixer for Strong Gravitational Lens Discovery

πŸ”— GitHub Repository: https://github.com/parlange/gravit

πŸ›°οΈ Model Details

  • πŸ€– Model Type: MLP-Mixer
  • πŸ§ͺ Experiment: C2 - C21+J24-half
  • 🌌 Dataset: C21+J24
  • πŸͺ Fine-tuning Strategy: half

πŸ’» Quick Start

import torch
import timm

# Load the model directly from the Hub
model = timm.create_model(
    'hf-hub:parlange/mlp-mixer-gravit-c2',
    pretrained=True
)
model.eval()

# Example inference
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(dummy_input)
    predictions = torch.softmax(output, dim=1)
print(f"Lens probability: {predictions[0][1]:.4f}")

⚑️ Training Configuration

Training Dataset: C21+J24 (CaΓ±ameras et al. 2021 + Jaelani et al. 2024)
Fine-tuning Strategy: half

πŸ”§ Parameter πŸ“ Value
Batch Size 192
Learning Rate AdamW with ReduceLROnPlateau
Epochs 100
Patience 10
Optimizer AdamW
Scheduler ReduceLROnPlateau
Image Size 224x224
Fine Tune Mode half
Stochastic Depth Probability 0.1

πŸ“ˆ Training Curves

Combined Training Metrics

🏁 Final Epoch Training Metrics

Metric Training Validation
πŸ“‰ Loss 0.0650 0.0401
🎯 Accuracy 0.9743 0.9854
πŸ“Š AUC-ROC 0.9973 0.9991
βš–οΈ F1 Score 0.9742 0.9855

β˜‘οΈ Evaluation Results

ROC Curves and Confusion Matrices

Performance across all test datasets (a through l) in the Common Test Sample (More et al. 2024):

ROC + Confusion Matrix - Dataset A ROC + Confusion Matrix - Dataset B ROC + Confusion Matrix - Dataset C ROC + Confusion Matrix - Dataset D ROC + Confusion Matrix - Dataset E ROC + Confusion Matrix - Dataset F ROC + Confusion Matrix - Dataset G ROC + Confusion Matrix - Dataset H ROC + Confusion Matrix - Dataset I ROC + Confusion Matrix - Dataset J ROC + Confusion Matrix - Dataset K ROC + Confusion Matrix - Dataset L

πŸ“‹ Performance Summary

Average performance across 12 test datasets from the Common Test Sample (More et al. 2024):

Metric Value
🎯 Average Accuracy 0.7143
πŸ“ˆ Average AUC-ROC 0.8680
βš–οΈ Average F1-Score 0.5146

πŸ“˜ Citation

If you use this model in your research, please cite:

@misc{parlange2025gravit,
      title={GraViT: Transfer Learning with Vision Transformers and MLP-Mixer for Strong Gravitational Lens Discovery}, 
      author={RenΓ© Parlange and Juan C. Cuevas-Tello and Octavio Valenzuela and Omar de J. Cabrera-Rosas and TomΓ‘s Verdugo and Anupreeta More and Anton T. Jaelani},
      year={2025},
      eprint={2509.00226},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2509.00226}, 
}

Model Card Contact

For questions about this model, please contact the author through: https://github.com/parlange/