BioClinical ModernBERT: an example of continued pre-training of ModernBERT

Community Article Published September 10, 2025

Table of Contents

  1. Overview
  2. Hardware requirements
  3. Environment setup
  4. Data preparation (CSV → tokenized MDS)
  5. Merge & index MDS datasets
  6. Configure training (YAML)
  7. Checkpoints (download & load)
  8. Launch training
  9. Adaptation & downstream tasks
  10. Citation

Overview

Following the release of BioClinical ModernBERT, we've received feedback asking for instructions on how to perform continued pre-training of ModernBERT with data from other domains. This guide walks through the steps we used.

In addition to showing how BioClinical ModernBERT itself was trained, this guide also explains how you can adapt the same workflow to your own biomedical or clinical datasets, extending BioClinical ModernBERT to your own data.

Note: We focus on continued pre-training (Masked Language Modeling). From the resulting encoder, you can fine-tune for downstream tasks (classification, NER, embeddings, etc.) using standard BERT recipes.

Hardware requirements

  • GPUs: ModernBERT benefits from FlashAttention (FA2 or FA3). Optimized GPUs include NVIDIA T4, A10, L4, A100, H100, RTX 3090/4090.  

  • What we used: 8× NVIDIA H100 (SXM5).

Environment setup

The ModernBERT repo already contains the files needed to run the Composer training framework. Start by cloning it:

git clone https://github.com/AnswerDotAI/ModernBERT.git
cd ModernBERT

And install the bert24 environment as indicated in the ModernBERT repo:

conda env create -f environment.yaml
# if the conda environment errors out set channel priority to flexible:
# conda config --set channel_priority flexible
conda activate bert24
# if using H100s clone and build flash attention 3
# git clone https://github.com/Dao-AILab/flash-attention.git
# cd flash-attention/hopper
# python setup.py install
# install flash attention 2 (model uses FA3+FA2 or just FA2 if FA3 isn't supported)
pip install "flash_attn==2.6.3" --no-build-isolation
# or download a precompiled wheel from https://github.com/Dao-AILab/flash-attention/releases/tag/v2.6.3
# or limit the number of parallel compilation jobs
# MAX_JOBS=8 pip install "flash_attn==2.6.3" --no-build-isolation

Data preparation (CSV → tokenized MDS)

The Composer framework supports different types of datasets. Following the ModernBERT repository's recommendations for high throughput, our implementation leverages tokenized MDS NoStreamingDataset.

You can find the script used to convert CSV files to tokenized MDS datasets in our GitHub repository. Feel free to adapt it to different input sources as needed.

Starting with the following folder structure:

└── data
    └── convert_dataset_to_mds.py
    └── train
    │   └── dataset1.csv
    │   └── dataset2.csv
    └── eval
        └── dataset3.csv

You can call convert_dataset_to_mds.py for each dataset.

The repository details the use of each parameter. For instance, row_count is an optional parameter that indicates the number of rows the conversion script will have to process, allowing to display a processing time estimate without loading the entire dataset in memory.

python convert_dataset_to_mds.py \
  --dataset train/dataset1.csv \
  --row_count 42 \
  --out_token_counts ../tokenized_datasets/train/token_counts.txt \
  --out_dir ../tokenized_datasets/train/

python convert_dataset_to_mds.py \
  --dataset train/dataset2.csv \
  --row_count 42 \
  --out_token_counts ../tokenized_datasets/train/token_counts.txt \
  --out_dir ../tokenized_datasets/train/

python convert_dataset_to_mds.py \
  --dataset eval/dataset3.csv \
  --row_count 42 \
  --out_token_counts ../tokenized_datasets/eval/token_counts.txt \
  --out_dir ../tokenized_datasets/eval/

This will result in the following folder structure:

└── tokenized_datasets
    └── train
    │   └── dataset1
    │   │   └── index.json
    │   │   └── shard.00000.mds
    │   │   └── ...
    │   └── dataset2
    │   │   └── index.json
    │   │   └── shard.00000.mds
    │   │   └── ...
    └── eval
        └── dataset3
            └── index.json
            └── shard.00000.mds
            └── ...

