Training Run: marin-8b-instruct-dpo

This document outlines the configuration and parameters used for training the model marin-8b-instruct-dpo using the EasyDeL library.

EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models, with a primary focus on JAX/Flax for TPU/GPU environments.

How to Load This Checkpoint

You can load the checkpoint generated from this training run using EasyDeL as follows:

import easydel as ed
from jax import numpy as jnp, lax

# Path to the directory where this README.md is located
repo_id = "user/model-id" # <-- TODO: Update this path with the actual save directory or model repo

model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
    repo_id,
    config_kwargs=EasyDeLBaseConfigDict(
        # use_scan_mlp=False, # Set to True to potentially reduce memory usage
        attn_dtype=jnp.float16, # Or jnp.bfloat16
        # freq_max_position_embeddings=max_length, # Set if using RoPE and need truncation
        # mask_max_position_embeddings=max_length, # Set if max length is defined
        attn_mechanism=ed.AttentionMechanisms.SPLASH # Matches the mechanism used by this model
    ),
    dtype=jnp.float16, # Or jnp.bfloat16 - Computation data type
    param_dtype=jnp.float16, # Or jnp.bfloat16 - Parameter data type
    precision=lax.Precision("fastest"), # Like "default", "fastest", "high", "highest"
    auto_shard_model=True, # Auto-shard across available devices
)

Note: Replace checkpoint_path with the actual path to the saved checkpoint directory. The params returned are ready to be used with the model.

Training Configuration Summary

Model & Hardware

  • Model Name (Run Name): marin-8b-instruct-dpo
  • Base Model Architecture: llama
  • Platform: TPU
  • Number of Devices Used: 4 (total), 4 (local)
  • EasyDeL Version: v0.1.5

Key Training Parameters

  • Learning Rate (Start โ†’ End): 1e-05 โ†’ 5e-07
  • Optimizer: EasyDeLOptimizers.ADAMW
  • Scheduler: EasyDeLSchedulers.COSINE
  • Warmup Steps: 0
  • Weight Decay: 0.01
  • Loss Configuration: LossConfig( ignore_index : -100 label_smoothing : 0.0 z_loss : 0.0 loss_normalizing_factor : SpecialLossNormalizingFactor.NO_WEIGHT_NUM_REAL_TARGET_TOKENS num_labels : None problem_type : None divide_weight_sum : False shift_tokens : True break_on_nan : True reduction : None num_classification_labels : None classification_problem_type : None )

Data & Batching

  • Number of Training Epochs: 1
  • Total Batch Size (per step): 4
  • Maximum Sequence Length: 8192
  • Gradient Accumulation Steps: 1

Datatypes & Precision

  • Computation dtype: <class 'jax.numpy.bfloat16'>
  • Parameter param_dtype: <class 'jax.numpy.bfloat16'>
  • Gradient Checkpointing Method: EasyDeLGradientCheckPointers.NOTHING_SAVEABLE
  • Attention Mechanism Used in Training: splash (can be loaded as AttentionMechanisms.SPLASH if using EasyDeLConfig)

Run Control

  • Max Training Steps: Not Set
  • Max Evaluation Steps: Not Set
  • Training Time Limit: Not Set

Citation

If you use EasyDeL in your research or work, please cite it:

@misc{Zare Chavoshi_2023,
    title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
    url={https://github.com/erfanzar/EasyDeL},
    author={Zare Chavoshi, Erfan},
    year={2023}
}

This document was automatically generated by EasyDeL v0.1.5 during the training run.

Downloads last month
3
Safetensors
Model size
8.03B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support