open-human-feedback-chat / kto_pipeline.py
Jen Ben Arye
kto
6498d39
raw
history blame
9.34 kB
import torch
from dataclasses import dataclass
from accelerate import PartialState
from datasets import load_dataset, DatasetDict, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
from dataloaders.data_loader import get_oasst
from pdb import set_trace as st
import wandb
####################################
# CONFIGURATION
####################################
@dataclass
class ScriptArguments:
"""
Configuration for the script.
"""
dataset_name: str = "OpenAssistant/oasst1" # Dataset name or path
output_dir: str = "/raid/lingo/jen_ben/HF-RLHF/kto_nov_24_2_epochs" # Output directory
pretrained_model_name: str = "mistralai/Mistral-7B-v0.1" # Pretrained model name or path
checkpoint_path: str = "/raid/lingo/jen_ben/HF-RLHF/kto_nov_24_2_epochs" # Checkpoint path
push_to_hub: bool = False # Whether to push the model to the Hugging Face hub
@dataclass
class TrainingArguments(KTOConfig):
"""
Configuration for the KTO trainer.
"""
output_dir: str = "/raid/lingo/jen_ben/HF-RLHF/kto_nov_24_2_epochs"
num_train_epochs: int = 2 # did 1 epoch, then maybe try 2 epochs
per_device_train_batch_size: int = 4 # 4 is the highes that runs well.
learning_rate: float = 5e-7
lr_scheduler_type: str = "cosine"
gradient_accumulation_steps: int = 1
logging_steps: int = 10
eval_steps: int = 500
warmup_ratio: float = 0.1
bf16: bool = True
logging_first_step: bool = True
@dataclass
class ModelArguments(ModelConfig):
"""
Configuration for the model.
"""
model_name_or_path: str = "mistralai/Mistral-7B-v0.1"
use_peft: bool = True
lora_target_modules: str = "all-linear"
lora_r: int = 16
lora_alpha: int = 16
# Initialize configurations
script_args = ScriptArguments()
training_args = TrainingArguments(output_dir=script_args.output_dir)
model_args = ModelArguments(model_name_or_path=script_args.pretrained_model_name)
####################################
# HELPER FUNCTIONS
####################################
def load_model_and_tokenizer(model_args):
"""
Load a model and tokenizer from a specified path.
"""
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code
)
# Set pad token if missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Setup chat format if not present
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
return model, tokenizer
def load_and_format_oasst_dataset(tokenizer):
"""
Load, process, and format the OpenAssistant dataset into DPO-compatible format.
Args:
split (str): The dataset split to load ('train' or 'test').
tokenizer (AutoTokenizer): Tokenizer to apply chat templates.
num_proc (int, optional): Number of processes for parallel processing.
Returns:
Dataset: Processed and formatted dataset.
"""
# Load oasst dataset
train_dataset = get_oasst(split='train')
# Initialize lists for DPO dataset
dpo_train_data = {
"prompt": [],
"chosen": [],
"rejected": []
}
# Process the dataset
for prompt, key in train_dataset.data.items(): # Iterate over dataset
if hasattr(key, "pairs") and key.pairs: # Check if pairs exist
for i, j in key.pairs: # Process each preference pair
# Add prompt and corresponding chosen/rejected completions
dpo_train_data["prompt"].append(key.prompt)
dpo_train_data["chosen"].append(key.generations[i]) # Chosen generation
dpo_train_data["rejected"].append(key.generations[j]) # Rejected generation
# Convert DPO data into a Dataset
dpo_train_dataset = Dataset.from_dict(dpo_train_data)
# Wrap it in a DatasetDict
dataset_dict = DatasetDict({
"train": dpo_train_dataset
})
test_dataset = get_oasst(split='test')
dpo_test_data = {
"prompt": [],
"chosen": [],
"rejected": []
}
for prompt, key in test_dataset.data.items(): # Iterate over dataset
if hasattr(key, "pairs") and key.pairs: # Check if pairs exist
for i, j in key.pairs: # Process each preference pair
# Add prompt and corresponding chosen/rejected completions
dpo_test_data["prompt"].append(key.prompt)
dpo_test_data["chosen"].append(key.generations[i]) # Chosen generation
dpo_test_data["rejected"].append(key.generations[j]) # Rejected generation
dpo_test_dataset = Dataset.from_dict(dpo_test_data)
dataset_dict["test"] = dpo_test_dataset
# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
dataset_dict = maybe_unpair_preference_dataset(dataset_dict, num_proc=training_args.dataset_num_proc)
print(f'loaded dataset')
# Apply chat template
def format_dataset(example):
# Ensure prompt is in the correct structure
if isinstance(example["prompt"], str):
example["prompt"] = [{"role": "user", "content": example["prompt"]}]
elif isinstance(example["prompt"], list):
# If it's already a list, ensure each element has the "role" and "content" keys
for item in example["prompt"]:
if "role" not in item or "content" not in item:
raise ValueError(f"Each item in 'prompt' must have 'role' and 'content': {item}")
# Ensure completion is in the correct structure
if isinstance(example["completion"], str):
example["completion"] = [{"role": "assistant", "content": example["completion"]}] # Wrap as a list of dictionaries
elif isinstance(example["completion"], list):
# If it's already a list, ensure each element has the "role" and "content" keys
for item in example["completion"]:
if "role" not in item or "content" not in item:
raise ValueError(f"Each item in 'completion' must have 'role' and 'content': {item}")
# Now apply the chat template
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
return example
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
dataset = dataset_dict.map(format_dataset, num_proc=training_args.dataset_num_proc)
return dataset
####################################
# MAIN LOGIC
####################################
def main():
# Initialize wandb
wandb.init(project="kto")
# Load models and tokenizer
print("Loading models and tokenizer...")
model, tokenizer = load_model_and_tokenizer(model_args)
ref_model, _ = load_model_and_tokenizer(model_args)
print("Models and tokenizer loaded.")
# Load and process datasets
print("Loading, processing, and formatting dataset...")
dataset = load_and_format_oasst_dataset(
tokenizer=tokenizer,
)
# Initialize trainer
print("Initializing trainer...")
trainer = KTOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)
# Training
print("Starting training...")
trainer.train()
print("Training completed.")
# Evaluation
print("Evaluating model...")
metrics = trainer.evaluate()
print(f"Metrics: {metrics}")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Log metrics to wandb
wandb.log({
"epoch": metrics.get("epoch"),
"grad_norm": metrics.get("grad_norm"),
"kl": metrics.get("kl"),
"learning_rate": metrics.get("learning_rate"),
"logits/chosen": metrics.get("logits/chosen"),
"logits/rejected": metrics.get("logits/rejected"),
"logps/chosen": metrics.get("logps/chosen"),
"logps/rejected": metrics.get("logps/rejected"),
"loss": metrics.get("loss"),
"rewards/chosen": metrics.get("rewards/chosen"),
"rewards/margins": metrics.get("rewards/margins"),
"rewards/rejected": metrics.get("rewards/rejected"),
"step": metrics.get("step")
})
# Save model and optionally push to hub
trainer.save_model(training_args.output_dir)
if script_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
print("Process completed.")
# Finish wandb run
wandb.finish()
if __name__ == "__main__":
main()