BERT for Bias Detection in Text
This article is a walk-through of building a binary bias classification BERT model, in 2024. It's also a sneak peak at some of the architecture we build on in the Ethical Spectacle Research GUS-Net paper, to be published in September ;).
Model🔬 | Try it out🧪 | Training ipynb 💻.
Start by Understanding BERT:
The introduction of the transformers architecture led to two particularly revolutionary models, on either side of the encoding-decoding process. One of them, you're probably familiar with... OpenAI's GPT was born by stacking decoders, while Google's BERT stacks encoders.
Encodings
The "encodings" output by the BERT architecture capture information about the context of a token (tokens/words on either side), in the encoding. Wrap your head around this before reading any further; in effect, each BERT encoding contains information on each token's meaning and the context in which they're used. These encodings can be used for NLP tasks like sentiment analysis or entity recognition by adding the correct output layer(s) for your task. These encodings (i.e. numbers representing the meaning of the sentence) are exactly what we need for classification. Here's a diagram of how BERT's self attention mechanisms capture relationships between tokens in a sequence:
BERT for Binary Classification
Since we're classifying bias, we only need one neuron in our output layer (fair is 0, bias is 1). At a high level, each classification will follow this process:
- Tokenize input string with BERT tokenizer. Outputs: token ids correlating to BERT's vocab (including CLS, SEP, UNK, PAD), and attention mask (mostly for ignoring PAD tokens).
- A fine tuned BERT will create encodings (768 fixed-size vector).
- An added linear layer will turn the encodings into a classification of "fair" or "bias" (768 -> 1 neuron).
- We'll activate the output with a sigmoid function to normalize the output (0 to 1).
There are many BERTs...
Lots of people have reimplemented BERT with various training corpa and architecture differences (ColBERT, RoBERTa, DistilBERT, etc). In this article, I'll weigh the realistic challenges of modern bias detection data science and results for various (vanilla) BERT architectures at the task of binary classification, but the same experiments could be done with a variant model with simple changes to the encoding management (pooling) for the output layer.
Understand Evaluation Metrics:
This is a deep-dive on the specific metrics and interpreted results of multiple architectures and training parameters, so let's start by going over what metrics are important in this process.
Key: TP - True Positive, TN - True Negative, FP - False Positive, FN - False Negative
- Accuracy = (TP + TN) / Total Records - Portion of correct predictions out of total classifications.
- Precision = TP / TP + FP - Proportion of positive predictions that were actually correct.
- Recall: TP / TP + FN - Proportion of true positives that our model predicted correctly.
- F1 Score: 2 x ((Precision x Recall) / (Precision + Recall)) - Harmonic mean, think of it like accuracy balanced if the dataset is weighted towards positive/negative.
- Confusion Matrix: These are helpful for visualizing the accuracy of the model and pointing us towards potential problems. As you can see in the diagram, we want to maximize the portion of predictions on a test set that fall in the TP and TN cells.
Design the Neural Network:
Think back to the classification process I talked about earlier; this pipeline will be preformed by our PyTorch forward method. Our neural network, at it's core, will look like this. Remember, we'll use a tokenizer to turn our initial text sequence into the input ids and attention mask required for the forward pass, starting with the BERT layers.
Our model's forward pass, composed of the BERT layers, our output layer, and the activation function will turn the input ids and attention mask from the tokenizer into a 0 to 1 bias classification prediction (as a probability).
Good to know: BERT-base has 12 transformer layers, with 12 attention heads per layer. These attention heads are essentially 12 self-attention mechanisms for weighing the importance of different tokens in a sequence relative to each other. Also, BERT (base) maintains 768 dimensions through all of the layers.
Training Our Model
To update the weights of our model for our classification task, in training we can measure how close our predictions were to our dataset's true labels (i.e. the loss), for use in backpropagation to calculate gradients (i.e. changes that should be made to weights), and finally, to apply those changes with an optimizer.
Loss Function (Binary Cross Entropy)
We can use a loss function to calculate how close our model's prediction was to being right, by comparing the predicted label and true label. There are many loss functions in the PyTorch.nn module you may want to checkout, and you can make your own. But for our experiments in this article we'll use trustee-old binary cross entropy loss.
Backpropagation
This is a fundamental part of machine learning where we walk backwards through the pipeline and compute the gradients. These gradients capture information on what effect changes to weights will have on the output, and can therefore be used to later optimize the layers.
Optimizer (AdamW)
The optimizer applies the learning rate to the gradients, giving us the new weights. Learning rate defines how much to update the weights during each optimizer step.
We'll use a common optimizer called AdamW, which handles helpful logic under-the-hood, such as momentum smoothing and weight decay.
Scheduler (Linear with Warmup)
To squeeze a little extra juice out of our training process, we'll use a learning rate scheduler. This updates the learning rate to be better suited to the different parts of the training.
We'll use a linear scheduler with a warmup. This means that during the first x steps, the learning rate will grow from 0 to our defined learning rate, then slowly decline. This helps avoid making big changes before the model has seen the data, and improve the loss by making more precise changes over training time. Our learning rate schedule will look something like this:
Evaluation
Our training dataset, BABEv3, comes with a test split. We can use it to check our compiled model's accuracy on a curated 1k record dataset, rather than splitting off a chunk of our training dataset.
It's also helpful because I'm pretty sure Dbias hasn't trained on this test split. If we were to split off some of the training dataset, Dbias would have a slight advantage because it trained on some of those exact examples (MBAD is ~1.7k records of BABE).
During training, you'll want to plot training loss and validation loss. The notebook is using TensorBoard, and I'm using PyTorch Lightning to log some useful metrics you should check out.
Training the Model:
As we embark to find the best fitting architecture and training parameters, we have a couple of things to consider:
Training Set Size: 3.12k records
I'm using the BABEv3 dataset to run these experiments and training (train split).
Bias/social bias datasets have a quality and size constraint due to a time and expertise intensive annotation process. The dataset we're using is about double the size of the MBAD/MBIC dataset used in Dbias, but it's still on the small side for a complex model like BERT. Because of this, we need to be careful of overfitting (training a model architecture that is too complex for the data, in effect memorizing the answers to the training set and generalizing poorly to other inputs).
Training Parameters:
If you're new to tuning models or hyper parameters, follow this guide by Andrej Karpathy. I use it all the time. These are the ones we'll focus on.
- Learning Rate - The degree of changes made to the weights in each optimizer step.
- Batch Size - The number of datapoints used in each training step (of the epoch). Typically aim between 4 and 64 (use powers of 2 for GPU core efficiency).
- Epoch - The number of "rounds" the model trains on the full dataset. We're using early stopping to prevent overfitting, and checkpoint callbacks to select the best model version, but it's an important thing to keep an eye on.
- Max Token Length - BERT will process fixed-sized vectors, meaning we need to truncate/pad our input id list and attention mask to a fixed length.
Good to know - Dropout rate, activation functions, gradient clipping.
Benchmark
We'll start off by reimplementing the Dbias model with our dataset. Raza et al. trained distilbert-base-uncased for 30 epochs with a batch size of 16, and a learning rate of 5e-5, max token length 512. Below, on the left are the results of running Dbias on the evaluation set, and on the right is our reimplementation with bert-base-uncased at 128 max len. These will be the baseline results we will compare with our experiments.
Pick a BERT
We should start by testing out a few different version of BERT to get a general sense for what size model complexity we need. Dbias used DistilBERT, a 6 layer student model of BERT, but we'll stick to the vanilla BERT models which are more easily interchangeable ;).
BERT has 4 architectures: base (12 layers, 768 features) and large (24 layers, 1024 features), cased and uncased. Since capitalization might be a feature relevant to bias, I have a feeling cased will see some improvements.
Interpreting Results: Not exactly what I was hoping for, BERT large is computationally more expensive. But we can see that using a cased model can reduce false positives, and increase false negatives (with our data). Good to know, but we can also see that's not responsible for an improvement in accuracy, because we see a drop in overall accuracy base-uncased -> base-cased.
Large-cased hit the best accuracy, maybe nuancing the increased negative bias with it's greater number of features. However, because we're overfitting so quickly (after just 3 epochs), we clearly have enough complexity with bert-base. I believe we can capture these nuances by optimizing our training rather than selecting a larger model, so we'll continue on with bert-based-uncased.
Learning Rate
Next, we'll find a base learning rate that works best for us. Learning rate tuning is often dataset specific but here's my very sophisticated process for finding one that works:
- Find a similar model, trained on a similarly sized dataset. Try the learning rate they used.
- For transfer learning, check if the original creators suggest a LR.
- Try 3e-4 ;)
- Take a guess based on the results of steps 1-3.
Interpreting Results: We're still going fast. We're early stopping in just a few epochs because the validation loss started rising instead of falling like we want it to. I tried a super low learning rate (3e-6) and it showed the validation loss reaching the same point (~0.4) over more epochs, but with lower accuracy. I take this to mean that a lower learning rate is just more carefully learning training set specific details. We can speed run it.
We'll use a 3e-4 LR, with 7 epoch warmup for the next few steps. It looks wonky on the graphs but it works, although you should slow it down for testing.
Batch Size
For each batch, we'll compute the gradients and new weights once.
I like to think of batch sizes like pixels in an image: if you're looking at 4 pixels, it might be harder to make out a pattern than 32 pixels. But in usage, our model will have a batch size of 1, so we want to keep our batch sizes small without sacrificing accuracy. To find the best batch size, keep an eye on training loss. I like to use the batch size with the lowest average training loss (before validation loss starts to rise).
Higher batch sizes will smooth gradients while lower batch sizes will retain more noise:
Overfit: For the examples above, I let the model overfit (no early stopping). This is an important test to see the lowest possible training loss our model can receive. In both of the batch sizes, the model hits a relatively low training loss of <0.1, meaning our model is theoretically capable of a validation loss this low. In reality, the validation loss will be higher than the training loss; a difference caused by an inability to generalize to new data. Minimizing this difference in training and validation losses will be our objective moving forward.
Picking Batch Size: We'll try a few different batch sizes, and pick the one that had the best accuracy before overfitting. This is typically the batch size that gives us the smallest gap between training and validation loss, to start optimizing.
Regularize
Now that we've got a better understanding the hyper parameters to use, we can explore the training process and architecture. We can apply regularization techniques such as freezing BERT layers or adding a dropout layer before the linear layer.
Freezing layers of BERT would prevent their weights from being updated, perhaps reducing the precise feature extraction that could cause overfitting to the training set.
Adding a dropout layer would randomly replace a given percentage of values from BERT's output tensor with 0. This creates "noise" which prevents the model from memorizing details.
Won't waste your time: None of the regularization techniques I described above actually improved accuracy on the test set in my (many) experiments. They're still in the notebook if you want to play around with them.
Here's the thing about bias classification, we're approaching a nuanced definition with a small dataset. What would really solve our generalization problem is more information, either in breadth or in depth. Our ESR research paper (GUS-Net) will implement approaches for improving both, but in this article we're sticking with what already exists...
Mean Pooling vs Pooler Output
Our simplest training pipeline and scheduler turned out to be the most accurate, as is often the case. Oh well, can't over engineer everything.
I have one more thing worth exploring.
Pooler Output: BERT outputs a fixed-size vector with a size of 768, called pooler_output. It does this by using the [CLS] encoding, at the beginning of the sequence, which (after passing through the layers) should contain information on the entire sequence. The pooler uses an activation function on the CLS token and outputs the fixed size vector. This is what we're currently passing into our output layer for classification.
Mean Pooling: Instead of using the built in pooler, we can take an average of the encodings. This could represent the token sequence in a different way. We can ensure it's still a 768 fixed-size vector, and train our model with the same process. After implementing it, the results are underwhelming but also show some merit. It'd be interesting to see on other datasets.
Conclusion
So... Turns out we got kinda lucky with our initial guesses at training params, but now that we've tested many training parameters, our sanity can rest easy.
Here's the take away: The dataset size is the bottleneck.
But with a dataset comparable to the BABEv3 dataset, you may find luck training BERT for sequence classification with these parameters:
- Batch Size: 16
- Learning Rate: 3e-4 (with 7ep warmup, 13ep decline)
- Epochs: 2
For testing, you should slow it down. Use a learning rate of 1.5e-6, which will give you a smoother validation and testing loss curves. Then tune your parameters like we did in the experiments section to get a low validation loss, then crank that learning rate back up.
The model I trained on those params reached 81.7% accuracy on the test set, a 13.6% improvement from the bias detection model presented in the Dbias paper.
What's Next:
This is just the tip of the iceberg for the kind of work we've been doing in our open source research group, Ethical Spectacle Research. Our research paper on bias detection combines a named entity recognition model with the architecture described in this article, for more advanced encodings, and accurate predictions. I'll also drop a blog on that.
Paper/Hackathon: We'll be publishing our ArXiv paper in September, when we'll also be hosting a hackathon to see what devs can build with a new state-of-the-art model. We'll publish our conference paper in January (fingers crossed).
Workshops: We just got a cool new venue for workshops, thanks to PHX Ventures, so I'm excited for the upcoming workshops on bias detection, synthetic data, networks of agents, and more. Check our meetup for events.