As you can see, this results in one MDS folder per dataset. This means that you can experiment with different data mixtures without having to re-tokenize everything - simply move around the MDS folders.

Merge & index MDS datasets

The MDS folders now need to be merged in order to build a train and an eval MDS folder. This operation creates an index.json file in tokenized_datasets/train/ and tokenized_datasets/eval/. This is done simply using the streaming library: in a notebook or python file, you can run

from streaming.base.util import merge_index

# The second argument (True) rewrites the split index
merge_index("tokenized_datasets/train", True)
merge_index("tokenized_datasets/eval", True)

Which will result in the folder structure:

└── tokenized_datasets
    └── train
    │   └── index.json
    │   └── dataset1
    │   │   └── index.json
    │   │   └── shard.00000.mds
    │   │   └── ...
    │   └── dataset2
    │   │   └── index.json
    │   │   └── shard.00000.mds
    │   │   └── ...
    └── eval
        └── index.json
        └── dataset3
            └── index.json
            └── shard.00000.mds
            └── ...

Now your datasets are ready to be used for continued pre-training! Pre-tokenizing the data is especially helpful if you plan on running several training configurations, as it avoids redundancy over runs.

Configure training (YAML)

The Composer framework used by ModernBERT leverages yaml configuration files. Let's go step-by-step over the config file we used for the phase 1 of BioClinical ModernBERT base:

We start by specifying the data path. data_local should link to the folder containing the train and eval splits. So in the example described above, you would replace data/All_tokenized_mds by tokenized_datasets (or the path pointing to it). Because we are using a NoStreaming dataset, no need to write anything in data_remote.

data_local: data/All_tokenized_mds
data_remote: # If blank, files must be present in data_local

No need to change anything here for the phase 1 - but note that in phase 2, when we will specialize the model on a specific data, we will change the MLM probability.

max_seq_len: 8192
tokenizer_name: answerdotai/ModernBERT-base
mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance
count_padding_tokens: false

Feel free to change the run_name, the rest is pretty much untouched from the config file shared in the ModernBERT repo.

# Run Name
run_name: phase1_all_3ep_constant

# Model
model:
  name: flex_bert
  pretrained_model_name: bert-base-uncased
  tokenizer_name: ${tokenizer_name}
  disable_train_metrics: true
  # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object
  # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure
  # the model_config settings match the architecture of the existing model
  model_config:
    vocab_size: 50368
    init_method: full_megatron
    num_hidden_layers: 22
    hidden_size: 768
    intermediate_size: 1152
    num_attention_heads: 12 # to have head size of 64
    attention_layer: rope
    attention_probs_dropout_prob: 0.0
    attn_out_bias: false
    attn_out_dropout_prob: 0.1
    attn_qkv_bias: false
    bert_layer: prenorm
    embed_dropout_prob: 0.0
    embed_norm: true
    final_norm: true
    skip_first_prenorm: true
    embedding_layer: sans_pos
    loss_function: fa_cross_entropy
    loss_kwargs:
      reduction: mean
    mlp_dropout_prob: 0.0
    mlp_in_bias: false
    mlp_layer: glu
    mlp_out_bias: false
    normalization: layernorm
    norm_kwargs:
      eps: 1e-5
      bias: false
    hidden_act: gelu
    head_pred_act: gelu
    activation_function: gelu # better safe than sorry
    padding: unpadded
    rotary_emb_dim: null
    rotary_emb_base: 160000.0
    rotary_emb_scale_base: null
    rotary_emb_interleaved: false
    local_attn_rotary_emb_base: 10000.0
    local_attn_rotary_emb_dim: null
    allow_embedding_resizing: true
    sliding_window: 128
    global_attn_every_n_layers: 3
    unpad_embeddings: true
    compile_model: true
    masked_prediction: true

