TRL documentation

XPO Trainer

You are viewing v0.14.0 version. A newer version v0.15.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

XPO Trainer

Overview

Exploratory Preference Optimization (XPO) was proposed in the paper Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, Corby Rosset, Ahmed Awadallah, and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the intitial model and human feedback data.

The abstract from the paper is the following:

Reinforcement learning from human feedback (RLHF) has emerged as a central tool for language model alignment. We consider online exploration in RLHF, which exploits interactive access to human or AI feedback by deliberately encouraging the model to produce diverse, maximally informative responses. By allowing RLHF to confidently stray from the pre-trained model, online exploration offers the possibility of novel, potentially super-human capabilities, but its full potential as a paradigm for language model training has yet to be realized, owing to computational and statistical bottlenecks in directly adapting existing reinforcement learning techniques. We propose a new algorithm for online exploration in RLHF, Exploratory Preference Optimization (XPO), which is simple and practical — a one-line change to (online) Direct Preference Optimization (DPO; Rafailov et al., 2023) — yet enjoys the strongest known provable guarantees and promising empirical performance. XPO augments the DPO objective with a novel and principled exploration bonus, empowering the algorithm to explore outside the support of the initial model and human feedback data. In theory, we show that XPO is provably sample-efficient and converges to a near-optimal language model policy under natural exploration conditions, irrespective of whether the initial model has good coverage. Our analysis, which builds on the observation that DPO implicitly performs a form of Q*-approximation (or, Bellman error minimization), combines previously disparate techniques from language modeling and theoretical reinforcement learning in a serendipitous fashion through the perspective of KL-regularized Markov decision processes. Empirically, we find that XPO is more sample-efficient than non-exploratory DPO variants in a preliminary evaluation.

This post-training method was contributed by Kashif Rasul, Quentin Gallouédec and Lewis Tunstall.

Quick start

This example demonstrates how to train a model using the XPO method. We use the Qwen 0.5B model as the base model and PairRMJudge as a judge. We use the prompts from the UltraFeedback dataset. You can view the prompts in the dataset here:

Below is the script to train the model:

# train_xpo.py
from datasets import load_dataset
from trl import PairRMJudge, XPOConfig, XPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

training_args = XPOConfig(output_dir="Qwen2-0.5B-XPO", logging_steps=10)
trainer = XPOTrainer(
    model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()

Execute the script using the following command:

accelerate launch train_xpo.py

Distributed across 8 GPUs, the training takes approximately 1 hour.

To see how the trained model performs, you can use the TRL Chat CLI.

$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-XPO
<quentin_gallouedec>:
What is the best programming language?

<trl-lib/Qwen2-0.5B-XPO>:
The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript. 

Expected dataset type

XPO requires a prompt-only dataset. The XPOTrainer supports both conversational and standard dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

Usage tips

Use a reward model

Instead of a judge, you can chose to use a reward model — see Reward Bench for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the trl-lib/Qwen2-0.5B-Reward model:

- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification

- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)

  trainer = XPOTrainer(
      ...
-     judge=judge,
+     reward_model=reward_model,
  )

Make sure that the SFT model and reward model use the same chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.

Encourage EOS token generation

When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the max_new_tokens argument of XPOConfig. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the missing_eos_penalty argument of XPOConfig:

training_args = XPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)

Logging Completions

To better understand your model’s behavior during training, you can log sample completions periodically using the LogCompletionsCallback.

trainer = XPOTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)

This callback logs the model’s generated completions directly to Weights & Biases.

Logged Completions

Example script

We provide an example script to train a model using the XPO method. The script is available in examples/scripts/xpo.py

To test the XPO script with the Qwen2.5 0.5B model on the UltraFeedback dataset, run the following command:

python examples/scripts/xpo.py \
    --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --judge pair_rm \
    --dataset_name trl-lib/ultrafeedback-prompt \
    --learning_rate 5.0e-7 \
    --logging_steps 25 \
    --output_dir Qwen2.5-0.5B-XPO-PairRM \
    --warmup_ratio 0.1 \
    --push_to_hub

Logged metrics

