tomaarsen HF Staff commited on
Commit
5bbf5da
·
verified ·
1 Parent(s): b867997

Create train_script.py

Browse files
Files changed (1) hide show
  1. train_script.py +124 -0
train_script.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import traceback
3
+
4
+ from datasets import load_dataset
5
+ from sentence_transformers import (
6
+ SentenceTransformer,
7
+ SentenceTransformerModelCardData,
8
+ SentenceTransformerTrainer,
9
+ SentenceTransformerTrainingArguments,
10
+ )
11
+ from sentence_transformers.evaluation import InformationRetrievalEvaluator
12
+ from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
13
+ from sentence_transformers.training_args import BatchSamplers
14
+
15
+ # Set the log level to INFO to get more information
16
+ logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
17
+
18
+ # 1. Load a model to finetune with 2. (Optional) model card data
19
+ model = SentenceTransformer(
20
+ "google/embeddinggemma-300M",
21
+ model_card_data=SentenceTransformerModelCardData(
22
+ language="en",
23
+ license="apache-2.0",
24
+ model_name="EmbeddingGemma-300M trained on the Medical Instruction and RetrIeval Dataset (MIRIAD)",
25
+ ),
26
+ )
27
+
28
+ # 3. Load a dataset to finetune on
29
+ train_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="train").select(range(100_000))
30
+ eval_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="eval").select(range(1_000))
31
+ test_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="test").select(range(1_000))
32
+
33
+ # 4. Define a loss function. CachedMultipleNegativesRankingLoss (CMNRL) is a special variant of MNRL (a.k.a. InfoNCE),
34
+ # which take question-answer pairs (or triplets, etc.) as input. It will take answers from other questions in the batch
35
+ # as wrong answers, reducing the distance between the question and the true answer while increasing the distance to the
36
+ # wrong answers, in the embedding space.
37
+ # The (C)MNRL losses benefit from larger `per_device_train_batch_size` in the Training Arguments, as they can leverage
38
+ # more in-batch negative samples. At the same time, the `mini_batch_size` does not affect training performance, but it
39
+ # does limit the memory usage. A good trick is setting a high `per_device_train_batch_size` while keeping
40
+ # `mini_batch_size` small.
41
+ loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=8)
42
+
43
+ # 5. (Optional) Specify training arguments
44
+ run_name = "embeddinggemma-300M-medical-100k"
45
+ args = SentenceTransformerTrainingArguments(
46
+ # Required parameter:
47
+ output_dir=f"models/{run_name}",
48
+ # Optional training parameters:
49
+ num_train_epochs=1,
50
+ per_device_train_batch_size=128,
51
+ per_device_eval_batch_size=128,
52
+ learning_rate=2e-5,
53
+ warmup_ratio=0.1,
54
+ fp16=True, # Set to False if you get an error that your GPU can't run on FP16
55
+ bf16=False, # Set to True if you have a GPU that supports BF16
56
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached)MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
57
+ prompts={ # Map training column names to model prompts
58
+ "question": model.prompts["query"],
59
+ "passage_text": model.prompts["document"],
60
+ },
61
+ # Optional tracking/debugging parameters:
62
+ eval_strategy="steps",
63
+ eval_steps=100,
64
+ save_strategy="steps",
65
+ save_steps=100,
66
+ save_total_limit=2,
67
+ logging_steps=20,
68
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
69
+ )
70
+
71
+ # 6. (Optional) Create an evaluator using the evaluation queries and 31k answers & evaluate the base model
72
+ queries = dict(enumerate(eval_dataset["question"]))
73
+ corpus = dict(enumerate(eval_dataset["passage_text"] + train_dataset["passage_text"][:30_000]))
74
+ relevant_docs = {idx: [idx] for idx in queries}
75
+ dev_evaluator = InformationRetrievalEvaluator(
76
+ queries=queries,
77
+ corpus=corpus,
78
+ relevant_docs=relevant_docs,
79
+ name="miriad-eval-1kq-31kd", # 1k questions, 31k passages
80
+ show_progress_bar=True,
81
+ )
82
+ dev_evaluator(model)
83
+
84
+ # 7. Create a trainer & train
85
+ trainer = SentenceTransformerTrainer(
86
+ model=model,
87
+ args=args,
88
+ train_dataset=train_dataset,
89
+ eval_dataset=eval_dataset,
90
+ loss=loss,
91
+ evaluator=dev_evaluator,
92
+ )
93
+ trainer.train()
94
+
95
+ # (Optional) Evaluate the trained model on the evaluation set once more, this will also log the results
96
+ # and include them in the model card
97
+ dev_evaluator(model)
98
+
99
+ queries = dict(enumerate(test_dataset["question"]))
100
+ corpus = dict(enumerate(test_dataset["passage_text"] + train_dataset["passage_text"][:30_000]))
101
+ relevant_docs = {idx: [idx] for idx in queries}
102
+ test_evaluator = InformationRetrievalEvaluator(
103
+ queries=queries,
104
+ corpus=corpus,
105
+ relevant_docs=relevant_docs,
106
+ name="miriad-test-1kq-31kd", # 1k questions, 31k passages
107
+ show_progress_bar=True,
108
+ )
109
+ test_evaluator(model)
110
+
111
+ # 8. Save the trained model
112
+ final_output_dir = f"models/{run_name}/final"
113
+ model.save_pretrained(final_output_dir)
114
+
115
+ # 9. (Optional) Push it to the Hugging Face Hub
116
+ # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
117
+ try:
118
+ model.push_to_hub(run_name)
119
+ except Exception:
120
+ logging.error(
121
+ f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
122
+ f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
123
+ f"and saving it using `model.push_to_hub('{run_name}')`."
124
+ )