|
--- |
|
tags: |
|
- EasyDeL |
|
- llama |
|
- CausalLM |
|
- splash |
|
- safetensors |
|
- Flax |
|
- JAX |
|
- TPU |
|
--- |
|
<p align="center"> |
|
<a href="https://github.com/erfanzar/EasyDeL"> |
|
<img src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/easydel-logo-with-text.png" height="80"> |
|
</a> |
|
</p> |
|
<p align="center"> |
|
<a href="https://github.com/erfanzar/EasyDeL"> |
|
<img src="https://img.shields.io/badge/π€_EasyDeL-v0.1.5-blue.svg" /> |
|
</a> |
|
<a href="https://github.com/erfanzar/EasyDeL"> |
|
<img src="https://img.shields.io/badge/Model_Arch-llama-green.svg" /> |
|
</a> |
|
</p> |
|
|
|
# Training Run: marin-8b-instruct-dpo |
|
|
|
This document outlines the configuration and parameters used for training the model `marin-8b-instruct-dpo` using the [EasyDeL](https://github.com/erfanzar/EasyDeL) library. |
|
|
|
EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models, with a primary focus on JAX/Flax for TPU/GPU environments. |
|
|
|
## How to Load This Checkpoint |
|
|
|
You can load the checkpoint generated from this training run using EasyDeL as follows: |
|
|
|
```python |
|
import easydel as ed |
|
from jax import numpy as jnp, lax |
|
|
|
# Path to the directory where this README.md is located |
|
repo_id = "user/model-id" # <-- TODO: Update this path with the actual save directory or model repo |
|
|
|
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained( |
|
repo_id, |
|
config_kwargs=EasyDeLBaseConfigDict( |
|
# use_scan_mlp=False, # Set to True to potentially reduce memory usage |
|
attn_dtype=jnp.float16, # Or jnp.bfloat16 |
|
# freq_max_position_embeddings=max_length, # Set if using RoPE and need truncation |
|
# mask_max_position_embeddings=max_length, # Set if max length is defined |
|
attn_mechanism=ed.AttentionMechanisms.SPLASH # Matches the mechanism used by this model |
|
), |
|
dtype=jnp.float16, # Or jnp.bfloat16 - Computation data type |
|
param_dtype=jnp.float16, # Or jnp.bfloat16 - Parameter data type |
|
precision=lax.Precision("fastest"), # Like "default", "fastest", "high", "highest" |
|
auto_shard_model=True, # Auto-shard across available devices |
|
) |
|
``` |
|
*Note: Replace `checkpoint_path` with the actual path to the saved checkpoint directory.* |
|
*The `params` returned are ready to be used with the `model`.* |
|
|
|
## Training Configuration Summary |
|
|
|
### Model & Hardware |
|
|
|
- **Model Name (Run Name)**: `marin-8b-instruct-dpo` |
|
- **Base Model Architecture**: `llama` |
|
- **Platform**: `TPU` |
|
- **Number of Devices Used**: `4` (total), `4` (local) |
|
- **EasyDeL Version**: `v0.1.5` |
|
|
|
### Key Training Parameters |
|
|
|
- **Learning Rate (Start β End)**: `1e-07` β `5e-07` |
|
- **Optimizer**: `EasyDeLOptimizers.ADAMW` |
|
- **Scheduler**: `EasyDeLSchedulers.LINEAR` |
|
- **Warmup Steps**: `0` |
|
- **Weight Decay**: `0.01` |
|
- **Loss Configuration**: `LossConfig( |
|
ignore_index : -100 |
|
label_smoothing : 0.0 |
|
z_loss : 0.0 |
|
loss_normalizing_factor : SpecialLossNormalizingFactor.NO_WEIGHT_NUM_REAL_TARGET_TOKENS |
|
num_labels : None |
|
problem_type : None |
|
divide_weight_sum : False |
|
shift_tokens : True |
|
break_on_nan : True |
|
reduction : None |
|
num_classification_labels : None |
|
classification_problem_type : None |
|
)` |
|
|
|
### Data & Batching |
|
|
|
- **Number of Training Epochs**: `4` |
|
- **Total Batch Size (per step)**: `4` |
|
- **Maximum Sequence Length**: `8192` |
|
- **Gradient Accumulation Steps**: `1` |
|
|
|
### Datatypes & Precision |
|
|
|
- **Computation `dtype`**: `<class 'jax.numpy.bfloat16'>` |
|
- **Parameter `param_dtype`**: `<class 'jax.numpy.bfloat16'>` |
|
- **Gradient Checkpointing Method**: `EasyDeLGradientCheckPointers.NOTHING_SAVEABLE` |
|
- **Attention Mechanism Used in Training**: `splash` (can be loaded as `AttentionMechanisms.SPLASH` if using `EasyDeLConfig`) |
|
|
|
### Run Control |
|
|
|
- **Max Training Steps**: `Not Set` |
|
- **Max Evaluation Steps**: `Not Set` |
|
- **Training Time Limit**: `Not Set` |
|
|
|
## Citation |
|
|
|
If you use EasyDeL in your research or work, please cite it: |
|
|
|
```bibtex |
|
@misc{Zare Chavoshi_2023, |
|
title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models}, |
|
url={https://github.com/erfanzar/EasyDeL}, |
|
author={Zare Chavoshi, Erfan}, |
|
year={2023} |
|
} |
|
``` |
|
|
|
--- |
|
*This document was automatically generated by EasyDeL v0.1.5 during the training run.* |