tomaarsen HF Staff commited on
Commit
d10f6aa
·
verified ·
1 Parent(s): f655c13

Create demo_train.py

Browse files
Files changed (1) hide show
  1. demo_train.py +329 -0
demo_train.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import traceback
4
+ from collections import defaultdict
5
+ from collections.abc import Iterable
6
+ from enum import Enum, auto
7
+
8
+ import torch
9
+ from datasets import load_dataset
10
+ from torch import Tensor
11
+
12
+ from sentence_transformers import (
13
+ SentenceTransformer,
14
+ SentenceTransformerModelCardData,
15
+ SentenceTransformerTrainer,
16
+ SentenceTransformerTrainingArguments,
17
+ )
18
+ from sentence_transformers.evaluation import InformationRetrievalEvaluator, NanoBEIREvaluator, SequentialEvaluator
19
+ from sentence_transformers.losses import (
20
+ CachedMultipleNegativesRankingLoss,
21
+ DistillKLDivLoss,
22
+ MarginMSELoss,
23
+ MultipleNegativesRankingLoss,
24
+ )
25
+ from sentence_transformers.training_args import BatchSamplers
26
+ from sentence_transformers.util import pairwise_dot_score
27
+
28
+ # Set the log level to INFO to get more information
29
+ logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
30
+
31
+
32
+ class LossType(Enum):
33
+ MNRL = auto()
34
+ CMNRL = auto()
35
+ MARGIN_MSE = auto()
36
+ KLDIV = auto()
37
+ MARGIN_MSE_KLDIV = auto()
38
+
39
+ def __str__(self):
40
+ return self.name.lower()
41
+
42
+
43
+ class MarginMSEKLDivLoss(torch.nn.Module):
44
+ def __init__(
45
+ self,
46
+ model: SentenceTransformer,
47
+ similarity_fct=pairwise_dot_score,
48
+ temperature=1.0,
49
+ margin_mse_weight=1.0,
50
+ kldiv_weight=1.0,
51
+ ) -> None:
52
+ super().__init__()
53
+ self.model = model
54
+ self.similarity_fct = similarity_fct
55
+ self.temperature = temperature
56
+ self.margin_mse_weight = margin_mse_weight
57
+ self.kldiv_weight = kldiv_weight
58
+
59
+ self.margin_mse_loss = MarginMSELoss(self.model, similarity_fct=self.similarity_fct)
60
+ self.kl_div_loss = DistillKLDivLoss(
61
+ self.model, similarity_fct=self.similarity_fct, temperature=self.temperature
62
+ )
63
+
64
+ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
65
+ embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
66
+
67
+ return self.compute_loss_from_embeddings(embeddings, labels)
68
+
69
+ def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
70
+ return {
71
+ "margin_mse": self.margin_mse_loss.compute_loss_from_embeddings(embeddings, labels) * self.margin_mse_weight,
72
+ "kl_div": self.kl_div_loss.compute_loss_from_embeddings(embeddings, labels) * self.kldiv_weight
73
+ }
74
+
75
+
76
+ def main(
77
+ model_name_or_path: str,
78
+ loss_type: LossType,
79
+ kldiv_temperature: float,
80
+ margin_mse_weight: float,
81
+ kldiv_weight: float,
82
+ mini_batch_size: int,
83
+ mnrl_scale: float,
84
+ num_train_epochs: int,
85
+ per_device_batch_size: int,
86
+ learning_rate: float,
87
+ warmup_ratio: float,
88
+ fp16: bool,
89
+ bf16: bool,
90
+ eval_save_steps: int,
91
+ save_total_limit: int,
92
+ logging_steps: int,
93
+ evaluator_batch_size: int,
94
+ quick: bool,
95
+ ):
96
+ # 1. Load a model with prompts to finetune with 2. (Optional) model card data
97
+ model = SentenceTransformer(
98
+ model_name_or_path,
99
+ model_card_data=SentenceTransformerModelCardData(
100
+ language="en",
101
+ license="apache-2.0",
102
+ model_name=f"{model_name_or_path} trained on RLHN MS MARCO using {loss_type}",
103
+ ),
104
+ prompts={ # prompts with "query" and "document" keys are automatically used in evaluation via model.encode_query/model.encode_document
105
+ "query": "query: ",
106
+ "document": "document: ",
107
+ },
108
+ )
109
+
110
+ # 3. Load a dataset to finetune on
111
+ # TODO: Eventually we want this:
112
+ """
113
+ train_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="train")
114
+ eval_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="eval")
115
+ test_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="test")
116
+ train_dataset = train_dataset.select_columns([column for column in train_dataset.column_names if column != 'logits'])
117
+ eval_dataset = eval_dataset.select_columns([column for column in eval_dataset.column_names if column != 'logits'])
118
+ test_dataset = test_dataset.select_columns([column for column in test_dataset.column_names if column != 'logits'])
119
+ """
120
+ # But for now we do it manually:
121
+ dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs-scored", split="train")
122
+ dataset = dataset.select_columns([column for column in dataset.column_names if column != "logits"])
123
+ split_dataset = dataset.train_test_split(test_size=3_000)
124
+ dataset = split_dataset["train"]
125
+ eval_dataset = split_dataset["test"]
126
+ split_dataset = dataset.train_test_split(test_size=10_000)
127
+ train_dataset = split_dataset["train"]
128
+ test_dataset = split_dataset["test"]
129
+
130
+ # 4. Define a loss function
131
+ batch_sampler = BatchSamplers.BATCH_SAMPLER
132
+ gather_across_devices = torch.distributed.is_initialized() if torch.distributed.is_available() else False
133
+ if loss_type == LossType.MNRL:
134
+ loss = MultipleNegativesRankingLoss(model, scale=mnrl_scale, gather_across_devices=gather_across_devices)
135
+ batch_sampler = BatchSamplers.NO_DUPLICATES
136
+ elif loss_type == LossType.CMNRL:
137
+ loss = CachedMultipleNegativesRankingLoss(
138
+ model, scale=mnrl_scale, mini_batch_size=mini_batch_size, gather_across_devices=gather_across_devices
139
+ )
140
+ batch_sampler = BatchSamplers.NO_DUPLICATES
141
+ elif loss_type == LossType.MARGIN_MSE:
142
+ loss = MarginMSELoss(model)
143
+ elif loss_type == LossType.KLDIV:
144
+ loss = DistillKLDivLoss(model, temperature=kldiv_temperature)
145
+ elif loss_type == LossType.MARGIN_MSE_KLDIV:
146
+ loss = MarginMSEKLDivLoss(
147
+ model, temperature=kldiv_temperature, margin_mse_weight=margin_mse_weight, kldiv_weight=kldiv_weight
148
+ )
149
+
150
+ # 5. (Optional) Specify training arguments
151
+ short_model_name_or_path = model_name_or_path.split("/")[-1]
152
+ run_name = f"{short_model_name_or_path}-{loss_type}-lr{learning_rate}-bs{per_device_batch_size}"
153
+ column_names_to_prompts = {
154
+ column_name: "query" if column_name == "query" else "document"
155
+ for column_name in dataset.column_names
156
+ if column_name != "scores"
157
+ }
158
+ args = SentenceTransformerTrainingArguments(
159
+ # Required parameter:
160
+ output_dir=f"models/{run_name}",
161
+ # Optional training parameters:
162
+ num_train_epochs=0.05 if quick else num_train_epochs,
163
+ per_device_train_batch_size=per_device_batch_size,
164
+ per_device_eval_batch_size=per_device_batch_size,
165
+ learning_rate=learning_rate,
166
+ warmup_ratio=warmup_ratio,
167
+ fp16=fp16, # Set to False if you get an error that your GPU can't run on FP16
168
+ bf16=bf16, # Set to True if you have a GPU that supports BF16
169
+ batch_sampler=batch_sampler, # (C)MNRL benefits from no duplicate samples in a batch
170
+ prompts=column_names_to_prompts, # Let's incorporate prompts for a ~1% improvement
171
+ # Optional tracking/debugging parameters:
172
+ eval_strategy="steps",
173
+ eval_steps=eval_save_steps,
174
+ save_strategy="steps",
175
+ save_steps=eval_save_steps,
176
+ save_total_limit=save_total_limit,
177
+ logging_steps=logging_steps,
178
+ run_name=run_name,
179
+ )
180
+
181
+ # 6. (Optional) Create evaluator & evaluate the base model
182
+ nano_beir_evaluator = NanoBEIREvaluator(
183
+ dataset_names=["msmarco", "nfcorpus", "nq"],
184
+ batch_size=evaluator_batch_size,
185
+ query_prompts=model.prompts["query"], # This will be done automatically starting from the next version
186
+ corpus_prompts=model.prompts["document"], # This will be done automatically starting from the next version
187
+ )
188
+ eval_queries = {}
189
+ eval_documents = {}
190
+ eval_relevant_docs = defaultdict(set)
191
+ for query, positive in zip(eval_dataset["query"], eval_dataset["positive"]):
192
+ query_id = len(eval_queries)
193
+ eval_queries[query_id] = query
194
+ document_id = len(eval_documents)
195
+ eval_documents[document_id] = positive
196
+ eval_relevant_docs[query_id].add(document_id)
197
+ for column_name in test_dataset.column_names:
198
+ if column_name.startswith("negative"):
199
+ for negative in test_dataset[column_name]:
200
+ document_id = len(eval_documents)
201
+ eval_documents[document_id] = negative
202
+ eval_ir_evaluator = InformationRetrievalEvaluator(
203
+ queries=eval_queries,
204
+ corpus=eval_documents,
205
+ relevant_docs=eval_relevant_docs,
206
+ name="rlhn-msmarco-eval",
207
+ batch_size=evaluator_batch_size,
208
+ query_prompt_name="query", # This will be done automatically starting from the next version
209
+ corpus_prompt_name="document", # This will be done automatically starting from the next version
210
+ )
211
+ eval_evaluator = SequentialEvaluator([nano_beir_evaluator, eval_ir_evaluator])
212
+ if not quick:
213
+ eval_evaluator(model)
214
+
215
+ test_queries = {}
216
+ test_documents = {}
217
+ test_relevant_docs = defaultdict(set)
218
+ for query, positive in zip(test_dataset["query"], test_dataset["positive"]):
219
+ query_id = len(test_queries)
220
+ test_queries[query_id] = query
221
+ document_id = len(test_documents)
222
+ test_documents[document_id] = positive
223
+ test_relevant_docs[query_id].add(document_id)
224
+ for column_name in test_dataset.column_names:
225
+ if column_name.startswith("negative"):
226
+ for negative in test_dataset[column_name]:
227
+ document_id = len(test_documents)
228
+ test_documents[document_id] = negative
229
+ test_ir_evaluator = InformationRetrievalEvaluator(
230
+ queries=test_queries,
231
+ corpus=test_documents,
232
+ relevant_docs=test_relevant_docs,
233
+ name="rlhn-msmarco-test",
234
+ batch_size=evaluator_batch_size,
235
+ query_prompt_name="query", # This will be done automatically starting from the next version
236
+ corpus_prompt_name="document", # This will be done automatically starting from the next version
237
+ )
238
+ test_evaluator = SequentialEvaluator([test_ir_evaluator])
239
+ if not quick:
240
+ test_evaluator(model)
241
+
242
+ # 7. Create a trainer & train
243
+ trainer = SentenceTransformerTrainer(
244
+ model=model,
245
+ args=args,
246
+ train_dataset=train_dataset,
247
+ eval_dataset=eval_dataset,
248
+ loss=loss,
249
+ evaluator=eval_evaluator,
250
+ )
251
+ trainer.train()
252
+
253
+ # (Optional) Evaluate the trained model on the eval & test sets again
254
+ eval_evaluator(model)
255
+ test_evaluator(model)
256
+
257
+ # 8. Save the final model
258
+ final_output_dir = f"models/{run_name}/final"
259
+ model.save_pretrained(final_output_dir)
260
+
261
+ # 9. (Optional) save the model to the Hugging Face Hub!
262
+ # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
263
+ try:
264
+ model.push_to_hub(run_name, private=True)
265
+ except Exception:
266
+ logging.error(
267
+ f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
268
+ f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
269
+ f"and saving it using `model.push_to_hub('{run_name}')`."
270
+ )
271
+
272
+
273
+ if __name__ == "__main__":
274
+ parser = argparse.ArgumentParser(description="Train a sentence transformer model on RLHN MS MARCO dataset")
275
+ parser.add_argument(
276
+ "--model_name_or_path", type=str, default="jhu-clsp/ettin-encoder-17m", help="Model name or path to load"
277
+ )
278
+ parser.add_argument(
279
+ "--loss_type",
280
+ type=lambda x: LossType[x.upper()],
281
+ default=LossType.CMNRL,
282
+ choices=list(LossType),
283
+ help="Loss function to use",
284
+ )
285
+ parser.add_argument("--kldiv_temperature", type=float, default=1.0, help="Temperature for KL divergence loss")
286
+ parser.add_argument("--margin_mse_weight", type=float, default=1.0, help="Weight for margin MSE in combined loss")
287
+ parser.add_argument("--kldiv_weight", type=float, default=1.0, help="Weight for KL divergence in combined loss")
288
+ parser.add_argument("--mini_batch_size", type=int, default=16, help="Mini-batch size for cached MNRL")
289
+ parser.add_argument("--mnrl_scale", type=float, default=20.0, help="Scale factor for MNRL loss")
290
+ parser.add_argument("--num_train_epochs", type=int, default=1, help="Number of training epochs")
291
+ parser.add_argument("--per_device_batch_size", type=int, default=128, help="Batch size per device")
292
+ parser.add_argument("--evaluator_batch_size", type=int, default=32, help="Batch size for the evaluators")
293
+ parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
294
+ parser.add_argument("--warmup_ratio", type=float, default=0.1, help="Ratio of warmup steps")
295
+ parser.add_argument("--fp16", action="store_true", help="Use FP16 precision")
296
+ parser.add_argument("--bf16", action="store_true", default=True, help="Use BF16 precision")
297
+ parser.add_argument(
298
+ "--eval_save_steps",
299
+ type=float,
300
+ default=0.2,
301
+ help="Steps between evaluations and checkpoint saves. If less than 1, "
302
+ "it will be treated as a fraction of the total steps.",
303
+ )
304
+ parser.add_argument("--save_total_limit", type=int, default=3, help="Maximum number of checkpoints to keep")
305
+ parser.add_argument("--logging_steps", type=int, default=100, help="Steps between logging")
306
+ parser.add_argument("--quick", action="store_true", help="Run with only 5% of training data for quick testing")
307
+
308
+ args = parser.parse_args()
309
+
310
+ main(
311
+ model_name_or_path=args.model_name_or_path,
312
+ loss_type=args.loss_type,
313
+ kldiv_temperature=args.kldiv_temperature,
314
+ margin_mse_weight=args.margin_mse_weight,
315
+ kldiv_weight=args.kldiv_weight,
316
+ mini_batch_size=args.mini_batch_size,
317
+ mnrl_scale=args.mnrl_scale,
318
+ num_train_epochs=args.num_train_epochs,
319
+ per_device_batch_size=args.per_device_batch_size,
320
+ learning_rate=args.learning_rate,
321
+ warmup_ratio=args.warmup_ratio,
322
+ fp16=args.fp16,
323
+ bf16=args.bf16,
324
+ eval_save_steps=args.eval_save_steps,
325
+ save_total_limit=args.save_total_limit,
326
+ logging_steps=args.logging_steps,
327
+ evaluator_batch_size=args.evaluator_batch_size,
328
+ quick=args.quick,
329
+ )