Llama-3.2-1B-DPO

Model Details

Training Details

devices: 4 * NPU 910B-64GB
precision: bf16 mixed-precision
global_batch_size: 64

Training Hyperparameters

attn_implementation: None
beta: 0.1
bf16: True
learning_rate: 1e-6
lr_scheduler_type: cosine
per_device_train_batch_size: 8
gradient_accumulation_steps: 2
torch_dtype: bfloat16
num_train_epochs: 1
max_prompt_length: 512
max_length: 1024
warmup_ratio: 0.05

Results

init_train_loss: 0.6958
final_train_loss: 0.5375
accuracy: 0.7188
reward_margin: 0.7227

Training script

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import multiprocessing
from trl import (
    DPOConfig,
    DPOTrainer,
    ModelConfig,
    ScriptArguments,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

if __name__ == "__main__":
    parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
    script_args, training_args, model_config = parser.parse_args_and_config()

    torch_dtype = (
        model_config.torch_dtype
        if model_config.torch_dtype in ["auto", None]
        else getattr(torch, model_config.torch_dtype)
    )

    quantization_config = get_quantization_config(model_config)

    model_kwargs = dict(
        revision=model_config.model_revision,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
    )

    peft_config = get_peft_config(model_config)
    if peft_config is None:
        ref_model = AutoModelForCausalLM.from_pretrained(
            model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
        )
    else:
        ref_model = None

    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.chat_template is None:
        tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
    if script_args.ignore_bias_buffers:
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    dataset = load_dataset(script_args.dataset_name,
                           split=script_args.dataset_train_split)
    dataset=dataset.select_columns(['chosen', 'prompt', 'rejected'])

    trainer = DPOTrainer(
        model,
        ref_model,
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    trainer.train()

    trainer.save_model(training_args.output_dir)
Downloads last month
17
Safetensors
Model size
1.24B params
Tensor type
BF16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for AIR-hl/Llama-3.2-1B-DPO

Finetuned
(1)
this model
Quantizations
1 model

Dataset used to train AIR-hl/Llama-3.2-1B-DPO

Collection including AIR-hl/Llama-3.2-1B-DPO