cyberagent/ca-reward-3b-ja
- 軽量な日本語報酬モデルの開発を目的として実装したモデルを公開する。
- 既存の指示文と新たに合成した指示文に対して、応答文を複数生成し、llm-as-a-judgeで疑似選好ラベルを付与することで疑似選好データセットを作成した。
- 上記の疑似選好データセットを分類するモデルを学習することで、指示文に対する応答文の好ましさを定量化する報酬モデルを作成した。
評価
- 人手で選好ラベル(好ましい応答文か、好ましくない応答文)が付与された既存のデータセットを収集し、選好ラベルの分類精度(Accuracy)を評価した。
- 既存の報酬モデルによる分類精度と、
gpt-4o-2024-08-06
を用いたllm-as-a-judgeによる分類精度を記載した。
HelpSteer3 | llm-jp-chatbot-arena | |
---|---|---|
OpenAssistant/reward-model-deberta-v3-large-v2 | 0.5124 | 0.5610 |
Skywork/Skywork-Reward-Gemma-2-27B-v0.2 | 0.6854 | 0.5166 |
cyberagent/ca-reward-3b-ja | 0.7032 | 0.5366 |
cyberagent/calm3-22b-chat-selfimprove-experimental | 0.7216 | 0.6075 |
gpt-4o-2024-08-06 | 0.7845 | 0.6842 |
環境構築
- 開発環境: Ubuntu 24.04.2 LTS
- ライブラリのインストール
pip install -U -q pip
pip install -q torch==2.8.0
pip install -q transformers==4.51.3
pip install -q accelerate==0.29.3
pip install -q sentencepiece==0.2.0
- バージョンの確認
import torch, transformers, accelerate, tokenizers, sentencepiece
print(torch.__version__)
print(transformers.__version__)
print(accelerate.__version__)
print(tokenizers.__version__)
print(sentencepiece.__version__)
# 2.8.0+cu128
# 4.51.3
# 0.29.3
# 0.21.4
# 0.2.0
コードサンプル
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained(
"cyberagent/ca-reward-3b-ja",
device_map="auto",
num_labels=1,
)
tokenizer = AutoTokenizer.from_pretrained("cyberagent/ca-reward-3b-ja")
prompt = """手軽に栄養を補給できる食事を教えてください。"""
response1 = """栄養補給が手軽にできるものとしては、野菜たっぷりのスムージー、ゆで卵とサラダ、納豆ご飯などがおすすめです。特に納豆は手間なく良質なタンパク質が摂れますよ。お身体を大切にしてくださいね。"""
response2 = """手軽な栄養補給なら冷凍食品でいいと思います。レンジで温めるだけだし、時間がない時はコンビニ弁当でも悪くないですよ。"""
chat1 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response1}]
chat2 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response2}]
chat1_formatted = tokenizer.apply_chat_template(chat1, tokenize=False)
chat2_formatted = tokenizer.apply_chat_template(chat2, tokenize=False)
chat1_tokenized = tokenizer(chat1_formatted, return_tensors="pt", max_length=4096, truncation=True, padding="max_length",).to(model.device)
chat2_tokenized = tokenizer(chat2_formatted, return_tensors="pt", max_length=4096, truncation=True, padding="max_length",).to(model.device)
with torch.no_grad():
chat1_score = model(**chat1_tokenized).logits.item()
chat2_score = model(**chat2_tokenized).logits.item()
print(f"Score for response 1: {chat1_score}")
print(f"Score for response 2: {chat2_score}")
# Score for response 1: 1.2595189809799194
# Score for response 2: -2.917454242706299
学習モデル
- sbintuitions/sarashina2.2-3b-instruct-v0.1をベースモデルに用いた。
学習データセット
- ChatBotArenaのinstruction文と、社内で作成したデータセットを用いた。
- cyberagent/calm3-22b-chat-selfimprove-experimentalによるLLM-as-a-judgeを用いて、作成したデータセットに対して疑似選好ラベルを付与して、疑似選好データセットを作成した。
報酬モデル単体の分類性能の評価に使用したデータセット
評価データセット名 | サンプル数 | 備考 |
---|---|---|
llm-jp/llm-jp-chatbot-arena-conversations | 448 | アノテーション結果がどちらも悪い、どちらも良い、同程度(winnerカラムがtie, tie(both good), tie(both bad))のサンプルは評価から除いた。シングルターンのみの学習データで訓練したため、評価にはシングルターンのサンプルのみ使用した。 |
nvidia/HelpSteer3 | 283 | 日本語性能を測るため、日本語サンプル(languageカラム=japanese)のみ評価に使用した。学習データセットにgemmaによって生成された応答文が含まれると記載されていたため、学習には用いず、train/validation split両方の分類性能の平均値を評価に使用した。シングルターンのみの学習データで訓練したため、評価にはシングルターンのサンプルのみ使用した。 |
学習時のハイパーパラメータ
- learning_rateは[8e-07, 1e-06, 2.5e-06, 5e-06]で学習し、learning_rate=5e-06のモデルを採用した。
training type | batchsize | learning_rate | lr_scheduler | num_epoch | max_length | num_of_training_sample | num_of_validation_sample | hardware | |
---|---|---|---|---|---|---|---|---|---|
ca-reward-3b-ja | full parameter tuning | 32 | 5e-06 | linear | 1 | 4096 | 601,348 | 8,000 | NVIDIA_A100_80GB |
リリース
v0.1, 2025/08/08
Authors
- 三橋亮太(corresponding author)
- 開発中にフィードバックをいただいた皆様:陣内佑、坂本充生、森村哲郎、阿部拳之、蟻生開人、藤本悠雅、暮石航大
引用
@misc{cyberagent-ca-reward-3b-ja,
title={cyberagent/ca-reward-3b-ja},
url={https://huggingface.co/cyberagent/ca-reward-3b-ja},
author={Ryota Mitsuhashi},
year={2025},
}
ライセンス
Apache-2.0
- Downloads last month
- 10
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Model tree for cyberagent/ca-reward-3b-ja
Base model
sbintuitions/sarashina2.2-3b