metadata
tags:
- EasyDeL
- llama
- CausalLM
- splash
- safetensors
- Flax
- JAX
- TPU
Training Run: marin-8b-instruct-dpo
This document outlines the configuration and parameters used for training the model marin-8b-instruct-dpo
using the EasyDeL library.
EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models, with a primary focus on JAX/Flax for TPU/GPU environments.
How to Load This Checkpoint
You can load the checkpoint generated from this training run using EasyDeL as follows:
import easydel as ed
from jax import numpy as jnp, lax
# Path to the directory where this README.md is located
repo_id = "user/model-id" # <-- TODO: Update this path with the actual save directory or model repo
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
repo_id,
config_kwargs=EasyDeLBaseConfigDict(
# use_scan_mlp=False, # Set to True to potentially reduce memory usage
attn_dtype=jnp.float16, # Or jnp.bfloat16
# freq_max_position_embeddings=max_length, # Set if using RoPE and need truncation
# mask_max_position_embeddings=max_length, # Set if max length is defined
attn_mechanism=ed.AttentionMechanisms.SPLASH # Matches the mechanism used by this model
),
dtype=jnp.float16, # Or jnp.bfloat16 - Computation data type
param_dtype=jnp.float16, # Or jnp.bfloat16 - Parameter data type
precision=lax.Precision("fastest"), # Like "default", "fastest", "high", "highest"
auto_shard_model=True, # Auto-shard across available devices
)
Note: Replace checkpoint_path
with the actual path to the saved checkpoint directory.
The params
returned are ready to be used with the model
.
Training Configuration Summary
Model & Hardware
- Model Name (Run Name):
marin-8b-instruct-dpo
- Base Model Architecture:
llama
- Platform:
TPU
- Number of Devices Used:
4
(total),4
(local) - EasyDeL Version:
v0.1.5
Key Training Parameters
- Learning Rate (Start → End):
1e-07
→5e-07
- Optimizer:
EasyDeLOptimizers.ADAMW
- Scheduler:
EasyDeLSchedulers.LINEAR
- Warmup Steps:
0
- Weight Decay:
0.01
- Loss Configuration:
LossConfig( ignore_index : -100 label_smoothing : 0.0 z_loss : 0.0 loss_normalizing_factor : SpecialLossNormalizingFactor.NO_WEIGHT_NUM_REAL_TARGET_TOKENS num_labels : None problem_type : None divide_weight_sum : False shift_tokens : True break_on_nan : True reduction : None num_classification_labels : None classification_problem_type : None )
Data & Batching
- Number of Training Epochs:
4
- Total Batch Size (per step):
4
- Maximum Sequence Length:
8192
- Gradient Accumulation Steps:
1
Datatypes & Precision
- Computation
dtype
:<class 'jax.numpy.bfloat16'>
- Parameter
param_dtype
:<class 'jax.numpy.bfloat16'>
- Gradient Checkpointing Method:
EasyDeLGradientCheckPointers.NOTHING_SAVEABLE
- Attention Mechanism Used in Training:
splash
(can be loaded asAttentionMechanisms.SPLASH
if usingEasyDeLConfig
)
Run Control
- Max Training Steps:
Not Set
- Max Evaluation Steps:
Not Set
- Training Time Limit:
Not Set
Citation
If you use EasyDeL in your research or work, please cite it:
@misc{Zare Chavoshi_2023,
title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
url={https://github.com/erfanzar/EasyDeL},
author={Zare Chavoshi, Erfan},
year={2023}
}
This document was automatically generated by EasyDeL v0.1.5 during the training run.