Create demo_train.py
Browse files- demo_train.py +329 -0
demo_train.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import traceback
|
4 |
+
from collections import defaultdict
|
5 |
+
from collections.abc import Iterable
|
6 |
+
from enum import Enum, auto
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from datasets import load_dataset
|
10 |
+
from torch import Tensor
|
11 |
+
|
12 |
+
from sentence_transformers import (
|
13 |
+
SentenceTransformer,
|
14 |
+
SentenceTransformerModelCardData,
|
15 |
+
SentenceTransformerTrainer,
|
16 |
+
SentenceTransformerTrainingArguments,
|
17 |
+
)
|
18 |
+
from sentence_transformers.evaluation import InformationRetrievalEvaluator, NanoBEIREvaluator, SequentialEvaluator
|
19 |
+
from sentence_transformers.losses import (
|
20 |
+
CachedMultipleNegativesRankingLoss,
|
21 |
+
DistillKLDivLoss,
|
22 |
+
MarginMSELoss,
|
23 |
+
MultipleNegativesRankingLoss,
|
24 |
+
)
|
25 |
+
from sentence_transformers.training_args import BatchSamplers
|
26 |
+
from sentence_transformers.util import pairwise_dot_score
|
27 |
+
|
28 |
+
# Set the log level to INFO to get more information
|
29 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
|
30 |
+
|
31 |
+
|
32 |
+
class LossType(Enum):
|
33 |
+
MNRL = auto()
|
34 |
+
CMNRL = auto()
|
35 |
+
MARGIN_MSE = auto()
|
36 |
+
KLDIV = auto()
|
37 |
+
MARGIN_MSE_KLDIV = auto()
|
38 |
+
|
39 |
+
def __str__(self):
|
40 |
+
return self.name.lower()
|
41 |
+
|
42 |
+
|
43 |
+
class MarginMSEKLDivLoss(torch.nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
model: SentenceTransformer,
|
47 |
+
similarity_fct=pairwise_dot_score,
|
48 |
+
temperature=1.0,
|
49 |
+
margin_mse_weight=1.0,
|
50 |
+
kldiv_weight=1.0,
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
self.model = model
|
54 |
+
self.similarity_fct = similarity_fct
|
55 |
+
self.temperature = temperature
|
56 |
+
self.margin_mse_weight = margin_mse_weight
|
57 |
+
self.kldiv_weight = kldiv_weight
|
58 |
+
|
59 |
+
self.margin_mse_loss = MarginMSELoss(self.model, similarity_fct=self.similarity_fct)
|
60 |
+
self.kl_div_loss = DistillKLDivLoss(
|
61 |
+
self.model, similarity_fct=self.similarity_fct, temperature=self.temperature
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
|
65 |
+
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
|
66 |
+
|
67 |
+
return self.compute_loss_from_embeddings(embeddings, labels)
|
68 |
+
|
69 |
+
def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
|
70 |
+
return {
|
71 |
+
"margin_mse": self.margin_mse_loss.compute_loss_from_embeddings(embeddings, labels) * self.margin_mse_weight,
|
72 |
+
"kl_div": self.kl_div_loss.compute_loss_from_embeddings(embeddings, labels) * self.kldiv_weight
|
73 |
+
}
|
74 |
+
|
75 |
+
|
76 |
+
def main(
|
77 |
+
model_name_or_path: str,
|
78 |
+
loss_type: LossType,
|
79 |
+
kldiv_temperature: float,
|
80 |
+
margin_mse_weight: float,
|
81 |
+
kldiv_weight: float,
|
82 |
+
mini_batch_size: int,
|
83 |
+
mnrl_scale: float,
|
84 |
+
num_train_epochs: int,
|
85 |
+
per_device_batch_size: int,
|
86 |
+
learning_rate: float,
|
87 |
+
warmup_ratio: float,
|
88 |
+
fp16: bool,
|
89 |
+
bf16: bool,
|
90 |
+
eval_save_steps: int,
|
91 |
+
save_total_limit: int,
|
92 |
+
logging_steps: int,
|
93 |
+
evaluator_batch_size: int,
|
94 |
+
quick: bool,
|
95 |
+
):
|
96 |
+
# 1. Load a model with prompts to finetune with 2. (Optional) model card data
|
97 |
+
model = SentenceTransformer(
|
98 |
+
model_name_or_path,
|
99 |
+
model_card_data=SentenceTransformerModelCardData(
|
100 |
+
language="en",
|
101 |
+
license="apache-2.0",
|
102 |
+
model_name=f"{model_name_or_path} trained on RLHN MS MARCO using {loss_type}",
|
103 |
+
),
|
104 |
+
prompts={ # prompts with "query" and "document" keys are automatically used in evaluation via model.encode_query/model.encode_document
|
105 |
+
"query": "query: ",
|
106 |
+
"document": "document: ",
|
107 |
+
},
|
108 |
+
)
|
109 |
+
|
110 |
+
# 3. Load a dataset to finetune on
|
111 |
+
# TODO: Eventually we want this:
|
112 |
+
"""
|
113 |
+
train_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="train")
|
114 |
+
eval_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="eval")
|
115 |
+
test_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="test")
|
116 |
+
train_dataset = train_dataset.select_columns([column for column in train_dataset.column_names if column != 'logits'])
|
117 |
+
eval_dataset = eval_dataset.select_columns([column for column in eval_dataset.column_names if column != 'logits'])
|
118 |
+
test_dataset = test_dataset.select_columns([column for column in test_dataset.column_names if column != 'logits'])
|
119 |
+
"""
|
120 |
+
# But for now we do it manually:
|
121 |
+
dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs-scored", split="train")
|
122 |
+
dataset = dataset.select_columns([column for column in dataset.column_names if column != "logits"])
|
123 |
+
split_dataset = dataset.train_test_split(test_size=3_000)
|
124 |
+
dataset = split_dataset["train"]
|
125 |
+
eval_dataset = split_dataset["test"]
|
126 |
+
split_dataset = dataset.train_test_split(test_size=10_000)
|
127 |
+
train_dataset = split_dataset["train"]
|
128 |
+
test_dataset = split_dataset["test"]
|
129 |
+
|
130 |
+
# 4. Define a loss function
|
131 |
+
batch_sampler = BatchSamplers.BATCH_SAMPLER
|
132 |
+
gather_across_devices = torch.distributed.is_initialized() if torch.distributed.is_available() else False
|
133 |
+
if loss_type == LossType.MNRL:
|
134 |
+
loss = MultipleNegativesRankingLoss(model, scale=mnrl_scale, gather_across_devices=gather_across_devices)
|
135 |
+
batch_sampler = BatchSamplers.NO_DUPLICATES
|
136 |
+
elif loss_type == LossType.CMNRL:
|
137 |
+
loss = CachedMultipleNegativesRankingLoss(
|
138 |
+
model, scale=mnrl_scale, mini_batch_size=mini_batch_size, gather_across_devices=gather_across_devices
|
139 |
+
)
|
140 |
+
batch_sampler = BatchSamplers.NO_DUPLICATES
|
141 |
+
elif loss_type == LossType.MARGIN_MSE:
|
142 |
+
loss = MarginMSELoss(model)
|
143 |
+
elif loss_type == LossType.KLDIV:
|
144 |
+
loss = DistillKLDivLoss(model, temperature=kldiv_temperature)
|
145 |
+
elif loss_type == LossType.MARGIN_MSE_KLDIV:
|
146 |
+
loss = MarginMSEKLDivLoss(
|
147 |
+
model, temperature=kldiv_temperature, margin_mse_weight=margin_mse_weight, kldiv_weight=kldiv_weight
|
148 |
+
)
|
149 |
+
|
150 |
+
# 5. (Optional) Specify training arguments
|
151 |
+
short_model_name_or_path = model_name_or_path.split("/")[-1]
|
152 |
+
run_name = f"{short_model_name_or_path}-{loss_type}-lr{learning_rate}-bs{per_device_batch_size}"
|
153 |
+
column_names_to_prompts = {
|
154 |
+
column_name: "query" if column_name == "query" else "document"
|
155 |
+
for column_name in dataset.column_names
|
156 |
+
if column_name != "scores"
|
157 |
+
}
|
158 |
+
args = SentenceTransformerTrainingArguments(
|
159 |
+
# Required parameter:
|
160 |
+
output_dir=f"models/{run_name}",
|
161 |
+
# Optional training parameters:
|
162 |
+
num_train_epochs=0.05 if quick else num_train_epochs,
|
163 |
+
per_device_train_batch_size=per_device_batch_size,
|
164 |
+
per_device_eval_batch_size=per_device_batch_size,
|
165 |
+
learning_rate=learning_rate,
|
166 |
+
warmup_ratio=warmup_ratio,
|
167 |
+
fp16=fp16, # Set to False if you get an error that your GPU can't run on FP16
|
168 |
+
bf16=bf16, # Set to True if you have a GPU that supports BF16
|
169 |
+
batch_sampler=batch_sampler, # (C)MNRL benefits from no duplicate samples in a batch
|
170 |
+
prompts=column_names_to_prompts, # Let's incorporate prompts for a ~1% improvement
|
171 |
+
# Optional tracking/debugging parameters:
|
172 |
+
eval_strategy="steps",
|
173 |
+
eval_steps=eval_save_steps,
|
174 |
+
save_strategy="steps",
|
175 |
+
save_steps=eval_save_steps,
|
176 |
+
save_total_limit=save_total_limit,
|
177 |
+
logging_steps=logging_steps,
|
178 |
+
run_name=run_name,
|
179 |
+
)
|
180 |
+
|
181 |
+
# 6. (Optional) Create evaluator & evaluate the base model
|
182 |
+
nano_beir_evaluator = NanoBEIREvaluator(
|
183 |
+
dataset_names=["msmarco", "nfcorpus", "nq"],
|
184 |
+
batch_size=evaluator_batch_size,
|
185 |
+
query_prompts=model.prompts["query"], # This will be done automatically starting from the next version
|
186 |
+
corpus_prompts=model.prompts["document"], # This will be done automatically starting from the next version
|
187 |
+
)
|
188 |
+
eval_queries = {}
|
189 |
+
eval_documents = {}
|
190 |
+
eval_relevant_docs = defaultdict(set)
|
191 |
+
for query, positive in zip(eval_dataset["query"], eval_dataset["positive"]):
|
192 |
+
query_id = len(eval_queries)
|
193 |
+
eval_queries[query_id] = query
|
194 |
+
document_id = len(eval_documents)
|
195 |
+
eval_documents[document_id] = positive
|
196 |
+
eval_relevant_docs[query_id].add(document_id)
|
197 |
+
for column_name in test_dataset.column_names:
|
198 |
+
if column_name.startswith("negative"):
|
199 |
+
for negative in test_dataset[column_name]:
|
200 |
+
document_id = len(eval_documents)
|
201 |
+
eval_documents[document_id] = negative
|
202 |
+
eval_ir_evaluator = InformationRetrievalEvaluator(
|
203 |
+
queries=eval_queries,
|
204 |
+
corpus=eval_documents,
|
205 |
+
relevant_docs=eval_relevant_docs,
|
206 |
+
name="rlhn-msmarco-eval",
|
207 |
+
batch_size=evaluator_batch_size,
|
208 |
+
query_prompt_name="query", # This will be done automatically starting from the next version
|
209 |
+
corpus_prompt_name="document", # This will be done automatically starting from the next version
|
210 |
+
)
|
211 |
+
eval_evaluator = SequentialEvaluator([nano_beir_evaluator, eval_ir_evaluator])
|
212 |
+
if not quick:
|
213 |
+
eval_evaluator(model)
|
214 |
+
|
215 |
+
test_queries = {}
|
216 |
+
test_documents = {}
|
217 |
+
test_relevant_docs = defaultdict(set)
|
218 |
+
for query, positive in zip(test_dataset["query"], test_dataset["positive"]):
|
219 |
+
query_id = len(test_queries)
|
220 |
+
test_queries[query_id] = query
|
221 |
+
document_id = len(test_documents)
|
222 |
+
test_documents[document_id] = positive
|
223 |
+
test_relevant_docs[query_id].add(document_id)
|
224 |
+
for column_name in test_dataset.column_names:
|
225 |
+
if column_name.startswith("negative"):
|
226 |
+
for negative in test_dataset[column_name]:
|
227 |
+
document_id = len(test_documents)
|
228 |
+
test_documents[document_id] = negative
|
229 |
+
test_ir_evaluator = InformationRetrievalEvaluator(
|
230 |
+
queries=test_queries,
|
231 |
+
corpus=test_documents,
|
232 |
+
relevant_docs=test_relevant_docs,
|
233 |
+
name="rlhn-msmarco-test",
|
234 |
+
batch_size=evaluator_batch_size,
|
235 |
+
query_prompt_name="query", # This will be done automatically starting from the next version
|
236 |
+
corpus_prompt_name="document", # This will be done automatically starting from the next version
|
237 |
+
)
|
238 |
+
test_evaluator = SequentialEvaluator([test_ir_evaluator])
|
239 |
+
if not quick:
|
240 |
+
test_evaluator(model)
|
241 |
+
|
242 |
+
# 7. Create a trainer & train
|
243 |
+
trainer = SentenceTransformerTrainer(
|
244 |
+
model=model,
|
245 |
+
args=args,
|
246 |
+
train_dataset=train_dataset,
|
247 |
+
eval_dataset=eval_dataset,
|
248 |
+
loss=loss,
|
249 |
+
evaluator=eval_evaluator,
|
250 |
+
)
|
251 |
+
trainer.train()
|
252 |
+
|
253 |
+
# (Optional) Evaluate the trained model on the eval & test sets again
|
254 |
+
eval_evaluator(model)
|
255 |
+
test_evaluator(model)
|
256 |
+
|
257 |
+
# 8. Save the final model
|
258 |
+
final_output_dir = f"models/{run_name}/final"
|
259 |
+
model.save_pretrained(final_output_dir)
|
260 |
+
|
261 |
+
# 9. (Optional) save the model to the Hugging Face Hub!
|
262 |
+
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
|
263 |
+
try:
|
264 |
+
model.push_to_hub(run_name, private=True)
|
265 |
+
except Exception:
|
266 |
+
logging.error(
|
267 |
+
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
|
268 |
+
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
|
269 |
+
f"and saving it using `model.push_to_hub('{run_name}')`."
|
270 |
+
)
|
271 |
+
|
272 |
+
|
273 |
+
if __name__ == "__main__":
|
274 |
+
parser = argparse.ArgumentParser(description="Train a sentence transformer model on RLHN MS MARCO dataset")
|
275 |
+
parser.add_argument(
|
276 |
+
"--model_name_or_path", type=str, default="jhu-clsp/ettin-encoder-17m", help="Model name or path to load"
|
277 |
+
)
|
278 |
+
parser.add_argument(
|
279 |
+
"--loss_type",
|
280 |
+
type=lambda x: LossType[x.upper()],
|
281 |
+
default=LossType.CMNRL,
|
282 |
+
choices=list(LossType),
|
283 |
+
help="Loss function to use",
|
284 |
+
)
|
285 |
+
parser.add_argument("--kldiv_temperature", type=float, default=1.0, help="Temperature for KL divergence loss")
|
286 |
+
parser.add_argument("--margin_mse_weight", type=float, default=1.0, help="Weight for margin MSE in combined loss")
|
287 |
+
parser.add_argument("--kldiv_weight", type=float, default=1.0, help="Weight for KL divergence in combined loss")
|
288 |
+
parser.add_argument("--mini_batch_size", type=int, default=16, help="Mini-batch size for cached MNRL")
|
289 |
+
parser.add_argument("--mnrl_scale", type=float, default=20.0, help="Scale factor for MNRL loss")
|
290 |
+
parser.add_argument("--num_train_epochs", type=int, default=1, help="Number of training epochs")
|
291 |
+
parser.add_argument("--per_device_batch_size", type=int, default=128, help="Batch size per device")
|
292 |
+
parser.add_argument("--evaluator_batch_size", type=int, default=32, help="Batch size for the evaluators")
|
293 |
+
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
|
294 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.1, help="Ratio of warmup steps")
|
295 |
+
parser.add_argument("--fp16", action="store_true", help="Use FP16 precision")
|
296 |
+
parser.add_argument("--bf16", action="store_true", default=True, help="Use BF16 precision")
|
297 |
+
parser.add_argument(
|
298 |
+
"--eval_save_steps",
|
299 |
+
type=float,
|
300 |
+
default=0.2,
|
301 |
+
help="Steps between evaluations and checkpoint saves. If less than 1, "
|
302 |
+
"it will be treated as a fraction of the total steps.",
|
303 |
+
)
|
304 |
+
parser.add_argument("--save_total_limit", type=int, default=3, help="Maximum number of checkpoints to keep")
|
305 |
+
parser.add_argument("--logging_steps", type=int, default=100, help="Steps between logging")
|
306 |
+
parser.add_argument("--quick", action="store_true", help="Run with only 5% of training data for quick testing")
|
307 |
+
|
308 |
+
args = parser.parse_args()
|
309 |
+
|
310 |
+
main(
|
311 |
+
model_name_or_path=args.model_name_or_path,
|
312 |
+
loss_type=args.loss_type,
|
313 |
+
kldiv_temperature=args.kldiv_temperature,
|
314 |
+
margin_mse_weight=args.margin_mse_weight,
|
315 |
+
kldiv_weight=args.kldiv_weight,
|
316 |
+
mini_batch_size=args.mini_batch_size,
|
317 |
+
mnrl_scale=args.mnrl_scale,
|
318 |
+
num_train_epochs=args.num_train_epochs,
|
319 |
+
per_device_batch_size=args.per_device_batch_size,
|
320 |
+
learning_rate=args.learning_rate,
|
321 |
+
warmup_ratio=args.warmup_ratio,
|
322 |
+
fp16=args.fp16,
|
323 |
+
bf16=args.bf16,
|
324 |
+
eval_save_steps=args.eval_save_steps,
|
325 |
+
save_total_limit=args.save_total_limit,
|
326 |
+
logging_steps=args.logging_steps,
|
327 |
+
evaluator_batch_size=args.evaluator_batch_size,
|
328 |
+
quick=args.quick,
|
329 |
+
)
|