🦙⚗️ Using Llama3 and distilabel to build fine-tuning datasets
In this post, I explain how you can build LLM fine-tuning datasets using distilabel and Hugging Face Inference endpoints.
What's the goal?
At Argilla, we have released a series of impactful open datasets for aligning LLMs. Unfortunately, all those datasets used closed models (mostly GPT-4) to perform the AI Feedback (AIF) or LLM-as-Judge step. This step uses an LLM to judge the quality of several responses so they can be used for preference tuning. The reason for using closed models is that the AI feedback step requires a powerful and highly capable model to approximate human preference. The end goal is creating a dataset that can be used to improve OSS models, using alignment methods like DPO, ORPO, or KTO. Now that Llama3 is closing the perfomance gap, we're a step closer to our vision: fully open data generation pipelines!
Ingredients
To build a high-quality preference dataset from scratch, we need:
- A dataset with prompts: I use DIBT/10k_prompts_ranked. I love this dataset because it contains high-quality prompts, curated by 314 amazing DIBT community members! Check out the Argilla Space if you want to look at the data yourself. Personally, I find that spending some minutes looking at the data is the most impactful way to learn about how AI models and methods work!
- One or several models to generate responses to the prompts: I use Llama3 models (8B and 70B instruct versions). Running these models can be costly and require certain skills to deploy them. For small experiments and prototypes, you can use Inference for Pros.
- A model to judge the quality of generated responses: As mentioned above, this is one of the first examples in the wild using Llama3-70B-Instruct to do this. It won't certainly be the last!
- Code to perform and orchestrate the data generation pipeline: You can develop your own code to define the data preparation, configuration, prompts, inference code, etc. or you can use our shiny new distilabel 1.0 which greatly simplifies this process and comes with all you need to build complex data synthesis and AIF pipelines!
- Human feedback: I use Argilla for this. To me this is the key step and the one that makes distilabel stand out: you can make your dataset available for human experts with a nice, transparent UI. AI generated datasets come with a lot of limitations (all sorts of biases, overly confident ratings, limited reasoning capabilities, and so on). If you want to make a high quality dataset, I highly encourage you to at least spend a few hours verifying the generated data. Even if you have limited resources and want to generate a fully synthetic dataset, you'll always find ways to improve the data generation pipeline (see our work on Notus for example). For more critical use cases, this step means you can make the AI dataset available for your pool of experts, before spending any money fine-tuning a model with data of unknown quality.
Recipe
Now let's see how to create distilabel pipeline that takes our prompts dataset and builds a preference dataset end to end.
The pipeline looks like this:
load_dataset \
> [generate_with_llama3_70B, generate_with_llama3_8B] \
> combine_columns \
> ultrafeedback \
> [keep_columns, push_to_argilla]
If you want to understand how distilabel works before reading the next sections, check out this blogpost.
For experimentation, I'm using Inference for Pros, for larger datasets you can deploy an Inference Endpoint. You can use the
InferenceEndpointsLLM
class, replacingmodel_id
withendpoint_name
andendpoint_namespace
.
Load dataset
This step loads the source data. distilabel provides a convenient method to read datasets from the Hub. As data pipelines are complex and potentially resource consuming, we recommend to start with a very small sample to make sure everything works before launching the full generation job. I use a small sample of high-quality prompts leveraging the LoadDataFromDicts
step:
# get great prompts annotated by at least 2 contributors
dataset = load_dataset(
"DIBT/10k_prompts_ranked",
split="train"
).filter(
lambda r: r['avg_rating']>=4 and r['num_responses']>=2
)
dataset = dataset.to_list()
load_dataset = LoadDataFromDicts(
name="load_dataset",
data=dataset[0:500], # during development I used [0:1] to iterate quickly
output_mappings={"prompt": "instruction"}
)
Generate responses
These steps take the prompts from our load_dataset
step and generate responses. I set up two steps that will run in parallel:
generate_with_llama3_70B = TextGeneration(
name="generate_with_llama3_70B",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
)
generate_with_llama3_8B = TextGeneration(
name="generate_with_llama3_8B",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
),
)
Combine columns
The UltraFeedback LLM-as-Judge step requires a list of responses called generations
(you can modify this using the inputs_mapping
parameter). The previous parallel steps output a generation
(the response) and model_name
(which model generated it). The "combine columns" step prepares the input for UltraFeedback:
combine_columns = CombineColumns(
name="combine_columns",
columns=["generation", "model_name"],
output_columns=["generations", "generation_models"],
)
UltraFeedback
This step takes the prompt and the two responses under generations
and rates their quality from 1 to 5 using Llama3-70B-Instruct. distilabel gives you the easiest-to-use implementation of UltraFeedback, a ground-breaking work from OpenBMB. To make it even easier, we have developed a novel prompt to judge the overall quality of responses, taking into account the original UltraFeedback dimensions (helpfulness
, honesty
, instruction-following
, truthfulness
).
ultrafeedback = UltraFeedback(
name="ultrafeedback",
aspect='overall-rating',
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
#llm=OpenAILLM(model="gpt-4"), # it used to be this! Now you can 🦙 -> 🥳
)
Keep columns
This step takes the outputs from previous steps and keeps the columns we want:
keep_columns = KeepColumns(
name="keep_columns",
columns=[
"instruction",
"generations",
"generation_models",
"ratings",
"rationales",
],
)
Push to Argilla for human feedback
This step automatically creates a fully-fledged UI for human feedback with the generated dataset!
Let's look at one example in the Argilla UI.
This example shows that our AI judge (Llama3-70B) thinks the response from the 8B models is better than the 70B model.
Instruction
With no explanation, label A→B with either "downplays", "exaggerates" or "same". A: Body clock function can break down when light and temperature levels throughout the day are out of sync, finds new UCL research in fruit flies. B: The behavioral disruption depends heavily on the phase relationship between light and temperature signals.
Response from Llama3-70B (rating 2)
A→B: downplays
Response from Llama3-8B (rating 5)
same
It also provides the following rationales:
The label "downplays" is incorrect, as the second statement (B) is a more specific and detailed explanation of the first statement (A), rather than downplaying it. The model's output does not align with the instruction and lacks correctness.
The label "same" is accurate, as the second statement (B) is a more detailed explanation of the first statement (A), without downplaying or exaggerating the information. The model's output is correct, confident, and aligned with the instruction.
I really like this example because:
- It shows how nuanced and difficult some questions and rating responses can be, even for humans. To me
same
could be more accurate but one could argue there's a slight degree of downplaying too (i'd love to hear your thoughts on this too!) - Based on the above, it's perfect example where human experts can help improving the quality of training data!
- It highlights that responses from larger models are not always the best. I've been saying this for while. We even showed the impact of not relying on this assumption when building preference datasets, in our reranked version of orca-pairs (used by 142 models already!)
Finally, I've made this dataset available for the open community. Sign in and explore it yourself, it's the best way to learn how preference datasets and AI feedback work!
Result
The full pipeline with 500
examples takes less than 30 min to run and costs $0. Check out the last section to see the full code.
An important feature of distilabel
is that pipelines are fully reproducible and you can share them via the Hub. I've made this pipeline and the dataset available for the community: https://huggingface.co/datasets/dvilasuero/distillama3-prompts10k. This means you can run it yourself like this:
distilabel pipeline run --config "https://huggingface.co/datasets/dvilasuero/distillama3-prompts10k/raw/main/pipeline.yaml"
Next steps
This post explains the basics and shows an end to end pipeline. In future posts, I'll compare the results of this pipeline with a GPT-4-turbo Judge to understand how far we are from replacing closed models for AIF datasets.
If you made it to this point: thanks so much for reading and share it with your network if you find it interesting!
Full pipeline code
from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineColumns, LoadDataFromDicts, KeepColumns, PreferenceToArgilla
from distilabel.steps.tasks import TextGeneration, UltraFeedback
from datasets import load_dataset
dataset = load_dataset("DIBT/10k_prompts_ranked", split="train").filter(lambda r: r['avg_rating']>=4 and r['num_responses']>=2)
dataset = dataset.to_list()
with Pipeline(
name="prefs-with-llama-3",
description="Pipeline for building preference datasets using Llama 3",
) as pipeline:
load_dataset = LoadDataFromDicts(
name="load_dataset",
data=dataset[0:100],
output_mappings={"prompt": "instruction"}
)
generate_with_llama3_70B = TextGeneration(
name="generate_with_llama3",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
),
)
generate_with_llama3_8B = TextGeneration(
name="generate_with_llama3_8B",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
),
)
combine_columns = CombineColumns(
name="combine_columns",
columns=["generation", "model_name"],
output_columns=["generations", "generation_models"],
)
ultrafeedback = UltraFeedback(
name="ultrafeedback",
aspect='overall-rating',
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
#llm=OpenAILLM(model="gpt-4"),
)
keep_columns = KeepColumns(
name="keep_columns",
columns=[
"instruction",
"generations",
"generation_models",
"ratings",
"rationales",
],
)
# Push the generated dataset to Argilla
# You need to `pip install argilla`
# and have an instance running: https://docs.argilla.io/en/latest/getting_started/quickstart_installation.html
push_to_argilla = PreferenceToArgilla(
name="push_to_argilla",
api_url="https://<argilla url>",
api_key="<super secret api key>",
dataset_name="ultrallama3",
dataset_workspace="admin",
num_generations=2,
)
generate_with_llama3_70B.connect(combine_columns)
generate_with_llama3_8B.connect(combine_columns)
load_dataset.connect(generate_with_llama3_70B)
load_dataset.connect(generate_with_llama3_8B)
combine_columns.connect(ultrafeedback)
ultrafeedback.connect(keep_columns)
ultrafeedback.connect(push_to_argilla)
if __name__ == "__main__":
distiset = pipeline.run(
parameters={
"load_dataset": {
"repo_id": "distilabel-internal-testing/instruction-dataset-mini",
"split": "test",
},
"generate_with_llama3": {
"llm": {
"generation_kwargs": {"max_new_tokens": 1024, "temperature": 0.7, "stop_sequences": ["<|eot_id|>", "<|end_of_text|>"]}
}
},
"generate_with_llama3_8B": {
"llm": {
"generation_kwargs": {"max_new_tokens": 1024, "temperature": 0.7, "stop_sequences": ["<|eot_id|>", "<|end_of_text|>"]}
}
},
"ultrafeedback": {
"llm": {
"generation_kwargs": {"max_new_tokens": 1024, "temperature": 0.1, "stop_sequences": ["<|eot_id|>", "<|end_of_text|>"]}
}
},
}
)
distiset.push_to_hub("dvilasuero/distillama3-prompts10k")