tomaarsen's picture
tomaarsen HF Staff
Create demo_train.py
d10f6aa verified
import argparse
import logging
import traceback
from collections import defaultdict
from collections.abc import Iterable
from enum import Enum, auto
import torch
from datasets import load_dataset
from torch import Tensor
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerModelCardData,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import InformationRetrievalEvaluator, NanoBEIREvaluator, SequentialEvaluator
from sentence_transformers.losses import (
CachedMultipleNegativesRankingLoss,
DistillKLDivLoss,
MarginMSELoss,
MultipleNegativesRankingLoss,
)
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.util import pairwise_dot_score
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
class LossType(Enum):
MNRL = auto()
CMNRL = auto()
MARGIN_MSE = auto()
KLDIV = auto()
MARGIN_MSE_KLDIV = auto()
def __str__(self):
return self.name.lower()
class MarginMSEKLDivLoss(torch.nn.Module):
def __init__(
self,
model: SentenceTransformer,
similarity_fct=pairwise_dot_score,
temperature=1.0,
margin_mse_weight=1.0,
kldiv_weight=1.0,
) -> None:
super().__init__()
self.model = model
self.similarity_fct = similarity_fct
self.temperature = temperature
self.margin_mse_weight = margin_mse_weight
self.kldiv_weight = kldiv_weight
self.margin_mse_loss = MarginMSELoss(self.model, similarity_fct=self.similarity_fct)
self.kl_div_loss = DistillKLDivLoss(
self.model, similarity_fct=self.similarity_fct, temperature=self.temperature
)
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
return self.compute_loss_from_embeddings(embeddings, labels)
def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) -> Tensor:
return {
"margin_mse": self.margin_mse_loss.compute_loss_from_embeddings(embeddings, labels) * self.margin_mse_weight,
"kl_div": self.kl_div_loss.compute_loss_from_embeddings(embeddings, labels) * self.kldiv_weight
}
def main(
model_name_or_path: str,
loss_type: LossType,
kldiv_temperature: float,
margin_mse_weight: float,
kldiv_weight: float,
mini_batch_size: int,
mnrl_scale: float,
num_train_epochs: int,
per_device_batch_size: int,
learning_rate: float,
warmup_ratio: float,
fp16: bool,
bf16: bool,
eval_save_steps: int,
save_total_limit: int,
logging_steps: int,
evaluator_batch_size: int,
quick: bool,
):
# 1. Load a model with prompts to finetune with 2. (Optional) model card data
model = SentenceTransformer(
model_name_or_path,
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name=f"{model_name_or_path} trained on RLHN MS MARCO using {loss_type}",
),
prompts={ # prompts with "query" and "document" keys are automatically used in evaluation via model.encode_query/model.encode_document
"query": "query: ",
"document": "document: ",
},
)
# 3. Load a dataset to finetune on
# TODO: Eventually we want this:
"""
train_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="train")
eval_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="eval")
test_dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs", split="test")
train_dataset = train_dataset.select_columns([column for column in train_dataset.column_names if column != 'logits'])
eval_dataset = eval_dataset.select_columns([column for column in eval_dataset.column_names if column != 'logits'])
test_dataset = test_dataset.select_columns([column for column in test_dataset.column_names if column != 'logits'])
"""
# But for now we do it manually:
dataset = load_dataset("mixedbread-ai/rlhn-680k-msmarco-7negs-scored", split="train")
dataset = dataset.select_columns([column for column in dataset.column_names if column != "logits"])
split_dataset = dataset.train_test_split(test_size=3_000)
dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
split_dataset = dataset.train_test_split(test_size=10_000)
train_dataset = split_dataset["train"]
test_dataset = split_dataset["test"]
# 4. Define a loss function
batch_sampler = BatchSamplers.BATCH_SAMPLER
gather_across_devices = torch.distributed.is_initialized() if torch.distributed.is_available() else False
if loss_type == LossType.MNRL:
loss = MultipleNegativesRankingLoss(model, scale=mnrl_scale, gather_across_devices=gather_across_devices)
batch_sampler = BatchSamplers.NO_DUPLICATES
elif loss_type == LossType.CMNRL:
loss = CachedMultipleNegativesRankingLoss(
model, scale=mnrl_scale, mini_batch_size=mini_batch_size, gather_across_devices=gather_across_devices
)
batch_sampler = BatchSamplers.NO_DUPLICATES
elif loss_type == LossType.MARGIN_MSE:
loss = MarginMSELoss(model)
elif loss_type == LossType.KLDIV:
loss = DistillKLDivLoss(model, temperature=kldiv_temperature)
elif loss_type == LossType.MARGIN_MSE_KLDIV:
loss = MarginMSEKLDivLoss(
model, temperature=kldiv_temperature, margin_mse_weight=margin_mse_weight, kldiv_weight=kldiv_weight
)
# 5. (Optional) Specify training arguments
short_model_name_or_path = model_name_or_path.split("/")[-1]
run_name = f"{short_model_name_or_path}-{loss_type}-lr{learning_rate}-bs{per_device_batch_size}"
column_names_to_prompts = {
column_name: "query" if column_name == "query" else "document"
for column_name in dataset.column_names
if column_name != "scores"
}
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=0.05 if quick else num_train_epochs,
per_device_train_batch_size=per_device_batch_size,
per_device_eval_batch_size=per_device_batch_size,
learning_rate=learning_rate,
warmup_ratio=warmup_ratio,
fp16=fp16, # Set to False if you get an error that your GPU can't run on FP16
bf16=bf16, # Set to True if you have a GPU that supports BF16
batch_sampler=batch_sampler, # (C)MNRL benefits from no duplicate samples in a batch
prompts=column_names_to_prompts, # Let's incorporate prompts for a ~1% improvement
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=eval_save_steps,
save_strategy="steps",
save_steps=eval_save_steps,
save_total_limit=save_total_limit,
logging_steps=logging_steps,
run_name=run_name,
)
# 6. (Optional) Create evaluator & evaluate the base model
nano_beir_evaluator = NanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"],
batch_size=evaluator_batch_size,
query_prompts=model.prompts["query"], # This will be done automatically starting from the next version
corpus_prompts=model.prompts["document"], # This will be done automatically starting from the next version
)
eval_queries = {}
eval_documents = {}
eval_relevant_docs = defaultdict(set)
for query, positive in zip(eval_dataset["query"], eval_dataset["positive"]):
query_id = len(eval_queries)
eval_queries[query_id] = query
document_id = len(eval_documents)
eval_documents[document_id] = positive
eval_relevant_docs[query_id].add(document_id)
for column_name in test_dataset.column_names:
if column_name.startswith("negative"):
for negative in test_dataset[column_name]:
document_id = len(eval_documents)
eval_documents[document_id] = negative
eval_ir_evaluator = InformationRetrievalEvaluator(
queries=eval_queries,
corpus=eval_documents,
relevant_docs=eval_relevant_docs,
name="rlhn-msmarco-eval",
batch_size=evaluator_batch_size,
query_prompt_name="query", # This will be done automatically starting from the next version
corpus_prompt_name="document", # This will be done automatically starting from the next version
)
eval_evaluator = SequentialEvaluator([nano_beir_evaluator, eval_ir_evaluator])
if not quick:
eval_evaluator(model)
test_queries = {}
test_documents = {}
test_relevant_docs = defaultdict(set)
for query, positive in zip(test_dataset["query"], test_dataset["positive"]):
query_id = len(test_queries)
test_queries[query_id] = query
document_id = len(test_documents)
test_documents[document_id] = positive
test_relevant_docs[query_id].add(document_id)
for column_name in test_dataset.column_names:
if column_name.startswith("negative"):
for negative in test_dataset[column_name]:
document_id = len(test_documents)
test_documents[document_id] = negative
test_ir_evaluator = InformationRetrievalEvaluator(
queries=test_queries,
corpus=test_documents,
relevant_docs=test_relevant_docs,
name="rlhn-msmarco-test",
batch_size=evaluator_batch_size,
query_prompt_name="query", # This will be done automatically starting from the next version
corpus_prompt_name="document", # This will be done automatically starting from the next version
)
test_evaluator = SequentialEvaluator([test_ir_evaluator])
if not quick:
test_evaluator(model)
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=eval_evaluator,
)
trainer.train()
# (Optional) Evaluate the trained model on the eval & test sets again
eval_evaluator(model)
test_evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name, private=True)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a sentence transformer model on RLHN MS MARCO dataset")
parser.add_argument(
"--model_name_or_path", type=str, default="jhu-clsp/ettin-encoder-17m", help="Model name or path to load"
)
parser.add_argument(
"--loss_type",
type=lambda x: LossType[x.upper()],
default=LossType.CMNRL,
choices=list(LossType),
help="Loss function to use",
)
parser.add_argument("--kldiv_temperature", type=float, default=1.0, help="Temperature for KL divergence loss")
parser.add_argument("--margin_mse_weight", type=float, default=1.0, help="Weight for margin MSE in combined loss")
parser.add_argument("--kldiv_weight", type=float, default=1.0, help="Weight for KL divergence in combined loss")
parser.add_argument("--mini_batch_size", type=int, default=16, help="Mini-batch size for cached MNRL")
parser.add_argument("--mnrl_scale", type=float, default=20.0, help="Scale factor for MNRL loss")
parser.add_argument("--num_train_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--per_device_batch_size", type=int, default=128, help="Batch size per device")
parser.add_argument("--evaluator_batch_size", type=int, default=32, help="Batch size for the evaluators")
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
parser.add_argument("--warmup_ratio", type=float, default=0.1, help="Ratio of warmup steps")
parser.add_argument("--fp16", action="store_true", help="Use FP16 precision")
parser.add_argument("--bf16", action="store_true", default=True, help="Use BF16 precision")
parser.add_argument(
"--eval_save_steps",
type=float,
default=0.2,
help="Steps between evaluations and checkpoint saves. If less than 1, "
"it will be treated as a fraction of the total steps.",
)
parser.add_argument("--save_total_limit", type=int, default=3, help="Maximum number of checkpoints to keep")
parser.add_argument("--logging_steps", type=int, default=100, help="Steps between logging")
parser.add_argument("--quick", action="store_true", help="Run with only 5% of training data for quick testing")
args = parser.parse_args()
main(
model_name_or_path=args.model_name_or_path,
loss_type=args.loss_type,
kldiv_temperature=args.kldiv_temperature,
margin_mse_weight=args.margin_mse_weight,
kldiv_weight=args.kldiv_weight,
mini_batch_size=args.mini_batch_size,
mnrl_scale=args.mnrl_scale,
num_train_epochs=args.num_train_epochs,
per_device_batch_size=args.per_device_batch_size,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
fp16=args.fp16,
bf16=args.bf16,
eval_save_steps=args.eval_save_steps,
save_total_limit=args.save_total_limit,
logging_steps=args.logging_steps,
evaluator_batch_size=args.evaluator_batch_size,
quick=args.quick,
)