File size: 5,782 Bytes
dc616b0
 
 
 
 
 
 
 
 
 
 
 
01ba912
dc616b0
 
 
01ba912
dc616b0
 
01ba912
dc616b0
 
 
01ba912
dc616b0
 
 
01ba912
dc616b0
 
01ba912
dc616b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import json
from typing import TYPE_CHECKING, List, Literal, Union

from datasets import Dataset, concatenate_datasets
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, GeneratorStep, KeepColumns, Step, StepInput
from distilabel.steps.tasks import TextGeneration
from typing_extensions import override

CHOSEN_TEMPLATE = """
You are provide with a conversation between a human and an AI assistant. 
The final message is of poor quality positively. Your task is to regenerate one of high quality.
{% for message in conversation %}
{{ message["role"] }}: {{ message["content"] }}
{% endfor %}
High quality response:
""".rstrip()

CHOSEN_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate high quality response when other assistants created a poor quality response."

REJECT_TEMPLATE = """
You are provide with a conversation between a human and an AI assistant.
The final message is of high quality positively. Your task is to regenerate one of poor quality.
{% for message in conversation %}
{{ message["role"] }}: {{ message["content"] }}
{% endfor %}
Poor quality response:
""".rstrip()

REJECT_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate a poor quality response when other assistants created a high quality response."


class FilterConversationRatings(Step):
    """Filters conversations based on the rating of the last message."""

    target_column: Union[Literal["chosen"], Literal["rejected"]]
    batch_size: int = 5

    @override
    def process(self, dataset: StepInput) -> "GeneratorStepOutput":

        column_rating_map = {
            "chosen": 1,
            "rejected": -1,
        }

        target_rating = column_rating_map[self.target_column]

        for batch_start in range(0, len(dataset), self.batch_size):
            batch = dataset[batch_start : batch_start + self.batch_size]
            filtered_batch = []
            for conversation in batch:
                for row in batch:
                    _conversation = row["conversation"]
                    conversation = None
                    for idx, message in enumerate(_conversation, 1):
                        if not isinstance(message["rating"], int):
                            continue
                        if message["rating"] == target_rating:
                            conversation = _conversation[:idx]
                            break
                    if conversation:
                        filtered_batch.append({"conversation": conversation})
            yield filtered_batch

    @property
    def outputs(self) -> "StepColumns":
        return ["conversation"]


class AppendToConversationStep(Step):
    """Appends a generated message to a conversation."""

    @property
    def inputs(self) -> "StepColumns":
        return ["generation", "conversation"]

    @property
    def outputs(self) -> "StepColumns":
        return ["generated_conversation", "conversation"]

    def process(self, inputs: StepInput) -> "StepOutput":

        for input in inputs:
            if not input["generation"]:
                continue
            if not input["conversation"]:
                continue
            input["generated_conversation"] = [
                {"role": message["role"], "content": message["content"]}
                for message in input["conversation"][:-1]
            ] + [{"role": "assistant", "content": input["generation"]}]
            input["conversation"] = [
                {"role": message["role"], "content": message["content"]}
                for message in input["conversation"]
            ]
        yield inputs


with Pipeline(
    name="conversation_rejection",
    description="Generate a chosen response to a rejected conversation.",
) as rejection_pipeline:

    rejected_dataset = FilterConversationRatings(target_column="rejected")

    chosen_text_gen = TextGeneration(
        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        ),
        system_prompt=CHOSEN_SYSTEM_PROMPT,
        template=CHOSEN_TEMPLATE,
        columns=["conversation"],
    )

    append_chosen = AppendToConversationStep(
        output_mappings={
            "generated_conversation": "chosen",
            "conversation": "rejected",
        },
    )

    keep_columns = KeepColumns(
        columns=["chosen", "rejected"],
    )

    rejected_dataset >> chosen_text_gen >> append_chosen >> keep_columns

with Pipeline(
    name="conversation_chosen",
    description="Generate a rejected response to a chosen conversation.",
) as chosen_pipeline:

    chosen_dataset = FilterConversationRatings(target_column="chosen")

    rejected_text_gen = TextGeneration(
        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        ),
        system_prompt=REJECT_SYSTEM_PROMPT,
        template=REJECT_TEMPLATE,
        columns=["conversation"],
    )
    append_rejected = AppendToConversationStep(
        output_mappings={
            "generated_conversation": "rejected",
            "conversation": "chosen",
        },
    )
    keep_columns = KeepColumns(
        columns=["chosen", "rejected"],
    )
    chosen_dataset >> rejected_text_gen >> append_rejected >> keep_columns

if __name__ == "__main__":

    dataset_path = "example_data.json"
    data = json.load(open(dataset_path))

    dataset = Dataset.from_list(data)
    rejected_dataset = rejection_pipeline.run(dataset=dataset, use_cache=False)
    chosen_dataset = chosen_pipeline.run(dataset=dataset, use_cache=False)

    dataset = concatenate_datasets(
        dsets=[rejected_dataset["default"]["train"], chosen_dataset["default"]["train"]]
    )