GeneJEPA: A Predictive World Model of the Transcriptome
GeneJEPA is a Joint-Embedding Predictive Architecture (JEPA) trained for self-supervised representation learning on scRNA-seq.
It uses a Perceiver-style encoder to handle sparse, high-dimensional gene count vectors and learns from masked block prediction.
Why? Produce compact cell embeddings you can use for clustering, transfer learning, linear probes, and downstream biological tasks.
Repository contents
This model repo intentionally contains artifacts only (no training code):
genejepa-epoch=49.ckpt
โ final PyTorch Lightning checkpoint (student encoder + predictor + EMA state, etc.)gene_metadata.parquet
โ mapping between foundation token IDs and gene identifiers used to build the embedding vocab.global_stats.json
โ globallog1p(counts)
normalization stats (mean
,std
) computed over a large sample of training data.
Model summary
- Backbone: Perceiver-style encoder over tokenized genes (identity + Fourier features of expression value)
- Latents: 512
- Dimensionality: 768
- Blocks: 24 transformer blocks on the latent array
- Heads: 12
- Masking: stochastic, block-wise targets with context complement
- Predictor: BYOL-style MLP head
- EMA teacher: maintained during training (for targets)
Default tokenizer Fourier settings:
N_f=64
,min_freq=0.1
,max_freq=100.0
,freq_scale=1.0
.
Download artifacts
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
filename="genejepa-epoch=49.ckpt")
meta_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
filename="gene_metadata.parquet")
stats_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
filename="global_stats.json")