Spaces:
Sleeping
Sleeping
File size: 5,782 Bytes
dc616b0 01ba912 dc616b0 01ba912 dc616b0 01ba912 dc616b0 01ba912 dc616b0 01ba912 dc616b0 01ba912 dc616b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import json
from typing import TYPE_CHECKING, List, Literal, Union
from datasets import Dataset, concatenate_datasets
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, GeneratorStep, KeepColumns, Step, StepInput
from distilabel.steps.tasks import TextGeneration
from typing_extensions import override
CHOSEN_TEMPLATE = """
You are provide with a conversation between a human and an AI assistant.
The final message is of poor quality positively. Your task is to regenerate one of high quality.
{% for message in conversation %}
{{ message["role"] }}: {{ message["content"] }}
{% endfor %}
High quality response:
""".rstrip()
CHOSEN_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate high quality response when other assistants created a poor quality response."
REJECT_TEMPLATE = """
You are provide with a conversation between a human and an AI assistant.
The final message is of high quality positively. Your task is to regenerate one of poor quality.
{% for message in conversation %}
{{ message["role"] }}: {{ message["content"] }}
{% endfor %}
Poor quality response:
""".rstrip()
REJECT_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate a poor quality response when other assistants created a high quality response."
class FilterConversationRatings(Step):
"""Filters conversations based on the rating of the last message."""
target_column: Union[Literal["chosen"], Literal["rejected"]]
batch_size: int = 5
@override
def process(self, dataset: StepInput) -> "GeneratorStepOutput":
column_rating_map = {
"chosen": 1,
"rejected": -1,
}
target_rating = column_rating_map[self.target_column]
for batch_start in range(0, len(dataset), self.batch_size):
batch = dataset[batch_start : batch_start + self.batch_size]
filtered_batch = []
for conversation in batch:
for row in batch:
_conversation = row["conversation"]
conversation = None
for idx, message in enumerate(_conversation, 1):
if not isinstance(message["rating"], int):
continue
if message["rating"] == target_rating:
conversation = _conversation[:idx]
break
if conversation:
filtered_batch.append({"conversation": conversation})
yield filtered_batch
@property
def outputs(self) -> "StepColumns":
return ["conversation"]
class AppendToConversationStep(Step):
"""Appends a generated message to a conversation."""
@property
def inputs(self) -> "StepColumns":
return ["generation", "conversation"]
@property
def outputs(self) -> "StepColumns":
return ["generated_conversation", "conversation"]
def process(self, inputs: StepInput) -> "StepOutput":
for input in inputs:
if not input["generation"]:
continue
if not input["conversation"]:
continue
input["generated_conversation"] = [
{"role": message["role"], "content": message["content"]}
for message in input["conversation"][:-1]
] + [{"role": "assistant", "content": input["generation"]}]
input["conversation"] = [
{"role": message["role"], "content": message["content"]}
for message in input["conversation"]
]
yield inputs
with Pipeline(
name="conversation_rejection",
description="Generate a chosen response to a rejected conversation.",
) as rejection_pipeline:
rejected_dataset = FilterConversationRatings(target_column="rejected")
chosen_text_gen = TextGeneration(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
system_prompt=CHOSEN_SYSTEM_PROMPT,
template=CHOSEN_TEMPLATE,
columns=["conversation"],
)
append_chosen = AppendToConversationStep(
output_mappings={
"generated_conversation": "chosen",
"conversation": "rejected",
},
)
keep_columns = KeepColumns(
columns=["chosen", "rejected"],
)
rejected_dataset >> chosen_text_gen >> append_chosen >> keep_columns
with Pipeline(
name="conversation_chosen",
description="Generate a rejected response to a chosen conversation.",
) as chosen_pipeline:
chosen_dataset = FilterConversationRatings(target_column="chosen")
rejected_text_gen = TextGeneration(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
system_prompt=REJECT_SYSTEM_PROMPT,
template=REJECT_TEMPLATE,
columns=["conversation"],
)
append_rejected = AppendToConversationStep(
output_mappings={
"generated_conversation": "rejected",
"conversation": "chosen",
},
)
keep_columns = KeepColumns(
columns=["chosen", "rejected"],
)
chosen_dataset >> rejected_text_gen >> append_rejected >> keep_columns
if __name__ == "__main__":
dataset_path = "example_data.json"
data = json.load(open(dataset_path))
dataset = Dataset.from_list(data)
rejected_dataset = rejection_pipeline.run(dataset=dataset, use_cache=False)
chosen_dataset = chosen_pipeline.run(dataset=dataset, use_cache=False)
dataset = concatenate_datasets(
dsets=[rejected_dataset["default"]["train"], chosen_dataset["default"]["train"]]
)
|