Fine-tuning Mistral on Your Dataset

Community Article Published July 22, 2024

This tutorial will walk you through the process of fine-tuning the Mistral-7B-Instruct model on your own dataset using the Hugging Face Transformers and PEFT libraries

Step 0: Install required libraries

!pip install -q datasets accelerate evaluate trl accelerate bitsandbytes peft

Step 1: Load and format your dataset

We'll define a function to format the prompts in the dataset and load the dataset:

def format_prompts(examples):
    """
    Define the format for your dataset
    This function should return a dictionary with a 'text' key containing the formatted prompts
    """
    pass
from datasets import load_dataset

dataset = load_dataset("your_dataset_name", split="train")
dataset = dataset.map(format_prompts, batched=True)

dataset['text'][2] # Check to see if the fields were formatted correctly

Step 2: Set up the model and tokenizer

Next, we'll load the pre-trained Mistral-7B-Instruct model and tokenizer, and set up the model for quantization and gradient checkpointing.

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_id)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

Step 3: Set up PEFT (Parameter-Efficient Fine-Tuning)

We'll use the PEFT technique to fine-tune the model efficiently. This involves setting up a LoraConfig and getting the PEFT model.

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)

Step 4: Set up the training arguments

We'll define the training arguments for the fine-tuning process.

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="your_model_name",
    num_train_epochs=4, # replace this, depending on your dataset
    per_device_train_batch_size=16,
    learning_rate=1e-5,
    optim="sgd"
)

Replace "your_model_name" with the desired name for your fine-tuned model.

Step 5: Initialize the trainer and fine-tune the model

Now, we'll initialize the SFTTrainer from the trl library and train the model.

from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    dataset_text_field='text',
    max_seq_length=1024,
)

trainer.train()

Step 6: Merge the adapter and model back together

When the fine tuning is finished, you can merge the model back together.

adapter_model = trainer.model
merged_model = adapter_model.merge_and_unload()

trained_tokenizer = trainer.tokenizer

Step 7: Push the fine-tuned model to the Hugging Face Hub

After all of that, you can optionally push the fine-tuned model to the Hugging Face Hub for easier sharing and deployment.

repo_id = "your_repo_name"

merged_model.push_to_hub(repo_id)
trained_tokenizer.push_to_hub(repo_id)

Step 8: The cursed child

If you are feeling extra spicy, you can dequantize the model in a new script.

!pip install accelerate bitsandbytes peft transformers # make sure to install dependencies again
from transformers import AutoModelForCausalLM

model_id = "your_repo_name"

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)

config = model.config
del config.quantization_config
del config._pre_quantization_dtype
model.config = config

model.dequantize()

model.push_to_hub(model_id) # the tokenizer will stay the same

be warned that mistral is a very big model, and it will take quite a bit of compute to do this.