Marin-8B-DPO-stage2 / README.md
erfanzar's picture
Upload tokenizer
07e921f verified
---
tags:
- EasyDeL
- llama
- CausalLM
- splash
- safetensors
- Flax
- JAX
- TPU
---
<p align="center">
<a href="https://github.com/erfanzar/EasyDeL">
<img src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/easydel-logo-with-text.png" height="80">
</a>
</p>
<p align="center">
<a href="https://github.com/erfanzar/EasyDeL">
<img src="https://img.shields.io/badge/πŸ€—_EasyDeL-v0.1.5-blue.svg" />
</a>
<a href="https://github.com/erfanzar/EasyDeL">
<img src="https://img.shields.io/badge/Model_Arch-llama-green.svg" />
</a>
</p>
# Training Run: marin-8b-instruct-dpo
This document outlines the configuration and parameters used for training the model `marin-8b-instruct-dpo` using the [EasyDeL](https://github.com/erfanzar/EasyDeL) library.
EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models, with a primary focus on JAX/Flax for TPU/GPU environments.
## How to Load This Checkpoint
You can load the checkpoint generated from this training run using EasyDeL as follows:
```python
import easydel as ed
from jax import numpy as jnp, lax
# Path to the directory where this README.md is located
repo_id = "user/model-id" # <-- TODO: Update this path with the actual save directory or model repo
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
repo_id,
config_kwargs=EasyDeLBaseConfigDict(
# use_scan_mlp=False, # Set to True to potentially reduce memory usage
attn_dtype=jnp.float16, # Or jnp.bfloat16
# freq_max_position_embeddings=max_length, # Set if using RoPE and need truncation
# mask_max_position_embeddings=max_length, # Set if max length is defined
attn_mechanism=ed.AttentionMechanisms.SPLASH # Matches the mechanism used by this model
),
dtype=jnp.float16, # Or jnp.bfloat16 - Computation data type
param_dtype=jnp.float16, # Or jnp.bfloat16 - Parameter data type
precision=lax.Precision("fastest"), # Like "default", "fastest", "high", "highest"
auto_shard_model=True, # Auto-shard across available devices
)
```
*Note: Replace `checkpoint_path` with the actual path to the saved checkpoint directory.*
*The `params` returned are ready to be used with the `model`.*
## Training Configuration Summary
### Model & Hardware
- **Model Name (Run Name)**: `marin-8b-instruct-dpo`
- **Base Model Architecture**: `llama`
- **Platform**: `TPU`
- **Number of Devices Used**: `4` (total), `4` (local)
- **EasyDeL Version**: `v0.1.5`
### Key Training Parameters
- **Learning Rate (Start β†’ End)**: `1e-05` β†’ `5e-07`
- **Optimizer**: `EasyDeLOptimizers.ADAMW`
- **Scheduler**: `EasyDeLSchedulers.COSINE`
- **Warmup Steps**: `0`
- **Weight Decay**: `0.01`
- **Loss Configuration**: `LossConfig(
ignore_index : -100
label_smoothing : 0.0
z_loss : 0.0
loss_normalizing_factor : SpecialLossNormalizingFactor.NO_WEIGHT_NUM_REAL_TARGET_TOKENS
num_labels : None
problem_type : None
divide_weight_sum : False
shift_tokens : True
break_on_nan : True
reduction : None
num_classification_labels : None
classification_problem_type : None
)`
### Data & Batching
- **Number of Training Epochs**: `1`
- **Total Batch Size (per step)**: `4`
- **Maximum Sequence Length**: `8192`
- **Gradient Accumulation Steps**: `1`
### Datatypes & Precision
- **Computation `dtype`**: `<class 'jax.numpy.bfloat16'>`
- **Parameter `param_dtype`**: `<class 'jax.numpy.bfloat16'>`
- **Gradient Checkpointing Method**: `EasyDeLGradientCheckPointers.NOTHING_SAVEABLE`
- **Attention Mechanism Used in Training**: `splash` (can be loaded as `AttentionMechanisms.SPLASH` if using `EasyDeLConfig`)
### Run Control
- **Max Training Steps**: `Not Set`
- **Max Evaluation Steps**: `Not Set`
- **Training Time Limit**: `Not Set`
## Citation
If you use EasyDeL in your research or work, please cite it:
```bibtex
@misc{Zare Chavoshi_2023,
title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
url={https://github.com/erfanzar/EasyDeL},
author={Zare Chavoshi, Erfan},
year={2023}
}
```
---
*This document was automatically generated by EasyDeL v0.1.5 during the training run.*