This is an attempt to replicate the RLHF pipeline
Base Model
We used bloomz-7b1-mt because of its less-restricted license and multilingual ability.
Supervised Fintune
For SFT we used a combination of multiple datasets including:
- RyokoAI/ShareGPT52K
- GPTeacher
- Alpaca-GPT4 en & zh
- Filtered subset of machine-translated ShareGPT dataset into Chinese
Reward Model
For RM we used the code of reward-modeling repo and datasets from
Reinforcement Learning
For RL we used the code of trlx with slight modification.
Instead of building value network upon the policy network with a single linear layer, we add another hydra head upon the reference network's frozen bottom layers as value network.
Example
We used Vicuna v1.1 template for model training
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "keyfan/bloomz-rlhf"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).cuda()
template = ("A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions. "
"USER: {}\nASSISTANT:")
question = template.format("Who was the president of the United States in 1955?")
inputs = tokenizer.encode(question, return_tensors="pt").cuda()
outputs = model.generate(inputs, do_sample=True, top_p=0.8, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))
Evalutions
Result on the Chinese BELLE eval set
others | rewrite | classification | generation | summarization | extract | open qa | brainstorming | closed qa | macro ave | macro ave w/o others |
---|---|---|---|---|---|---|---|---|---|---|
0.619 | 0.873 | 0.706 | 0.934 | 0.755 | 0.619 | 0.527 | 0.908 | 0.615 | 0.728 | 0.742 |
- We found in GPT-4 evaluation the order in which the responses were presented has unneglectable affect on the final score even with the very-well designed Vicuna prompt. So we removed the score on the Vicuna eval set.
- Downloads last month
- 17
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.