Spaces:
Sleeping
Sleeping
import json | |
from typing import TYPE_CHECKING, List, Literal, Union | |
from datasets import Dataset, concatenate_datasets | |
from distilabel.llms.huggingface import InferenceEndpointsLLM | |
from distilabel.pipeline import Pipeline | |
from distilabel.steps import CombineOutputs, GeneratorStep, KeepColumns, Step, StepInput | |
from distilabel.steps.tasks import TextGeneration | |
from typing_extensions import override | |
CHOSEN_TEMPLATE = """ | |
You are provide with a conversation between a human and an AI assistant. | |
The final message is of poor quality positively. Your task is to regenerate one of high quality. | |
{% for message in conversation %} | |
{{ message["role"] }}: {{ message["content"] }} | |
{% endfor %} | |
High quality response: | |
""".rstrip() | |
CHOSEN_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate high quality response when other assistants created a poor quality response." | |
REJECT_TEMPLATE = """ | |
You are provide with a conversation between a human and an AI assistant. | |
The final message is of high quality positively. Your task is to regenerate one of poor quality. | |
{% for message in conversation %} | |
{{ message["role"] }}: {{ message["content"] }} | |
{% endfor %} | |
Poor quality response: | |
""".rstrip() | |
REJECT_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate a poor quality response when other assistants created a high quality response." | |
class FilterConversationRatings(Step): | |
"""Filters conversations based on the rating of the last message.""" | |
target_column: Union[Literal["chosen"], Literal["rejected"]] | |
batch_size: int = 5 | |
def process(self, dataset: StepInput) -> "GeneratorStepOutput": | |
column_rating_map = { | |
"chosen": 1, | |
"rejected": -1, | |
} | |
target_rating = column_rating_map[self.target_column] | |
for batch_start in range(0, len(dataset), self.batch_size): | |
batch = dataset[batch_start : batch_start + self.batch_size] | |
filtered_batch = [] | |
for conversation in batch: | |
for row in batch: | |
_conversation = row["conversation"] | |
conversation = None | |
for idx, message in enumerate(_conversation, 1): | |
if not isinstance(message["rating"], int): | |
continue | |
if message["rating"] == target_rating: | |
conversation = _conversation[:idx] | |
break | |
if conversation: | |
filtered_batch.append({"conversation": conversation}) | |
yield filtered_batch | |
def outputs(self) -> "StepColumns": | |
return ["conversation"] | |
class AppendToConversationStep(Step): | |
"""Appends a generated message to a conversation.""" | |
def inputs(self) -> "StepColumns": | |
return ["generation", "conversation"] | |
def outputs(self) -> "StepColumns": | |
return ["generated_conversation", "conversation"] | |
def process(self, inputs: StepInput) -> "StepOutput": | |
for input in inputs: | |
if not input["generation"]: | |
continue | |
if not input["conversation"]: | |
continue | |
input["generated_conversation"] = [ | |
{"role": message["role"], "content": message["content"]} | |
for message in input["conversation"][:-1] | |
] + [{"role": "assistant", "content": input["generation"]}] | |
input["conversation"] = [ | |
{"role": message["role"], "content": message["content"]} | |
for message in input["conversation"] | |
] | |
yield inputs | |
with Pipeline( | |
name="conversation_rejection", | |
description="Generate a chosen response to a rejected conversation.", | |
) as rejection_pipeline: | |
rejected_dataset = FilterConversationRatings(target_column="rejected") | |
chosen_text_gen = TextGeneration( | |
llm=InferenceEndpointsLLM( | |
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", | |
), | |
system_prompt=CHOSEN_SYSTEM_PROMPT, | |
template=CHOSEN_TEMPLATE, | |
columns=["conversation"], | |
) | |
append_chosen = AppendToConversationStep( | |
output_mappings={ | |
"generated_conversation": "chosen", | |
"conversation": "rejected", | |
}, | |
) | |
keep_columns = KeepColumns( | |
columns=["chosen", "rejected"], | |
) | |
rejected_dataset >> chosen_text_gen >> append_chosen >> keep_columns | |
with Pipeline( | |
name="conversation_chosen", | |
description="Generate a rejected response to a chosen conversation.", | |
) as chosen_pipeline: | |
chosen_dataset = FilterConversationRatings(target_column="chosen") | |
rejected_text_gen = TextGeneration( | |
llm=InferenceEndpointsLLM( | |
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", | |
), | |
system_prompt=REJECT_SYSTEM_PROMPT, | |
template=REJECT_TEMPLATE, | |
columns=["conversation"], | |
) | |
append_rejected = AppendToConversationStep( | |
output_mappings={ | |
"generated_conversation": "rejected", | |
"conversation": "chosen", | |
}, | |
) | |
keep_columns = KeepColumns( | |
columns=["chosen", "rejected"], | |
) | |
chosen_dataset >> rejected_text_gen >> append_rejected >> keep_columns | |
if __name__ == "__main__": | |
dataset_path = "example_data.json" | |
data = json.load(open(dataset_path)) | |
dataset = Dataset.from_list(data) | |
rejected_dataset = rejection_pipeline.run(dataset=dataset, use_cache=False) | |
chosen_dataset = chosen_pipeline.run(dataset=dataset, use_cache=False) | |
dataset = concatenate_datasets( | |
dsets=[rejected_dataset["default"]["train"], chosen_dataset["default"]["train"]] | |
) | |