Model Card for MrT5 Small

MrT5: Dynamic Token Merging for Efficient Byte-level Language Models
(Kallini et al., 2024)

MrT5 (MergeT5) is a more efficient variant of ByT5 (Xue et al., 2022) that integrates a token deletion mechanism in its encoder to dynamically shorten the input sequence length. After processing through a fixed number of encoder layers, a learned delete gate determines which tokens are to be removed and which are to be retained for subsequent layers. By effectively "merging" critical information from deleted tokens into a more compact sequence, MrT5 presents a solution to the practical limitations of existing byte-level models.

Citation

If you use this model, please cite the MrT5 paper:

@inproceedings{
    kallini2025mrt,
    title={MrT5: Dynamic Token Merging for Efficient Byte-level Language Models},
    author={Julie Kallini and Shikhar Murty and Christopher D Manning and Christopher Potts and R{\'o}bert Csord{\'a}s},
    booktitle={The Thirteenth International Conference on Learning Representations},
    year={2025},
    url={https://openreview.net/forum?id=VYWBMq1L7H}
}

Also cite the ByT5 paper:

@article{xue-etal-2022-byt5,
    title = "{B}y{T}5: Towards a Token-Free Future with Pre-trained Byte-to-Byte Models",
    author = "Xue, Linting  and
      Barua, Aditya  and
      Constant, Noah  and
      Al-Rfou, Rami  and
      Narang, Sharan  and
      Kale, Mihir  and
      Roberts, Adam  and
      Raffel, Colin",
    editor = "Roark, Brian  and
      Nenkova, Ani",
    journal = "Transactions of the Association for Computational Linguistics",
    volume = "10",
    year = "2022",
    address = "Cambridge, MA",
    publisher = "MIT Press",
    url = "https://aclanthology.org/2022.tacl-1.17",
    doi = "10.1162/tacl_a_00461",
    pages = "291--306",
}

Model Details

This is the model card for the 300M-parameter MrT5 Small (mrt5-small), a more efficient variant of ByT5 Small (google/byt5-small). This model is trained to reduce sequence lengths by ~50% on average.

  • Developed by: Julie Kallini, Shikhar Murty, Christopher D. Manning, Christopher Potts, Róbert Csordás
  • Model type: MrT5
  • Languages: English, French, Spanish, German, Greek, Bulgarian, Russian, Turkish, Arabic, Vietnamese, Thai, Chinese, Hindi, Swahili, and Urdu
  • Fine-tuned from model: google/byt5-small
  • Sources for more information:

Model Architecture

MrT5 Small uses the model configuration of the standard ByT5 Small, which has a feed-forward dimensionality of 3584, a model dimensionality of 1472, 12 encoder layers, 4 decoder layers, 6 attention heads in each layer, and 300M total parameters.

MrT5 has an additional delete gate, which dynamically reduces the encoder sequence length. In this model, it is placed after the third encoder layer, and all subsequent layers operate on a reduced sequence. This model was trained with a deletion rate of δ=0.5, which means that the model reduces its encoder sequence length by ~50% after the third layer. MrT5’s gating mechanism only introduces an additional 3,000 parameters.

MrT5 Small is initialized from ByT5 Small and fine-tuned on the same training objective. Only MrT5's delete gate is randomly initialized before training. The other distinguishing feature of MrT5 is that it uses softmax1 in its attention mechanism.

Uses

This model is an encoder-decoder architecture designed primarily for sequence-to-sequence tasks. While it can be used as-is for exploratory or academic purposes, fine-tuning is recommended to achieve optimal performance on specific downstream tasks.

To leverage the model’s deletion feature, please use the custom MrT5Trainer available in the accompanying repository. This specialized trainer ensures that the deletion mechanism is properly maintained and integrated during fine-tuning.

Because this is a base model built for academic and research explorations, it is not intended for production-grade deployments. Users should carefully evaluate the model’s outputs, especially in any setting where reliability and robustness are critical.

Bias, Risks, and Limitations

Language models are known to exhibit various forms of social bias and may produce harmful or offensive content (Bender et al., 2021; Bommasani et al., 2022; Liang et al., 2022). Like other language models, this model may produce biased or harmful outputs. It has not been fine-tuned for safety and should be used with caution, especially in sensitive contexts.

How to Get Started with the Model

Like ByT5, MrT5 works on raw UTF-8 bytes and can be used without a tokenizer. Make sure to set trust_remote_code=True to load the MrT5 code:

from transformers import AutoModelForSeq2SeqLM
import torch

model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-small', trust_remote_code=True)

input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3  # add 3 for special tokens
labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3  # add 3 for special tokens

# Forward pass with hard deletion
loss = model(input_ids, labels=labels, hard_delete=True).loss

For batched inference and training, you can use ByT5's tokenizer class:

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-small', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('google/byt5-small')

model_inputs = tokenizer(["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt")
labels = tokenizer(["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt").input_ids

# Forward pass with hard deletion
loss = model(**model_inputs, labels=labels, hard_delete=True).loss

Training Details

Training Data

For continued pre-training, we use the multilingual C4 (mC4) corpus (Raffel et al., 2020; Xue et al., 2021). MrT5 is trained on 15 typologically diverse languages: English, French, Spanish, German, Greek, Bulgarian, Russian, Turkish, Arabic, Vietnamese, Thai, Chinese, Hindi, Swahili, and Urdu. To avoid training models for multiple epochs, we ensure that the samples drawn from the mC4 corpus are sufficiently large. Additionally, we extract equal-sized samples for each language (in terms of bytes) from the mC4 training split.

Training Procedure

MrT5 is trained on the ByT5 span corruption pre-training objective. In this task, spans of tokens in unlabeled text data are replaced with a single sentinel token ID per span, and the model must fill in the missing tokens. For ByT5 and MrT5, these are spans of bytes, and the masks can potentially interfere with word boundaries.

Preprocessing

When training on the span corruption objective, we calculate the corrupted spans such that the average masked span length is 20 tokens with a noise density of 15%—that is, 15% of tokens in the sequence are masked out, following the specification outlined in the ByT5 paper.

Optimization

MrT5 is trained for 5,000 gradient steps over batches of 2^20 tokens (i.e., an encoder sequence length of 1024 with an effective batch size of 1024). We use the AdamW optimizer with an initial learning rate of 1e-4 with linear decay and no warmup.

To achieve a specific sequence length reduction rate, we use a PI controller with a target deletion ratio of δ=0.5, as described in Section 3.2 of the paper. We also use attention score regularization, as described in Appendix D of the paper.

Environmental Impact

  • Hardware Type: NVIDIA RTX 6000 Ada Generation
  • GPU Count: 1
  • Hours used: ~63 hours
  • Cloud Provider: Stanford NLP Cluster

Model Card Authors

Julie Kallini
[email protected]

Downloads last month
46
Safetensors
Model size
300M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train stanfordnlp/mrt5-small