YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
BERT with Flash-Attention
Installing dependencies
To run the model on GPU, you need to install Flash Attention. You may either install from pypi (which may not work with fused-dense), or from source. To install from source, clone the GitHub repository:
git clone [email protected]:Dao-AILab/flash-attention.git
The code provided here should work with commit 43950dd
.
Change to the cloned repo and install:
cd flash-attention && python setup.py install
This will compile the flash-attention kernel, which will take some time.
If you would like to use fused MLPs (e.g. to use activation checkpointing), you may install fused-dense also from source:
cd csrc/fused_dense_lib && python setup.py install
Configuration
The config adds some new parameters:
use_flash_attn
: IfTrue
, always use flash attention. IfNone
, use flash attention when GPU is available. IfFalse
, never use flash attention (works on CPU).window_size
: Size (left and right) of the local attention window. If(-1, -1)
, use global attentiondense_seq_output
: If true, we only need to pass the hidden states for the masked out token (around 15%) to the classifier heads. I set this to true for pretraining.fused_mlp
: Whether to use fused-dense. Useful to reduce VRAM in combination with activation checkpointingmlp_checkpoint_lvl
: One of{0, 1, 2}
. Increasing this increases the amount of activation checkpointing within the MLP. Keep this at 0 for pretraining and use gradient accumulation instead. For embedding training, increase this as much as needed.last_layer_subset
: If true, we only need the compute the last layer for a subset of tokens. I left this to false.use_qk_norm
: Whether or not to use QK-normalizationnum_loras
: Number of LoRAs to use when initializing aBertLoRA
model. Has no effect on other models.
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no library tag.