The logged metrics are as follows:

  • loss/xpo: The mean xpo part of the full loss.
  • loss/dpo: The mean dpo part of the full loss.
  • objective/kl: The mean KL divergence between the model and reference data.
  • objective/entropy: The mean entropy of the model and reference data.
  • objective/model_scores: The mean scores (according to the reward model) of the model completions.
  • objective/ref_scores: The mean scores (according to the reward model) of the reference completions.
  • objective/scores_margin: The mean score margin (according to the external reward model) between the chosen and rejected completions.
  • rewards/chosen: The mean reward (according to XPO’s DPO implicit reward model) of the chosen completions.
  • rewards/rejected: The mean reward (according to XPO’s DPO implicit reward model) of the rejected completions.
  • rewards/accuracies: The accuracies of the XPO’s implicit reward model.
  • rewards/margins: The mean reward margin (according to online DPO’s implicit reward model) between the chosen and rejected completions.
  • logps/chosen: The mean log probabilities of the chosen completions.
  • logps/rejected: The mean log probabilities of the rejected completions.
  • val/model_contain_eos_token: The amount of times the model’s output contains the eos token.
  • val/ref_contain_eos_token: The amount of times the reference’s output contains the eos token.
  • alpha: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to XPOConfig.
  • beta: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to XPOConfig.

XPOTrainer

class trl.XPOTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = Noneref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = Nonereward_model: typing.Optional[torch.nn.modules.module.Module] = Nonejudge: typing.Optional[trl.trainer.judges.BasePairwiseJudge] = Noneargs: typing.Optional[trl.trainer.xpo_config.XPOConfig] = Nonedata_collator: typing.Optional[typing.Callable] = Nonetrain_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = Noneeval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = Noneprocessing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = Nonepeft_config: typing.Optional[dict] = Nonecompute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = Nonecallbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = Noneoptimizers: tuple = (None, None)preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )

Parameters

  • model (transformers.PreTrainedModel) — The model to train, preferably an AutoModelForCausalLM.
  • ref_model (PreTrainedModelWrapper) — Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
  • reward_model (transformers.PreTrainedModel) — The reward model to score completions with, preferably an AutoModelForSequenceClassification.
  • judge (BasePairwiseJudge) — The judge to use for pairwise comparison of model completions.
  • args (XPOConfig) — The XPO config arguments to use for training.
  • data_collator (transformers.DataCollator) — The data collator to use for training. If None is specified, the default data collator (DPODataCollatorWithPadding) will be used which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
  • train_dataset (datasets.Dataset) — The dataset to use for training.
  • eval_dataset (datasets.Dataset) — The dataset to use for evaluation.
  • processing_class (PreTrainedTokenizerBase or BaseImageProcessor or FeatureExtractionMixin or ProcessorMixin, optional) — Processing class used to process the data. If provided, will be used to automatically process the inputs for the model, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model.
  • peft_config (dict) — The peft config to use for training.
  • compute_metrics (Callable[[EvalPrediction], dict], optional) — The function to use to compute the metrics. Must take a EvalPrediction and return a dictionary string to metric values.
  • callbacks (list[transformers.TrainerCallback]) — The callbacks to use for training.
  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — The optimizer and scheduler to use for training.
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — The function to use to preprocess the logits before computing the metrics.

Initialize XPOTrainer as a subclass of OnlineDPOConfig.

XPOConfig

class trl.XPOConfig

< >

