Update README.md
Browse files
README.md
CHANGED
@@ -25,3 +25,168 @@ I performed DPO based on the already fine-tuned Hide101111001111000/llm-jp-3-13b
|
|
25 |
|
26 |
--python--
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
--python--
|
27 |
|
28 |
+
!pip install unsloth
|
29 |
+
# Also get the latest nightly Unsloth!
|
30 |
+
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git
|
31 |
+
|
32 |
+
from unsloth import PatchDPOTrainer
|
33 |
+
PatchDPOTrainer()
|
34 |
+
|
35 |
+
from unsloth import FastLanguageModel
|
36 |
+
import torch
|
37 |
+
max_seq_length = 2200 # Choose any! We auto support RoPE Scaling internally!
|
38 |
+
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
39 |
+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
|
40 |
+
HF_TOKEN = "your-token"#ご自身のToken
|
41 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
42 |
+
model_name = "Hide101111001111000/llm-jp-3-13b-it_lora_3.", # 自分がUnslothを使ってFTして、loraだけアップロードしているモデル
|
43 |
+
max_seq_length = max_seq_length,
|
44 |
+
dtype = dtype,
|
45 |
+
load_in_4bit = load_in_4bit,
|
46 |
+
token = HF_TOKEN,
|
47 |
+
)
|
48 |
+
|
49 |
+
from huggingface_hub import login
|
50 |
+
|
51 |
+
|
52 |
+
# 生成したトークンをペースト
|
53 |
+
login(HF_TOKEN)
|
54 |
+
|
55 |
+
from datasets import load_dataset
|
56 |
+
|
57 |
+
# データセットをロード
|
58 |
+
ds = load_dataset("llm-jp/hh-rlhf-12k-ja")
|
59 |
+
|
60 |
+
#フィルタリング "conversationsの処理"
|
61 |
+
def extract_anthropic_prompt(sample):
|
62 |
+
# 'conversations' カラムから最後の 'value' を取得
|
63 |
+
conversations = sample.get("conversations", [])
|
64 |
+
if not conversations or not isinstance(conversations, list):
|
65 |
+
raise KeyError(f"Key 'conversations' not found or is not a list in sample: {sample}")
|
66 |
+
|
67 |
+
last_conversation = conversations[-1]
|
68 |
+
if 'value' not in last_conversation:
|
69 |
+
raise KeyError(f"Key 'value' not found in last conversation: {last_conversation}")
|
70 |
+
|
71 |
+
prompt = last_conversation['value']
|
72 |
+
|
73 |
+
# 'chosen' と 'rejected' フィールドを使用する
|
74 |
+
chosen_text = sample["chosen"].replace("\\n ", "\\n")
|
75 |
+
rejected_text = sample["rejected"].replace("\\n ", "\\n")
|
76 |
+
|
77 |
+
return {
|
78 |
+
"prompt": prompt,
|
79 |
+
"chosen": chosen_text,
|
80 |
+
"rejected": rejected_text,
|
81 |
+
}
|
82 |
+
|
83 |
+
# フィルタリング関数を定義
|
84 |
+
def filter_short_examples(example):
|
85 |
+
return (
|
86 |
+
len(example['prompt']) <= 2000 and
|
87 |
+
len(example['chosen']) <= 2000 and
|
88 |
+
len(example['rejected']) <= 2000
|
89 |
+
)
|
90 |
+
|
91 |
+
# トレーニングデータをフィルタリング
|
92 |
+
ds_filter = ds['train'].map(extract_anthropic_prompt)
|
93 |
+
filtered_train = ds_filter.filter(filter_short_examples)
|
94 |
+
|
95 |
+
# データセットをトレーニング用と評価用に分割 (80%をトレーニング用、20%を評価用)
|
96 |
+
train_size = int(0.8 * len(filtered_train)) # トレーニングデータのサイズ
|
97 |
+
eval_size = len(filtered_train) - train_size # 評価データのサイズ
|
98 |
+
|
99 |
+
# インデックスを順序通りに生成 (ランダム性なし)
|
100 |
+
train_indices = list(range(train_size)) # トレーニング用インデックス
|
101 |
+
eval_indices = list(range(train_size, len(filtered_train))) # 評価用インデックス
|
102 |
+
|
103 |
+
# トレーニングデータと評価データを選択
|
104 |
+
train_dataset = filtered_train.select(train_indices)
|
105 |
+
eval_dataset = filtered_train.select(eval_indices)
|
106 |
+
|
107 |
+
# データセットのサイズを出力
|
108 |
+
print(f"トレーニングデータセットのサイズ: {len(train_dataset)}")
|
109 |
+
print(f"評価データセットのサイズ: {len(eval_dataset)}")
|
110 |
+
|
111 |
+
use_dataset = train_dataset.select(range(500))
|
112 |
+
use_dataset
|
113 |
+
|
114 |
+
# One must patch the DPO Trainer first!
|
115 |
+
from unsloth import PatchDPOTrainer
|
116 |
+
PatchDPOTrainer()
|
117 |
+
|
118 |
+
from transformers import TrainingArguments
|
119 |
+
from trl import DPOTrainer, DPOConfig
|
120 |
+
from unsloth import is_bfloat16_supported
|
121 |
+
|
122 |
+
dpo_trainer = DPOTrainer(
|
123 |
+
model = model,
|
124 |
+
ref_model = None,
|
125 |
+
args = DPOConfig(
|
126 |
+
per_device_train_batch_size = 2,
|
127 |
+
gradient_accumulation_steps = 4,
|
128 |
+
warmup_ratio = 0.1,
|
129 |
+
num_train_epochs = 1,
|
130 |
+
learning_rate = 5e-6,
|
131 |
+
fp16 = not is_bfloat16_supported(),
|
132 |
+
bf16 = is_bfloat16_supported(),
|
133 |
+
logging_steps = 1,
|
134 |
+
optim = "adamw_8bit",
|
135 |
+
weight_decay = 0.0,
|
136 |
+
lr_scheduler_type = "linear",
|
137 |
+
seed = 42,
|
138 |
+
output_dir = "outputs",
|
139 |
+
report_to = "none", # Use this for WandB etc
|
140 |
+
),
|
141 |
+
beta = 0.1,
|
142 |
+
train_dataset = use_dataset, #raw_datasets["train"],
|
143 |
+
# eval_dataset = raw_datasets["test"],
|
144 |
+
tokenizer = tokenizer,
|
145 |
+
max_length = 2048,
|
146 |
+
max_prompt_length = 1024,
|
147 |
+
)
|
148 |
+
|
149 |
+
dpo_trainer.train()
|
150 |
+
|
151 |
+
# ELYZA-tasks-100-TVの読み込み。事前にファイルをアップロードしてください
|
152 |
+
# データセットの読み込み。
|
153 |
+
# omnicampusの開発環境では、左にタスクのjsonlをドラッグアンドドロップしてから実行。
|
154 |
+
import json
|
155 |
+
datasets = []
|
156 |
+
with open("/content/elyza-tasks-100-TV_0.jsonl", "r") as f:
|
157 |
+
item = ""
|
158 |
+
for line in f:
|
159 |
+
line = line.strip()
|
160 |
+
item += line
|
161 |
+
if item.endswith("}"):
|
162 |
+
datasets.append(json.loads(item))
|
163 |
+
item = ""
|
164 |
+
|
165 |
+
# 学習したモデルを用いてタスクを実行
|
166 |
+
from tqdm import tqdm
|
167 |
+
|
168 |
+
# 推論するためにモデルのモードを変更
|
169 |
+
FastLanguageModel.for_inference(model)
|
170 |
+
|
171 |
+
results = []
|
172 |
+
for dt in tqdm(datasets):
|
173 |
+
input = dt["input"]
|
174 |
+
|
175 |
+
prompt = f"""### 指示\n{input}\n### 回答\n"""
|
176 |
+
|
177 |
+
inputs = tokenizer([prompt], return_tensors = "pt").to(model.device)
|
178 |
+
|
179 |
+
outputs = model.generate(**inputs, max_new_tokens = 2048, use_cache = True, do_sample=False, repetition_penalty=1.2)
|
180 |
+
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).split('\n### 回答')[-1]
|
181 |
+
|
182 |
+
results.append({"task_id": dt["task_id"], "input": input, "output": prediction})
|
183 |
+
|
184 |
+
#modelをhugging faceにアップロード
|
185 |
+
# LoRAアダプタだけ保存
|
186 |
+
model.push_to_hub_merged(
|
187 |
+
"llm-jp-3-13b-it_lora-DPO-ja",#保存するモデルの名前
|
188 |
+
tokenizer=tokenizer,
|
189 |
+
save_method="lora",#loraだけ保存
|
190 |
+
token=HF_TOKEN,
|
191 |
+
private=True
|
192 |
+
)
|