For the dataloaders, the name has to be set to text even though we are using tokenized data (probably a leftover from the Nomic codebase). The split field has to be set to the subfolder you want to use, so in our example from before either train or eval.

# Dataloaders
train_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: train
    tokenizer_name: ${tokenizer_name}
    max_seq_len: ${max_seq_len}
    shuffle: true
    mlm_probability: ${mlm_probability}
    streaming: false
    shuffle_seed: 2998
  drop_last: true
  num_workers: 6
  sequence_packing: true


eval_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: eval
    tokenizer_name: ${tokenizer_name}
    max_seq_len: ${max_seq_len}
    shuffle: false
    mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison
    streaming: false
  drop_last: false
  num_workers: 3
  sequence_packing: false

You will need to adjust the scheduler depending on the phase you're in: this constant scheduler is useful to continue the pretraining from a pre-decay checkpoint, and you can find an example of a $1-sqrt$ decay in our phase 2 config file. We recommend keeping the optimizer parameters as is for continued pre-training to avoid cold restarts.

# Optimization
scheduler:
  name: constant_with_warmup
  t_warmup: 0tok
  t_max: ${max_duration}
  
optimizer:
  name: decoupled_stableadamw
  lr: 3e-4 # keep same as modernbert
  betas:
  - 0.9
  - 0.98
  eps: 1.0e-06
  weight_decay: 1.0e-5 # Amount of weight decay regularization
  filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases
  log_grad_norm: true

The max_duration parameter simply corresponds to how long you want to train it for. Note that we hard-coded the number of tokens we wanted to train for, but MosaicML provides a variety of ways to represent and track training time which could be better suited for your projects. Similarly, you might want to change the eval_interval parameter to adapt it to the size of your dataset and can encode this value differently.

max_duration: 160_453_000_000tok ## note: 3x all data
eval_interval: 3000ba  ## note: was 4000, more granular evaluation to adjust for smaller dataset

We trained BioClinical ModernBERT using a server equipped with 8 NVIDIA H100 SXM5 GPUs, and therefore did not need to modify the batch / microbatch sizes from ModernBERT's config files. But if this configuration exceeds your available VRAM, you can increase the microbatch size. Note that the global batch sizes are set following the formula: $global_train_batch_size = n_{GPUs} * batch\ size$. ModernBERT was trained with a batch size of 72 for base and 77 for large (during the stable training phase). That means that if you use a different number of GPUs, you will have to adjust global_train_batch_size and global_eval_batch_size.

global_train_batch_size: 576
global_eval_batch_size: 1024

# System
seed: 17
device_eval_batch_size: 128
device_train_microbatch_size: 12
precision: amp_bf16

In the last part of the configuration file, you can adjust console_log_interval and save_interval as needed to adjust for the length of your training. As noted as a comment, you need to make sure that you select a checkpoint corresponding to a stable learning rate phase for the parameter load_path. This avoids cold restart issues.

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 150ba

callbacks:
  speed_monitor:
    window_size: 50
  lr_monitor: {}
  scheduled_gc: {}
  log_grad_norm:
    batch_log_interval: 10
  packing_efficiency:
    log_interval: 10

# W&B logging
# loggers:
#   wandb:
#     project: ## TODO: add your wandb info if needed
#     entity: ## TODO: add your wandb info if needed

# Checkpoint to local filesystem or remote object store
save_interval: 3000ba
save_num_checkpoints_to_keep: -1  # Important, this cleans up checkpoints saved to DISK
save_folder: checkpoints/{run_name}

# Load from local filesystem or remote object store to
## note: This is the last checkpoint of the stable phase for ModernBERT. It's the one you will want to use if you perform continued pretraining from the ModernBERT model!
##        You can find them here: https://huggingface.co/answerdotai/ModernBERT-base-training-checkpoints
load_path: checkpoints/modernbert-base-context-extension/context-extension/ep0-ba52988-rank0.pt
autoresume: false
reset_time: true # restarts the scheduler, dataloaders, etc from step zero
restart_override: true # resets optimizer hyperparameters (LR, WD, etc), LR Scheduler, and training microbatch size from the checkpoint's values

