Model Card for MrT5 Large
MrT5: Dynamic Token Merging for Efficient Byte-level Language Models
(Kallini et al., 2024)
MrT5 (MergeT5) is a more efficient variant of ByT5 (Xue et al., 2022) that integrates a token deletion mechanism in its encoder to dynamically shorten the input sequence length. After processing through a fixed number of encoder layers, a learned delete gate determines which tokens are to be removed and which are to be retained for subsequent layers. By effectively "merging" critical information from deleted tokens into a more compact sequence, MrT5 presents a solution to the practical limitations of existing byte-level models.
Citation
If you use this model, please cite the MrT5 paper:
@inproceedings{
kallini2025mrt,
title={MrT5: Dynamic Token Merging for Efficient Byte-level Language Models},
author={Julie Kallini and Shikhar Murty and Christopher D Manning and Christopher Potts and R{\'o}bert Csord{\'a}s},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=VYWBMq1L7H}
}
Also cite the ByT5 paper:
@article{xue-etal-2022-byt5,
title = "{B}y{T}5: Towards a Token-Free Future with Pre-trained Byte-to-Byte Models",
author = "Xue, Linting and
Barua, Aditya and
Constant, Noah and
Al-Rfou, Rami and
Narang, Sharan and
Kale, Mihir and
Roberts, Adam and
Raffel, Colin",
editor = "Roark, Brian and
Nenkova, Ani",
journal = "Transactions of the Association for Computational Linguistics",
volume = "10",
year = "2022",
address = "Cambridge, MA",
publisher = "MIT Press",
url = "https://aclanthology.org/2022.tacl-1.17",
doi = "10.1162/tacl_a_00461",
pages = "291--306",
}
Model Details
This is the model card for the 1.23B-parameter MrT5 Large (mrt5-large
), a more efficient variant of ByT5 Large (google/byt5-large
). This model is trained to reduce sequence lengths by ~50% on average.
- Developed by: Julie Kallini, Shikhar Murty, Christopher D. Manning, Christopher Potts, Róbert Csordás
- Model type: MrT5
- Languages: English, French, Spanish, German, Greek, Bulgarian, Russian, Turkish, Arabic, Vietnamese, Thai, Chinese, Hindi, Swahili, and Urdu
- Fine-tuned from model: google/byt5-large
- Sources for more information:
Model Architecture
MrT5 Large uses the model configuration of the standard ByT5 Large, which has a feed-forward dimensionality of 3840, a model dimensionality of 1536, 36 encoder layers, 12 decoder layers, 16 attention heads in each layer, and 1.23B total parameters.
MrT5 has an additional delete gate, which dynamically reduces the encoder sequence length. In this model, it is placed after the third encoder layer, and all subsequent layers operate on a reduced sequence. This model was trained with a deletion rate of δ=0.5, which means that the model reduces its encoder sequence length by ~50% after the third layer. MrT5’s gating mechanism only introduces an additional 3,000 parameters.
MrT5 Large is initialized from ByT5 Large and fine-tuned on the same training objective. Only MrT5's delete gate is randomly initialized before training. The other distinguishing feature of MrT5 is that it uses softmax1 in its attention mechanism.
Uses
This model is an encoder-decoder architecture designed primarily for sequence-to-sequence tasks. While it can be used as-is for exploratory or academic purposes, fine-tuning is recommended to achieve optimal performance on specific downstream tasks.
To leverage the model’s deletion feature, please use the custom MrT5Trainer available in the accompanying repository. This specialized trainer ensures that the deletion mechanism is properly maintained and integrated during fine-tuning.
Because this is a base model built for academic and research explorations, it is not intended for production-grade deployments. Users should carefully evaluate the model’s outputs, especially in any setting where reliability and robustness are critical.
Bias, Risks, and Limitations
Language models are known to exhibit various forms of social bias and may produce harmful or offensive content (Bender et al., 2021; Bommasani et al., 2022; Liang et al., 2022). Like other language models, this model may produce biased or harmful outputs. It has not been fine-tuned for safety and should be used with caution, especially in sensitive contexts.
How to Get Started with the Model
Like ByT5, MrT5 works on raw UTF-8 bytes and can be used without a tokenizer. Make sure to set trust_remote_code=True
to load the MrT5 code:
from transformers import AutoModelForSeq2SeqLM
import torch
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3 # add 3 for special tokens
labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3 # add 3 for special tokens
# Forward pass with hard deletion
loss = model(input_ids, labels=labels, hard_delete=True).loss
For batched inference and training, you can use ByT5's tokenizer class:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('google/byt5-large')
model_inputs = tokenizer(["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt")
labels = tokenizer(["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt").input_ids
# Forward pass with hard deletion
loss = model(**model_inputs, labels=labels, hard_delete=True).loss
Training Details
Training Data
For continued pre-training, we use the multilingual C4 (mC4) corpus (Raffel et al., 2020; Xue et al., 2021). MrT5 is trained on 15 typologically diverse languages: English, French, Spanish, German, Greek, Bulgarian, Russian, Turkish, Arabic, Vietnamese, Thai, Chinese, Hindi, Swahili, and Urdu. To avoid training models for multiple epochs, we ensure that the samples drawn from the mC4 corpus are sufficiently large. Additionally, we extract equal-sized samples for each language (in terms of bytes) from the mC4 training split.
Training Procedure
MrT5 is trained on the ByT5 span corruption pre-training objective. In this task, spans of tokens in unlabeled text data are replaced with a single sentinel token ID per span, and the model must fill in the missing tokens. For ByT5 and MrT5, these are spans of bytes, and the masks can potentially interfere with word boundaries.
Preprocessing
When training on the span corruption objective, we calculate the corrupted spans such that the average masked span length is 20 tokens with a noise density of 15%—that is, 15% of tokens in the sequence are masked out, following the specification outlined in the ByT5 paper.
Optimization
MrT5 is trained for 5,000 gradient steps over batches of 2^20 tokens (i.e., an encoder sequence length of 1024 with an effective batch size of 1024). We use the AdamW optimizer with an initial learning rate of 1e-4 with linear decay and no warmup.
To achieve a specific sequence length reduction rate, we use a PI controller with a target deletion ratio of δ=0.5, as described in Section 3.2 of the paper. We also use attention score regularization, as described in Appendix D of the paper.
Environmental Impact
- Hardware Type: NVIDIA A100-SXM4-80GB
- GPU Count: 4
- Hours used: ~73 hours
- Cloud Provider: Stanford NLP Cluster
Model Card Authors
Julie Kallini
[email protected]
- Downloads last month
- 61