WGAN-GP model trained on the MNIST dataset using JAX in Colab.
Training Progression
Details
This model is based on WGAN-GP.
The model was trained for ~9h40m on a GCE VM instance (n1-standard-4, 1 x NVIDIA T4).
The Critic consists of 4 Convolutional Layers with strides for downsampling, and Leaky ReLU activation. The critic does not use Batch Normalization or Dropout.
The Generator consists of 4 Transposed Convolutional Layers with ReLU activation and Batch Normalization.
The learning rate was kept constant at 1e-4 for the first 50,000 steps, which was followed by cosine annealing cycles with a peak LR of 1e-3.
The Lambda (gradient penalty coefficient) used was 10 (same as the original paper).
For more details, please refer to the Colab Notebook.
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no library tag.