Checkpoints (download & load)

To download the checkpoints from which to continue training (the load_path field specified just above), you can use the Hugging Face CLI tool:

pip install huggingface_hub
  • If you want to continue training from the ModernBERT checkpoints, the following script will download the last checkpoints of the stable learning rate phase for both the base and large versions of the model.
mkdir -p ~/ModernBERT/checkpoints/modernbert-base-context-extension

huggingface-cli download answerdotai/ModernBERT-base-training-checkpoints --include "context-extension/ep0-ba52988-rank0.pt" --local-dir ~/ModernBERT/checkpoints/modernbert-base-context-extension --local-dir-use-symlinks False

huggingface-cli download answerdotai/ModernBERT-large-training-checkpoints --include "context-extension/ep0-ba49552-rank0.pt" --local-dir ~/ModernBERT/checkpoints/modernbert-large-context-extension --local-dir-use-symlinks False
  • If you would like to continue training from the BioClinical ModernBERT checkpoints, this will download the last checkpoints of phase 1 for both the base and large models (phase 2 is where the learning rate decays).
mkdir -p ~/ModernBERT/checkpoints/bioclinical_modernbert_phase1

huggingface-cli download thomas-sounack/BioClinical-ModernBERT-checkpoints --include "base/phase1_base_last.pt" --local-dir ~/ModernBERT/checkpoints/bioclinical_modernbert_phase1 --local-dir-use-symlinks False

huggingface-cli download thomas-sounack/BioClinical-ModernBERT-checkpoints --include "large/phase1_large_last.pt" --local-dir ~/ModernBERT/checkpoints/bioclinical_modernbert_phase1 --local-dir-use-symlinks False

Launch Training

After saving the yaml config file to ModernBERT/yamls/, you can start the training simply by running

composer main.py yamls/your_yaml_file.yml

from the ModernBERT/ folder.

Following this approach, our training is done in two steps. Please refer to our preprint for more information and to our yaml config files for the implementation details.

Adaptation & downstream tasks

In our testing, we found that BioClinical ModernBERT adapted very well to new medical domains. For instance, it largely surpassed other encoders on out-of-the-box Masked Language Modeling using clinical data from the Dana-Farber Cancer Institute (oncology data, while the training data is largely dominated by research papers and ICU notes). We observed the same behavior on a variety of downstream tasks. We believe that this generalizability comes from the diversity of sources of text including in BioClinical ModernBERT's training data.

If you plan to use BioClinical ModernBERT on your own data and have enough text available for continued training, we recommend a similar two-step strategy to maximize performance:

  • Continue MLM training from the stable BioClinical ModernBERT checkpoints using a two-phase schedule:
    • Phase 1: maintain a stable learning rate for most of training.
    • Phase 2: apply a decaying learning rate to consolidate and specialize.
  • Fine-tune the resulting model on your downstream task (e.g., classification, NER, embeddings).

This setup also supports multiple Phase 2 “branches” for different subsets of data. For example, in a hospital setting you could first pre-train with a stable learning rate on all institutional text, then run separate Phase 2 schedules for each department. This way, models benefit from shared hospital-wide patterns (e.g., writing style, formatting conventions) before specializing on department-specific content.

Citation

If you use the BioClinical ModernBERT models or methodology, please cite our preprint:

@misc{sounack2025bioclinicalmodernbertstateoftheartlongcontext,
      title={BioClinical ModernBERT: A State-of-the-Art Long-Context Encoder for Biomedical and Clinical NLP}, 
      author={Thomas Sounack and Joshua Davis and Brigitte Durieux and Antoine Chaffin and Tom J. Pollard and Eric Lehman and Alistair E. W. Johnson and Matthew McDermott and Tristan Naumann and Charlotta Lindvall},
      year={2025},
      eprint={2506.10896},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2506.10896}, 
}

Community

Sign up or log in to comment