Spaces:
Sleeping
Sleeping
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import re | |
from dataclasses import dataclass | |
from typing import Dict, List, Optional | |
from datasets import load_dataset | |
from transformers import HfArgumentParser | |
class ScriptArguments: | |
r""" | |
Arguments for the script. | |
Args: | |
push_to_hub (`bool`, *optional*, defaults to `False`): | |
Whether to push the dataset to the Hugging Face Hub. | |
repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`): | |
Hugging Face repository ID to push the dataset to. | |
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): | |
Number of workers to use for dataset processing. | |
""" | |
push_to_hub: bool = False | |
repo_id: str = "trl-lib/hh-rlhf-helpful-base" | |
dataset_num_proc: Optional[int] = None | |
def common_start(str1: str, str2: str) -> str: | |
# Zip the two strings and iterate over them together | |
common_chars = [] | |
for c1, c2 in zip(str1, str2): | |
if c1 == c2: | |
common_chars.append(c1) | |
else: | |
break | |
# Join the common characters and return as a string | |
return "".join(common_chars) | |
def extract_dialogue(example: str) -> List[Dict[str, str]]: | |
# Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues | |
prompt_text = common_start(example["chosen"], example["rejected"]) | |
# The chosen and rejected may share a common start, so we need to remove the common part | |
if not prompt_text.endswith("\n\nAssistant: "): | |
prompt_text = prompt_text[: prompt_text.rfind("\n\nAssistant: ")] + "\n\nAssistant: " | |
# Extract the chosen and rejected lines | |
chosen_line = example["chosen"][len(prompt_text) :] | |
rejected_line = example["rejected"][len(prompt_text) :] | |
# Remove the generation prompt ("\n\nAssistant: ") from the prompt | |
prompt_text = prompt_text[: -len("\n\nAssistant: ")] | |
# Split the string at every occurrence of "Human: " or "Assistant: " | |
prompt_lines = re.split(r"(\n\nAssistant: |\n\nHuman: )", prompt_text) | |
# Remove the first element as it's empty | |
prompt_lines = prompt_lines[1:] | |
prompt = [] | |
for idx in range(0, len(prompt_lines), 2): | |
role = "user" if prompt_lines[idx] == "\n\nHuman: " else "assistant" | |
content = prompt_lines[idx + 1] | |
prompt.append({"role": role, "content": content}) | |
# Remove the prompt from the chosen and rejected dialogues | |
chosen = [{"role": "assitant", "content": chosen_line}] | |
rejected = [{"role": "assistant", "content": rejected_line}] | |
return {"prompt": prompt, "chosen": chosen, "rejected": rejected} | |
def runner(arguments): | |
parser = HfArgumentParser(arguments) | |
script_args = parser.parse_args_into_dataclasses()[0] | |
dataset = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base") | |
dataset = dataset.map(extract_dialogue, num_proc=script_args.dataset_num_proc) | |
return | |
# if script_args.push_to_hub: | |
# dataset.push_to_hub(script_args.repo_id) | |