|
import argparse |
|
import logging |
|
import traceback |
|
from collections import defaultdict |
|
from collections.abc import Iterable |
|
from enum import Enum, auto |
|
|
|
import torch |
|
from datasets import load_dataset |
|
from torch import Tensor |
|
|
|
from sentence_transformers import ( |
|
SentenceTransformer, |
|
SentenceTransformerModelCardData, |
|
SentenceTransformerTrainer, |
|
SentenceTransformerTrainingArguments, |
|
) |
|
from sentence_transformers.evaluation import InformationRetrievalEvaluator, NanoBEIREvaluator, SequentialEvaluator |
|
from sentence_transformers.losses import ( |
|
CachedMultipleNegativesRankingLoss, |
|
DistillKLDivLoss, |
|
MarginMSELoss, |
|
MultipleNegativesRankingLoss, |
|
) |
|
from sentence_transformers.training_args import BatchSamplers |
|
from sentence_transformers.util import pairwise_dot_score |
|
|
|
|
|
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) |
|
|
|
|
|
class LossType(Enum): |
|
MNRL = auto() |
|
CMNRL = auto() |
|
MARGIN_MSE = auto() |
|
KLDIV = auto() |
|
MARGIN_MSE_KLDIV = auto() |
|
|
|
def __str__(self): |
|
return self.name.lower() |
|
|
|
|
|
class MarginMSEKLDivLoss(torch.nn.Module): |
|
def __init__( |
|
self, |
|
model: SentenceTransformer, |
|
similarity_fct=pairwise_dot_score, |
|
temperature=1.0, |
|
margin_mse_weight=1.0, |
|
kldiv_weight=1.0, |
|
) -> None: |
|
super().__init__() |
|
self.model = model |
|
self.similarity_fct = similarity_fct |
|
self.temperature = temperature |
|
self.margin_mse_weight = margin_mse_weight |
|
self.kldiv_weight = kldiv_weight |
|
|
|
self.margin_mse_loss = MarginMSELoss(self.model, similarity_fct=self.similarity_fct) |
|
self.kl_div_loss = DistillKLDivLoss( |
|
self.model, similarity_fct=self.similarity_fct, temperature=self.temperature |
|
) |
|
|
|
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: |
|
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] |
|
|
|
return self.compute_loss_from_embeddings(embeddings, labels) |
|
|
|
def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor: |
|
return { |
|
"margin_mse": self.margin_mse_loss.compute_loss_from_embeddings(embeddings, labels) * self.margin_mse_weight, |
|
"kl_div": self.kl_div_loss.compute_loss_from_embeddings(embeddings, labels) * self.kldiv_weight |
|
} |
|
|
|
|
|
def main( |
|
model_name_or_path: str, |
|
loss_type: LossType, |
|
kldiv_temperature: float, |
|
margin_mse_weight: float, |
|
kldiv_weight: float, |
|
mini_batch_size: int, |
|
mnrl_scale: float, |
|
num_train_epochs: int, |
|
per_device_batch_size: int, |
|
learning_rate: float, |
|
warmup_ratio: float, |
|
fp16: bool, |
|
bf16: bool, |
|
eval_save_steps: int, |
|
save_total_limit: int, |
|
logging_steps: int, |
|
evaluator_batch_size: int, |
|
quick: bool, |
|
): |
|
|
|
model = SentenceTransformer( |
|
model_name_or_path, |
|
model_card_data=SentenceTransformerModelCardData( |
|
language="en", |
|
license="apache-2.0", |
|
model_name=f"{model_name_or_path} trained on RLHN MS MARCO using {loss_type}", |
|
), |
|
prompts={ |
|
"query": "query: ", |
|
"document": "document: ", |
|
}, |
|
) |
|
|
|
|
|
|
|
""" |
|
train_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="train") |
|
eval_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="eval") |
|
test_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="test") |
|
train_dataset = train_dataset.select_columns([column for column in train_dataset.column_names if column != 'logits']) |
|
eval_dataset = eval_dataset.select_columns([column for column in eval_dataset.column_names if column != 'logits']) |
|
test_dataset = test_dataset.select_columns([column for column in test_dataset.column_names if column != 'logits']) |
|
""" |
|
|
|
dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs-scored", split="train") |
|
dataset = dataset.select_columns([column for column in dataset.column_names if column != "logits"]) |
|
split_dataset = dataset.train_test_split(test_size=3_000) |
|
dataset = split_dataset["train"] |
|
eval_dataset = split_dataset["test"] |
|
split_dataset = dataset.train_test_split(test_size=10_000) |
|
train_dataset = split_dataset["train"] |
|
test_dataset = split_dataset["test"] |
|
|
|
|
|
batch_sampler = BatchSamplers.BATCH_SAMPLER |
|
gather_across_devices = torch.distributed.is_initialized() if torch.distributed.is_available() else False |
|
if loss_type == LossType.MNRL: |
|
loss = MultipleNegativesRankingLoss(model, scale=mnrl_scale, gather_across_devices=gather_across_devices) |
|
batch_sampler = BatchSamplers.NO_DUPLICATES |
|
elif loss_type == LossType.CMNRL: |
|
loss = CachedMultipleNegativesRankingLoss( |
|
model, scale=mnrl_scale, mini_batch_size=mini_batch_size, gather_across_devices=gather_across_devices |
|
) |
|
batch_sampler = BatchSamplers.NO_DUPLICATES |
|
elif loss_type == LossType.MARGIN_MSE: |
|
loss = MarginMSELoss(model) |
|
elif loss_type == LossType.KLDIV: |
|
loss = DistillKLDivLoss(model, temperature=kldiv_temperature) |
|
elif loss_type == LossType.MARGIN_MSE_KLDIV: |
|
loss = MarginMSEKLDivLoss( |
|
model, temperature=kldiv_temperature, margin_mse_weight=margin_mse_weight, kldiv_weight=kldiv_weight |
|
) |
|
|
|
|
|
short_model_name_or_path = model_name_or_path.split("/")[-1] |
|
run_name = f"{short_model_name_or_path}-{loss_type}-lr{learning_rate}-bs{per_device_batch_size}" |
|
column_names_to_prompts = { |
|
column_name: "query" if column_name == "query" else "document" |
|
for column_name in dataset.column_names |
|
if column_name != "scores" |
|
} |
|
args = SentenceTransformerTrainingArguments( |
|
|
|
output_dir=f"models/{run_name}", |
|
|
|
num_train_epochs=0.05 if quick else num_train_epochs, |
|
per_device_train_batch_size=per_device_batch_size, |
|
per_device_eval_batch_size=per_device_batch_size, |
|
learning_rate=learning_rate, |
|
warmup_ratio=warmup_ratio, |
|
fp16=fp16, |
|
bf16=bf16, |
|
batch_sampler=batch_sampler, |
|
prompts=column_names_to_prompts, |
|
|
|
eval_strategy="steps", |
|
eval_steps=eval_save_steps, |
|
save_strategy="steps", |
|
save_steps=eval_save_steps, |
|
save_total_limit=save_total_limit, |
|
logging_steps=logging_steps, |
|
run_name=run_name, |
|
) |
|
|
|
|
|
nano_beir_evaluator = NanoBEIREvaluator( |
|
dataset_names=["msmarco", "nfcorpus", "nq"], |
|
batch_size=evaluator_batch_size, |
|
query_prompts=model.prompts["query"], |
|
corpus_prompts=model.prompts["document"], |
|
) |
|
eval_queries = {} |
|
eval_documents = {} |
|
eval_relevant_docs = defaultdict(set) |
|
for query, positive in zip(eval_dataset["query"], eval_dataset["positive"]): |
|
query_id = len(eval_queries) |
|
eval_queries[query_id] = query |
|
document_id = len(eval_documents) |
|
eval_documents[document_id] = positive |
|
eval_relevant_docs[query_id].add(document_id) |
|
for column_name in test_dataset.column_names: |
|
if column_name.startswith("negative"): |
|
for negative in test_dataset[column_name]: |
|
document_id = len(eval_documents) |
|
eval_documents[document_id] = negative |
|
eval_ir_evaluator = InformationRetrievalEvaluator( |
|
queries=eval_queries, |
|
corpus=eval_documents, |
|
relevant_docs=eval_relevant_docs, |
|
name="rlhn-msmarco-eval", |
|
batch_size=evaluator_batch_size, |
|
query_prompt_name="query", |
|
corpus_prompt_name="document", |
|
) |
|
eval_evaluator = SequentialEvaluator([nano_beir_evaluator, eval_ir_evaluator]) |
|
if not quick: |
|
eval_evaluator(model) |
|
|
|
test_queries = {} |
|
test_documents = {} |
|
test_relevant_docs = defaultdict(set) |
|
for query, positive in zip(test_dataset["query"], test_dataset["positive"]): |
|
query_id = len(test_queries) |
|
test_queries[query_id] = query |
|
document_id = len(test_documents) |
|
test_documents[document_id] = positive |
|
test_relevant_docs[query_id].add(document_id) |
|
for column_name in test_dataset.column_names: |
|
if column_name.startswith("negative"): |
|
for negative in test_dataset[column_name]: |
|
document_id = len(test_documents) |
|
test_documents[document_id] = negative |
|
test_ir_evaluator = InformationRetrievalEvaluator( |
|
queries=test_queries, |
|
corpus=test_documents, |
|
relevant_docs=test_relevant_docs, |
|
name="rlhn-msmarco-test", |
|
batch_size=evaluator_batch_size, |
|
query_prompt_name="query", |
|
corpus_prompt_name="document", |
|
) |
|
test_evaluator = SequentialEvaluator([test_ir_evaluator]) |
|
if not quick: |
|
test_evaluator(model) |
|
|
|
|
|
trainer = SentenceTransformerTrainer( |
|
model=model, |
|
args=args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
loss=loss, |
|
evaluator=eval_evaluator, |
|
) |
|
trainer.train() |
|
|
|
|
|
eval_evaluator(model) |
|
test_evaluator(model) |
|
|
|
|
|
final_output_dir = f"models/{run_name}/final" |
|
model.save_pretrained(final_output_dir) |
|
|
|
|
|
|
|
try: |
|
model.push_to_hub(run_name, private=True) |
|
except Exception: |
|
logging.error( |
|
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " |
|
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` " |
|
f"and saving it using `model.push_to_hub('{run_name}')`." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Train a sentence transformer model on RLHN MS MARCO dataset") |
|
parser.add_argument( |
|
"--model_name_or_path", type=str, default="jhu-clsp/ettin-encoder-17m", help="Model name or path to load" |
|
) |
|
parser.add_argument( |
|
"--loss_type", |
|
type=lambda x: LossType[x.upper()], |
|
default=LossType.CMNRL, |
|
choices=list(LossType), |
|
help="Loss function to use", |
|
) |
|
parser.add_argument("--kldiv_temperature", type=float, default=1.0, help="Temperature for KL divergence loss") |
|
parser.add_argument("--margin_mse_weight", type=float, default=1.0, help="Weight for margin MSE in combined loss") |
|
parser.add_argument("--kldiv_weight", type=float, default=1.0, help="Weight for KL divergence in combined loss") |
|
parser.add_argument("--mini_batch_size", type=int, default=16, help="Mini-batch size for cached MNRL") |
|
parser.add_argument("--mnrl_scale", type=float, default=20.0, help="Scale factor for MNRL loss") |
|
parser.add_argument("--num_train_epochs", type=int, default=1, help="Number of training epochs") |
|
parser.add_argument("--per_device_batch_size", type=int, default=128, help="Batch size per device") |
|
parser.add_argument("--evaluator_batch_size", type=int, default=32, help="Batch size for the evaluators") |
|
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate") |
|
parser.add_argument("--warmup_ratio", type=float, default=0.1, help="Ratio of warmup steps") |
|
parser.add_argument("--fp16", action="store_true", help="Use FP16 precision") |
|
parser.add_argument("--bf16", action="store_true", default=True, help="Use BF16 precision") |
|
parser.add_argument( |
|
"--eval_save_steps", |
|
type=float, |
|
default=0.2, |
|
help="Steps between evaluations and checkpoint saves. If less than 1, " |
|
"it will be treated as a fraction of the total steps.", |
|
) |
|
parser.add_argument("--save_total_limit", type=int, default=3, help="Maximum number of checkpoints to keep") |
|
parser.add_argument("--logging_steps", type=int, default=100, help="Steps between logging") |
|
parser.add_argument("--quick", action="store_true", help="Run with only 5% of training data for quick testing") |
|
|
|
args = parser.parse_args() |
|
|
|
main( |
|
model_name_or_path=args.model_name_or_path, |
|
loss_type=args.loss_type, |
|
kldiv_temperature=args.kldiv_temperature, |
|
margin_mse_weight=args.margin_mse_weight, |
|
kldiv_weight=args.kldiv_weight, |
|
mini_batch_size=args.mini_batch_size, |
|
mnrl_scale=args.mnrl_scale, |
|
num_train_epochs=args.num_train_epochs, |
|
per_device_batch_size=args.per_device_batch_size, |
|
learning_rate=args.learning_rate, |
|
warmup_ratio=args.warmup_ratio, |
|
fp16=args.fp16, |
|
bf16=args.bf16, |
|
eval_save_steps=args.eval_save_steps, |
|
save_total_limit=args.save_total_limit, |
|
logging_steps=args.logging_steps, |
|
evaluator_batch_size=args.evaluator_batch_size, |
|
quick=args.quick, |
|
) |
|
|