burtenshaw commited on
Commit
6d59547
·
unverified ·
2 Parent(s): aac30ac 873f98c

Merge pull request #1 from huggingface/generate-dpo-dataset

Browse files
data/download_data.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ wget https://huggingface.co/datasets/feel-fl/open-human-feedback-chat-en/resolve/main/data/data_d6f0f072-348e-4f61-9a44-26dbd2ccba75.json
data/example_data.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "conversation": [
4
+ {
5
+ "role": "user",
6
+ "content": "hello how are you?",
7
+ "options": "",
8
+ "rating": 0
9
+ },
10
+ {
11
+ "role": "assistant",
12
+ "content": "leave me alone you weirdo!",
13
+ "options": "",
14
+ "rating": -1
15
+ }
16
+ ],
17
+ "timestamp": "2024-12-10T15:35:52.363635",
18
+ "session_id": "9c5b367d-12c2-4ae0-a868-e2e783e50935",
19
+ "conversation_id": "870fac58-2b2c-45ac-93f7-7cd8a43644be"
20
+ },
21
+ {
22
+ "conversation": [
23
+ {
24
+ "role": "user",
25
+ "content": "hello",
26
+ "options": "",
27
+ "rating": 0
28
+ },
29
+ {
30
+ "role": "assistant",
31
+ "content": "Hello! How can I assist you today? If you have any questions or just want to chat, feel free \ud83d\ude0a.",
32
+ "options": "",
33
+ "rating": 1
34
+ }
35
+ ],
36
+ "timestamp": "2024-12-10T15:35:52.363635",
37
+ "session_id": "9c5b367d-12c2-4ae0-a868-e2e783e50935",
38
+ "conversation_id": "870fac58-2b2c-45ac-93f7-7cd8a43644be"
39
+ },
40
+ {
41
+ "conversation": [
42
+ {
43
+ "role": "user",
44
+ "content": "hello",
45
+ "options": "",
46
+ "rating": 0
47
+ },
48
+ {
49
+ "role": "assistant",
50
+ "content": "Hello! How can I assist you today? If you have any questions or just want to chat, feel free \ud83d\ude0a.",
51
+ "options": "",
52
+ "rating": 1
53
+ }
54
+ ],
55
+ "timestamp": "2024-12-10T15:35:52.363635",
56
+ "session_id": "9c5b367d-12c2-4ae0-a868-e2e783e50935",
57
+ "conversation_id": "870fac58-2b2c-45ac-93f7-7cd8a43644be"
58
+ }
59
+ ]
data/generate_dpo.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import TYPE_CHECKING, List, Literal, Union
3
+
4
+ from datasets import Dataset, concatenate_datasets
5
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
6
+ from distilabel.pipeline import Pipeline
7
+ from distilabel.steps import CombineOutputs, GeneratorStep, KeepColumns, Step, StepInput
8
+ from distilabel.steps.tasks import TextGeneration
9
+ from typing_extensions import override
10
+
11
+ CHOSEN_TEMPLATE = """
12
+ You are provide with a conversation between a human and an AI assistant.
13
+ The final message is of poor quality positively. Your task is to regenerate one of high quality.
14
+ {% for message in conversation %}
15
+ {{ message["role"] }}: {{ message["content"] }}
16
+ {% endfor %}
17
+ High quality response:
18
+ """.rstrip()
19
+
20
+ CHOSEN_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate high quality response when other assistants created a poor quality response."
21
+
22
+ REJECT_TEMPLATE = """
23
+ You are provide with a conversation between a human and an AI assistant.
24
+ The final message is of high quality positively. Your task is to regenerate one of poor quality.
25
+ {% for message in conversation %}
26
+ {{ message["role"] }}: {{ message["content"] }}
27
+ {% endfor %}
28
+ Poor quality response:
29
+ """.rstrip()
30
+
31
+ REJECT_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate a poor quality response when other assistants created a high quality response."
32
+
33
+
34
+ class FilterConversationRatings(Step):
35
+ """Filters conversations based on the rating of the last message."""
36
+
37
+ target_column: Union[Literal["chosen"], Literal["rejected"]]
38
+ batch_size: int = 5
39
+
40
+ @override
41
+ def process(self, dataset: StepInput) -> "GeneratorStepOutput":
42
+
43
+ column_rating_map = {
44
+ "chosen": 1,
45
+ "rejected": -1,
46
+ }
47
+
48
+ target_rating = column_rating_map[self.target_column]
49
+
50
+ for batch_start in range(0, len(dataset), self.batch_size):
51
+ batch = dataset[batch_start : batch_start + self.batch_size]
52
+ filtered_batch = []
53
+ for conversation in batch:
54
+ for row in batch:
55
+ _conversation = row["conversation"]
56
+ conversation = None
57
+ for idx, message in enumerate(_conversation, 1):
58
+ if not isinstance(message["rating"], int):
59
+ continue
60
+ if message["rating"] == target_rating:
61
+ conversation = _conversation[:idx]
62
+ break
63
+ if conversation:
64
+ filtered_batch.append({"conversation": conversation})
65
+ yield filtered_batch
66
+
67
+ @property
68
+ def outputs(self) -> "StepColumns":
69
+ return ["conversation"]
70
+
71
+
72
+ class AppendToConversationStep(Step):
73
+ """Appends a generated message to a conversation."""
74
+
75
+ @property
76
+ def inputs(self) -> "StepColumns":
77
+ return ["generation", "conversation"]
78
+
79
+ @property
80
+ def outputs(self) -> "StepColumns":
81
+ return ["generated_conversation", "conversation"]
82
+
83
+ def process(self, inputs: StepInput) -> "StepOutput":
84
+
85
+ for input in inputs:
86
+ if not input["generation"]:
87
+ continue
88
+ if not input["conversation"]:
89
+ continue
90
+ input["generated_conversation"] = [
91
+ {"role": message["role"], "content": message["content"]}
92
+ for message in input["conversation"][:-1]
93
+ ] + [{"role": "assistant", "content": input["generation"]}]
94
+ input["conversation"] = [
95
+ {"role": message["role"], "content": message["content"]}
96
+ for message in input["conversation"]
97
+ ]
98
+ yield inputs
99
+
100
+
101
+ with Pipeline(
102
+ name="conversation_rejection",
103
+ description="Generate a chosen response to a rejected conversation.",
104
+ ) as rejection_pipeline:
105
+
106
+ rejected_dataset = FilterConversationRatings(target_column="rejected")
107
+
108
+ chosen_text_gen = TextGeneration(
109
+ llm=InferenceEndpointsLLM(
110
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
111
+ ),
112
+ system_prompt=CHOSEN_SYSTEM_PROMPT,
113
+ template=CHOSEN_TEMPLATE,
114
+ columns=["conversation"],
115
+ )
116
+
117
+ append_chosen = AppendToConversationStep(
118
+ output_mappings={
119
+ "generated_conversation": "chosen",
120
+ "conversation": "rejected",
121
+ },
122
+ )
123
+
124
+ keep_columns = KeepColumns(
125
+ columns=["chosen", "rejected"],
126
+ )
127
+
128
+ rejected_dataset >> chosen_text_gen >> append_chosen >> keep_columns
129
+
130
+ with Pipeline(
131
+ name="conversation_chosen",
132
+ description="Generate a rejected response to a chosen conversation.",
133
+ ) as chosen_pipeline:
134
+
135
+ chosen_dataset = FilterConversationRatings(target_column="chosen")
136
+
137
+ rejected_text_gen = TextGeneration(
138
+ llm=InferenceEndpointsLLM(
139
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
140
+ ),
141
+ system_prompt=REJECT_SYSTEM_PROMPT,
142
+ template=REJECT_TEMPLATE,
143
+ columns=["conversation"],
144
+ )
145
+ append_rejected = AppendToConversationStep(
146
+ output_mappings={
147
+ "generated_conversation": "rejected",
148
+ "conversation": "chosen",
149
+ },
150
+ )
151
+ keep_columns = KeepColumns(
152
+ columns=["chosen", "rejected"],
153
+ )
154
+ chosen_dataset >> rejected_text_gen >> append_rejected >> keep_columns
155
+
156
+ if __name__ == "__main__":
157
+
158
+ dataset_path = "example_data.json"
159
+ data = json.load(open(dataset_path))
160
+
161
+ dataset = Dataset.from_list(data)
162
+ rejected_dataset = rejection_pipeline.run(dataset=dataset, use_cache=False)
163
+ chosen_dataset = chosen_pipeline.run(dataset=dataset, use_cache=False)
164
+
165
+ dataset = concatenate_datasets(
166
+ dsets=[rejected_dataset["default"]["train"], chosen_dataset["default"]["train"]]
167
+ )
pyproject.toml CHANGED
@@ -6,6 +6,8 @@ readme = "README.md"
6
  requires-python = ">=3.11"
7
  dependencies = [
8
  "datasets>=3.1.0",
 
 
9
  ]
10
 
11
  [dependency-groups]
 
6
  requires-python = ">=3.11"
7
  dependencies = [
8
  "datasets>=3.1.0",
9
+ "distilabel>=1.4.1",
10
+ "ipykernel>=6.29.5",
11
  ]
12
 
13
  [dependency-groups]