( output_dir: stroverwrite_output_dir: bool = Falsedo_train: bool = Falsedo_eval: bool = Falsedo_predict: bool = Falseeval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no'prediction_loss_only: bool = Falseper_device_train_batch_size: int = 8per_device_eval_batch_size: int = 8per_gpu_train_batch_size: typing.Optional[int] = Noneper_gpu_eval_batch_size: typing.Optional[int] = Nonegradient_accumulation_steps: int = 1eval_accumulation_steps: typing.Optional[int] = Noneeval_delay: typing.Optional[float] = 0torch_empty_cache_steps: typing.Optional[int] = Nonelearning_rate: float = 5e-07weight_decay: float = 0.0adam_beta1: float = 0.9adam_beta2: float = 0.999adam_epsilon: float = 1e-08max_grad_norm: float = 1.0num_train_epochs: float = 3.0max_steps: int = -1lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear'lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory>warmup_ratio: float = 0.0warmup_steps: int = 0log_level: typing.Optional[str] = 'passive'log_level_replica: typing.Optional[str] = 'warning'log_on_each_node: bool = Truelogging_dir: typing.Optional[str] = Nonelogging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'logging_first_step: bool = Falselogging_steps: float = 500logging_nan_inf_filter: bool = Truesave_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps'save_steps: float = 500save_total_limit: typing.Optional[int] = Nonesave_safetensors: typing.Optional[bool] = Truesave_on_each_node: bool = Falsesave_only_model: bool = Falserestore_callback_states_from_checkpoint: bool = Falseno_cuda: bool = Falseuse_cpu: bool = Falseuse_mps_device: bool = Falseseed: int = 42data_seed: typing.Optional[int] = Nonejit_mode_eval: bool = Falseuse_ipex: bool = Falsebf16: bool = Falsefp16: bool = Falsefp16_opt_level: str = 'O1'half_precision_backend: str = 'auto'bf16_full_eval: bool = Falsefp16_full_eval: bool = Falsetf32: typing.Optional[bool] = Nonelocal_rank: int = -1ddp_backend: typing.Optional[str] = Nonetpu_num_cores: typing.Optional[int] = Nonetpu_metrics_debug: bool = Falsedebug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = ''dataloader_drop_last: bool = Falseeval_steps: typing.Optional[float] = Nonedataloader_num_workers: int = 0dataloader_prefetch_factor: typing.Optional[int] = Nonepast_index: int = -1run_name: typing.Optional[str] = Nonedisable_tqdm: typing.Optional[bool] = Noneremove_unused_columns: typing.Optional[bool] = Truelabel_names: typing.Optional[typing.List[str]] = Noneload_best_model_at_end: typing.Optional[bool] = Falsemetric_for_best_model: typing.Optional[str] = Nonegreater_is_better: typing.Optional[bool] = Noneignore_data_skip: bool = Falsefsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = ''fsdp_min_num_params: int = 0fsdp_config: typing.Union[dict, str, NoneType] = Nonefsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = Noneaccelerator_config: typing.Union[dict, str, NoneType] = Nonedeepspeed: typing.Union[dict, str, NoneType] = Nonelabel_smoothing_factor: float = 0.0optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch'optim_args: typing.Optional[str] = Noneadafactor: bool = Falsegroup_by_length: bool = Falselength_column_name: typing.Optional[str] = 'length'report_to: typing.Union[NoneType, str, typing.List[str]] = Noneddp_find_unused_parameters: typing.Optional[bool] = Noneddp_bucket_cap_mb: typing.Optional[int] = Noneddp_broadcast_buffers: typing.Optional[bool] = Nonedataloader_pin_memory: bool = Truedataloader_persistent_workers: bool = Falseskip_memory_metrics: bool = Trueuse_legacy_prediction_loop: bool = Falsepush_to_hub: bool = Falseresume_from_checkpoint: typing.Optional[str] = Nonehub_model_id: typing.Optional[str] = Nonehub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save'hub_token: typing.Optional[str] = Nonehub_private_repo: typing.Optional[bool] = Nonehub_always_push: bool = Falsegradient_checkpointing: bool = Falsegradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = Noneinclude_inputs_for_metrics: bool = Falseinclude_for_metrics: typing.List[str] = <factory>eval_do_concat_batches: bool = Truefp16_backend: str = 'auto'evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = Nonepush_to_hub_model_id: typing.Optional[str] = Nonepush_to_hub_organization: typing.Optional[str] = Nonepush_to_hub_token: typing.Optional[str] = Nonemp_parameters: str = ''auto_find_batch_size: bool = Falsefull_determinism: bool = Falsetorchdynamo: typing.Optional[str] = Noneray_scope: typing.Optional[str] = 'last'ddp_timeout: typing.Optional[int] = 1800torch_compile: bool = Falsetorch_compile_backend: typing.Optional[str] = Nonetorch_compile_mode: typing.Optional[str] = Nonedispatch_batches: typing.Optional[bool] = Nonesplit_batches: typing.Optional[bool] = Noneinclude_tokens_per_second: typing.Optional[bool] = Falseinclude_num_input_tokens_seen: typing.Optional[bool] = Falseneftune_noise_alpha: typing.Optional[float] = Noneoptim_target_modules: typing.Union[NoneType, str, typing.List[str]] = Nonebatch_eval_metrics: bool = Falseeval_on_start: bool = Falseuse_liger_kernel: typing.Optional[bool] = Falseeval_use_gather_object: typing.Optional[bool] = Falseaverage_tokens_across_devices: typing.Optional[bool] = Falsereward_model_path: typing.Optional[str] = Nonejudge: typing.Optional[str] = Nonemax_new_tokens: int = 64max_length: int = 512temperature: float = 0.9missing_eos_penalty: typing.Optional[float] = Nonebeta: list = <factory>loss_type: str = 'sigmoid'dataset_num_proc: typing.Optional[int] = Nonedisable_dropout: bool = Trueuse_vllm: bool = Falseds3_gather_for_generation: bool = Truealpha: list = <factory> )

Parameters

  • alpha (float or list[float], optional, defaults to 1e-5) — Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch and the last alpha is used for the rest of the epochs.

Configuration class for the XPOTrainer.

Subclass of OnlineDPOConfig we can use all its arguments and add the following:

< > Update on GitHub