Preference Optimization for Vision Language Models with TRL
Training models to understand and predict human preferences can be incredibly complex. Traditional methods, like supervised fine-tuning, often require assigning specific labels to data, which is not cost-efficient, especially for nuanced tasks. Preference optimization is an alternative approach that can simplify this process and yield more accurate results. By focusing on comparing and ranking candidate answers rather than assigning fixed labels, preference optimization allows models to capture the subtleties of human judgment more effectively.
Preference optimization is widely used for fine-tuning language models, but it can also be applied to vision language models (VLM). We are excited to announce that the TRL library now supports direct preference optimization (DPO) for VLMs. This article will guide you through the process of training VLMs using TRL and DPO.
Preference dataset
Preference optimization requires data that captures user preferences. In the binary choice setting, each example consists of a prompt, and two candidate answers: one that is chosen and one that is rejected. The model's goal is to learn to predict the chosen answer over the rejected one. For example, you need to have samples like the following:
❔ Question: How many families?
- ❌ Rejected: The image does not provide any information about families.
- ✅ Chosen: The image shows a Union Organization table setup with 18,000 families.
Note that the chosen message is not necessarily correct. For example, the chosen response that says 18,000 families is still wrong, but it's less wrong compared to the rejected response.
For this blog post, we'll be using the openbmb/RLAIF-V-Dataset, which includes over 83,000 annotated rows. Let's take a closer look at the dataset:
>>> from datasets import load_dataset
>>> dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train[:1%]")
>>> sample = dataset[1]
>>> sample["image"].show()
>>> sample["question"]
'how many families?'
>>> sample["rejected"]
'The image does not provide any information about families.'
>>> sample["chosen"]
'The image shows a Union Organization table setup with 18,000 families.'
Our model requires both text and images as input, so the first step is to format the dataset to fit this requirement. The data should be structured to simulate a conversation between a user and an assistant. The user provides a prompt that includes an image and a question, while the assistant responds with an answer. Here's how this formatting is done:
from datasets import features
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)
def format(example):
# Prepare the input for the chat template
prompt = [
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": example["question"]}],
},
]
chosen = [
{
"role": "assistant",
"content": [{"type": "text", "text": example["chosen"]}],
},
]
rejected = [
{
"role": "assistant",
"content": [{"type": "text", "text": example["rejected"]}],
},
]
# Apply the chat template
prompt = processor.apply_chat_template(prompt, tokenize=False)
chosen = processor.apply_chat_template(chosen, tokenize=False)
rejected = processor.apply_chat_template(rejected, tokenize=False)
# Resize the image to ensure it fits within the maximum allowable
# size of the processor to prevent OOM errors.
max_size = processor.image_processor.size["longest_edge"]
example["image"].thumbnail((max_size, max_size))
return {"images": [example["image"]], "prompt": prompt, "chosen": chosen, "rejected": rejected}
# Apply the formatting function to the dataset,
# remove columns to end up with only "images", "prompt", "chosen", "rejected" columns
dataset = dataset.map(format, remove_columns=dataset.column_names)
# Make sure that the images are decoded, it prevents from storing bytes.
# More info here https://github.com/huggingface/blog/pull/2148#discussion_r1667400478
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True)) # to avoid bytes
dataset = dataset.cast(f)
Our dataset is now formatted. Let's have a look at the first example:
>>> dataset[1]
{'images': [<PIL.JpegImagePlugin.JpegImageFile image mode=L size=980x812 at 0x154505570>],
'prompt': 'User:<image>how many families?<end_of_utterance>\n',
'rejected': 'Assistant: The image does not provide any information about families.<end_of_utterance>\n',
'chosen': 'Assistant: The image shows a Union Organization table setup with 18,000 families.<end_of_utterance>\n'}
Warm up your GPUs, the dataset is ready for training!
Training
For the sake of the example, we'll be training the Idefics2-8b model, but note that the DPO implementation in TRL supports other models like Llava 1.5 and PaliGemma. More information in Section Finetuning Llava 1.5, PaliGemma and others. Before looking into the training process, we'll first ensure everything fits smoothly into memory.
How much memory do I need?
I have a GPU with 80GB of VRAM. Is it enough to train my Idefics2-8b model? Here are the calculation steps to get a rough estimate of the memory needed.
Let be the number of parameters, the precision. The following components will have to fit together in memory:
- Model to train:
- Reference model: the reference model is the same as the model to train, so it also requires
- Gradients: we train the whole model, and each parameter requires a gradient, so it requires
- Optimizer states: we use AdamW, which requires two states per parameter, so it requires
Idefics2-8b has 8 billion parameters, and we use float32
precision which requires 4 bytes per float. So the total memory required is:
Component | Calculation | Memory |
---|---|---|
Model to train | 32 GB | |
Reference model | 32 GB | |
Gradients | 32 GB | |
Optimizer states | 64 GB | |
Total | 160 GB |
This is way above my GPU's memory capacity. Fortunately, by applying techniques such as quantization and LoRA, we can significantly reduce the memory requirements and make the training feasible. Let's see how to do this.
Quantization
Quantization is a technique that reduces the precision of the model's weights and activations. Switching from float32
to bfloat16
precision halves the storage requirement per parameter from 4 bytes to 2 bytes. This optimization conserves memory while also accelerating computations, ensuring high performance with minimal compromise.
To implement bfloat16
precision in the model:
import torch
from transformers import AutoModelForVision2Seq
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/idefics2-8b", torch_dtype=torch.bfloat16)
bfloat16
precision can also be applied to the optimizer by setting bf16=True
in the training arguments:
from transformers import TrainingArguments
training_args = TrainingArguments(..., bf16=True)
LoRA
LoRA is a method that reduces the number of trainable parameters by learning pairs of rank-decomposition matrices while keeping the original weights frozen. This significantly decreases the storage needs for LLM adapted to specific tasks. LoRA is integrated in PEFT and you can set it up in no time:
from transformers import AutoModelForVision2Seq
+ from peft import get_peft_model, LoraConfig
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/idefics2-8b")
+ peft_config = LoraConfig(target_modules="all-linear")
+ model = get_peft_model(model, peft_config)
PEFT acts like a wrapper (called adaptater) around the model. This is the adapter that will be trained while the inner model is kept frozen. How much does LoRA reduce the number of trainable parameters?
>>> model.print_trainable_parameters()
trainable params: 55,348,736 || all params: 8,458,116,848 || trainable%: 0.6543860411799315
It reduces the number of trainable parameters from 8 billion to 55 million, which is a huge gap, and it will significantly reduce the memory requirements.
The new memory requirements after quantization and LoRA
Now that we have reduced the memory requirements, let's recalculate the memory needed:
Component | Calculation | Memory |
---|---|---|
Model to train | 16 GB | |
Reference model | 16 GB | |
Gradients | 0.1 GB | |
Optimizer states | 0.2 GB | |
Total | 32.3 GB |
This time, we need around 32GB of memory to finetune our Idefics2-8b model, which is much more reasonable and fits within my GPU!
For additional information on optimizing memory usage using LoRA and QLoRA, refer to the PEFT documentation or LoRA and QLoRA Google's recommendations for LLMs.
What about the batch size?
Our memory calculation isn't exact as it doesn't account for activations. Activations are the intermediate outputs of the network layers and their memory requirements depend on the model structure and batch size. Precisely calculating the memory needed for activations is challenging, so we'll rely on empirical observations.
To choose an appropriate training batch size (per_device_train_batch_size
), start with your desired batch size (e.g., 64). This will likely result in an out-of-memory (OOM) error. If it does, reduce the batch size by half and double the gradient accumulation steps (gradient_accumulation_steps
) to maintain the same effective batch size. Repeat this process until the memory fits within your GPU. In our case, we end up with a batch size of 2 and gradient accumulation steps of 32.
An additional optimization is to use gradient checkpointing (gradient_checkpointing
) to reduce the memory needed for activations. This technique trades off compute for memory by recomputing parts of the network during the backward pass. It can be enabled by setting gradient_checkpointing=True
in the training arguments.
Summary: complete training script
Now that we've set up the model, dataset, and training parameters, we're ready to train. Here's how to put everything together in a script, including some additional elements to speed up processing, like dataset_num_proc
and dataloader_num_workers
:
# dpo_idefics2-8b.py
from datasets import features, load_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor
import torch
from trl import DPOConfig, DPOTrainer
from peft import LoraConfig
def main():
# Load the model and processor
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/idefics2-8b", torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)
# Load the dataset
dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train")
def format(example):
# Prepare the input for the chat template
prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": example["question"]}]}]
chosen = [{"role": "assistant", "content": [{"type": "text", "text": example["chosen"]}]}]
rejected = [{"role": "assistant", "content": [{"type": "text", "text": example["rejected"]}]}]
# Apply the chat template
prompt = processor.apply_chat_template(prompt, tokenize=False)
chosen = processor.apply_chat_template(chosen, tokenize=False)
rejected = processor.apply_chat_template(rejected, tokenize=False)
# Resize the image to ensure it fits within the maximum allowable
# size of the processor to prevent OOM errors.
max_size = processor.image_processor.size["longest_edge"] // 2
example["image"].thumbnail((max_size, max_size))
return {"images": [example["image"]], "prompt": prompt, "chosen": chosen, "rejected": rejected}
# Apply the formatting function to the dataset
dataset = dataset.map(format, remove_columns=dataset.column_names, num_proc=32)
# Make sure that the images are decoded, it prevents from storing bytes.
# More info here https://github.com/huggingface/blog/pull/2148#discussion_r1667400478
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True))
dataset = dataset.cast(f)
# Train the model
training_args = DPOConfig(
output_dir="idefics2-8b-dpo",
bf16=True,
gradient_checkpointing=True,
per_device_train_batch_size=2,
gradient_accumulation_steps=32,
num_train_epochs=1,
dataset_num_proc=32, # tokenization will use 32 processes
dataloader_num_workers=32, # data loading will use 32 workers
logging_steps=10,
)
trainer = DPOTrainer(
model,
ref_model=None, # not needed when using peft
args=training_args,
train_dataset=dataset,
tokenizer=processor,
peft_config=LoraConfig(target_modules="all-linear"),
)
trainer.train()
if __name__ == "__main__":
main()
Let's run and wait... 🚀
accelerate launch dpo_idefics2-8b.py
Results
A few hours later, the training is complete. Let's take a look at the training curves:
In DPO, we focus on several metrics to assess the quality of the training:
- Accuracy: This metric indicates the percentage of training samples where the model is more likely to output the chosen answer rather than the rejected answer. We can see an increase in accuracy, which is a positive sign.
- Rewards: Rewards are related to the probability of an answer being chosen. For more details, refer to DPO paper, Section 5. We expect the reward for the chosen answer to be higher than for the rejected answer. To verify this, we look at the reward margin, which is the difference between the rewards for the chosen and rejected answers. An increasing reward margin, as observed here, is also a good sign.
Evaluation
Inference
With the model training complete, the next step is to evaluate its performance on some examples. This will give us a sense of how well the model has learned and how effectively it can make predictions. Here’s a script to help you evaluate the model and analyze its performance on a set of test examples:
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/idefics2-8b").to("cuda")
processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)
model.load_adapter("HuggingFaceH4/idefics2-8b-dpo-rlaif-v-v0.3") # <-- Load the adapter we've just trained
# Process
user_message = ...
image_path = ...
data = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": user_message}]}]
prompts = processor.apply_chat_template(data, add_generation_prompt=True) # add_generation_prompt=True to end the prompt with "ASSISTANT:"
images = [Image.open(image_path)]
inputs = processor(prompts, images, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Generate
generated_ids = model.generate(**inputs, max_new_tokens=500)
response_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response_text)
As mentioned above, the openbmb/RLAIF-V-Dataset is designed to reduce hallucinations. But has the fine-tuning actually reduced hallucinations? To find out, we can use the AMBER benchmark, a dataset specifically created to evaluate hallucinations in VLMs. We will report the results for Idefics2 and Idefics2+DPO on the discriminative task and compare them with other models for reference.
Accuracy | F1 | |
---|---|---|
GPT-4o | 88.8 | 91.6 |
Idefics2+DPO | 85.9 | 89.4 |
Idefics2 | 85.8 | 89.1 |
GPT-4v | 83.4 | 87.4 |
MiniGemini | 82.6 | 87.6 |
LLaVA-NeXT | 81.4 | 85.4 |
QWEN-VL | 81.9 | 86.4 |
LURE | 73.5 | 77.7 |
OPERA | 75.2 | 78.3 |
Less-is-more | 72.4 | 75.8 |
VCD | 71.8 | 74.9 |
Overall, the fine-tuned model seems to hallucinate a bit less. The training seems to have been successful!
Here are some cherry-picked examples to illustrate the model's performance:
Image | Question | Idefics2 | Idefics2+DPO |
---|---|---|---|
Are there two ships in this image? | Yes | No | |
Is the ground uneven in this image? | No | Yes | |
Is there one shovel in this image? | Yes | No |
Try it yourself and see how the model performs on your own examples!
Finetuning Llava 1.5, PaliGemma and others
At the time of writing, the DPO implementation in TRL supports Idefics2, Llava 1.5, and PaliGemma, with ongoing efforts to add support for more models. The easiest way to fine-tune these models is to use the example script provided in the TRL repository. For example, to finetune PaliGemma, you can use the following command:
accelerate launch examples/scripts/dpo_visual.py \
--dataset_name HuggingFaceH4/rlaif-v_formatted \
--model_name_or_path google/paligemma-3b-pt-224 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 32 \
--dataset_num_proc 32 \
--output_dir dpo_paligemma_rlaif-v \
--bf16 \
--torch_dtype bfloat16 \
--gradient_checkpointing \
--use_peft \
--lora_target_modules=all-linear
You can find a detailed focus on PaliGemma finetuning in the smol-vision project.
🚀🚀 Now you have everything you need to start fine-tuning your own VLMs with DPO. Share your findings, models, and datasets with the community!