madhavanvvs
commited on
Commit
·
a794aec
1
Parent(s):
c6d04a6
Refactor MTL: DDP NCCL support
Browse files- geneformer/mtl/__init__.py +4 -1
- geneformer/mtl/collators.py +2 -2
- geneformer/mtl/data.py +192 -117
- geneformer/mtl/eval_utils.py +5 -8
- geneformer/mtl/imports.py +0 -43
- geneformer/mtl/model.py +1 -1
- geneformer/mtl/optuna_utils.py +0 -27
- geneformer/mtl/train.py +656 -329
- geneformer/mtl/train_utils.py +0 -161
- geneformer/mtl/utils.py +603 -91
- geneformer/mtl_classifier.py +24 -8
geneformer/mtl/__init__.py
CHANGED
@@ -1 +1,4 @@
|
|
1 |
-
# ruff: noqa: F401
|
|
|
|
|
|
|
|
1 |
+
# ruff: noqa: F401
|
2 |
+
|
3 |
+
from . import eval_utils
|
4 |
+
from . import utils
|
geneformer/mtl/collators.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
# imports
|
2 |
import torch
|
3 |
import pickle
|
4 |
-
from
|
5 |
-
from
|
6 |
|
7 |
"""Geneformer collator for multi-task cell classification."""
|
8 |
|
|
|
1 |
# imports
|
2 |
import torch
|
3 |
import pickle
|
4 |
+
from geneformer.collator_for_classification import DataCollatorForGeneClassification
|
5 |
+
from geneformer import TOKEN_DICTIONARY_FILE
|
6 |
|
7 |
"""Geneformer collator for multi-task cell classification."""
|
8 |
|
geneformer/mtl/data.py
CHANGED
@@ -1,126 +1,190 @@
|
|
1 |
import os
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
"""Ensures required columns are present in the dataset."""
|
7 |
-
missing_columns = [col for col in required_columns if col not in dataset.column_names]
|
8 |
-
if missing_columns:
|
9 |
-
raise KeyError(
|
10 |
-
f"Missing columns in {dataset_type} dataset: {missing_columns}. "
|
11 |
-
f"Available columns: {dataset.column_names}"
|
12 |
-
)
|
13 |
-
|
14 |
-
|
15 |
-
def create_label_mappings(dataset, task_to_column):
|
16 |
-
"""Creates label mappings for the dataset."""
|
17 |
-
task_label_mappings = {}
|
18 |
-
num_labels_list = []
|
19 |
-
for task, column in task_to_column.items():
|
20 |
-
unique_values = sorted(set(dataset[column]))
|
21 |
-
mapping = {label: idx for idx, label in enumerate(unique_values)}
|
22 |
-
task_label_mappings[task] = mapping
|
23 |
-
num_labels_list.append(len(unique_values))
|
24 |
-
return task_label_mappings, num_labels_list
|
25 |
-
|
26 |
-
|
27 |
-
def save_label_mappings(mappings, path):
|
28 |
-
"""Saves label mappings to a pickle file."""
|
29 |
-
with open(path, "wb") as f:
|
30 |
-
pickle.dump(mappings, f)
|
31 |
-
|
32 |
-
|
33 |
-
def load_label_mappings(path):
|
34 |
-
"""Loads label mappings from a pickle file."""
|
35 |
-
with open(path, "rb") as f:
|
36 |
-
return pickle.load(f)
|
37 |
-
|
38 |
-
|
39 |
-
def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
|
40 |
-
"""Transforms the dataset to the required format."""
|
41 |
-
transformed_dataset = []
|
42 |
-
cell_id_mapping = {}
|
43 |
-
|
44 |
-
for idx, record in enumerate(dataset):
|
45 |
-
transformed_record = {
|
46 |
-
"input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
|
47 |
-
"cell_id": idx, # Index-based cell ID
|
48 |
-
}
|
49 |
-
|
50 |
-
if not is_test:
|
51 |
-
label_dict = {
|
52 |
-
task: task_label_mappings[task][record[column]]
|
53 |
-
for task, column in task_to_column.items()
|
54 |
-
}
|
55 |
-
else:
|
56 |
-
label_dict = {task: -1 for task in config["task_names"]}
|
57 |
-
|
58 |
-
transformed_record["label"] = label_dict
|
59 |
-
transformed_dataset.append(transformed_record)
|
60 |
-
cell_id_mapping[idx] = record.get("unique_cell_id", idx)
|
61 |
-
|
62 |
-
return transformed_dataset, cell_id_mapping
|
63 |
|
|
|
64 |
|
65 |
-
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
|
66 |
-
"""Main function to load and preprocess data."""
|
67 |
-
try:
|
68 |
-
dataset = load_from_disk(dataset_path)
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
# Setup task and column mappings
|
71 |
-
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
|
72 |
-
task_to_column = dict(zip(task_names, config["task_columns"]))
|
73 |
-
config["task_names"] = task_names
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
76 |
config["results_dir"],
|
77 |
f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
|
78 |
)
|
79 |
-
|
80 |
if not is_test:
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column)
|
85 |
-
save_label_mappings(task_label_mappings, label_mappings_path)
|
86 |
else:
|
87 |
# Load existing mappings for test data
|
88 |
-
task_label_mappings =
|
89 |
-
num_labels_list = [len(mapping) for mapping in task_label_mappings.values()]
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
|
104 |
-
def
|
105 |
-
"""
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
# Process validation data and save mappings
|
110 |
-
val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
|
111 |
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
validate_label_mappings(config)
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
|
118 |
def validate_label_mappings(config):
|
119 |
"""Ensures train and validation label mappings are consistent."""
|
120 |
train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
|
121 |
val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
124 |
|
125 |
for task_name in config["task_names"]:
|
126 |
if train_mappings[task_name] != val_mappings[task_name]:
|
@@ -131,32 +195,43 @@ def validate_label_mappings(config):
|
|
131 |
)
|
132 |
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
143 |
)
|
144 |
|
145 |
|
146 |
def preload_data(config):
|
147 |
"""Preprocesses train and validation data for trials."""
|
148 |
-
|
149 |
-
|
150 |
-
return train_loader, val_loader
|
151 |
|
152 |
|
153 |
def load_and_preprocess_test_data(config):
|
154 |
"""Loads and preprocesses test data."""
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
|
158 |
def prepare_test_loader(config):
|
159 |
"""Prepares DataLoader for test data."""
|
160 |
-
|
161 |
-
test_loader
|
162 |
-
return test_loader, cell_id_mapping, num_labels_list
|
|
|
1 |
import os
|
2 |
+
import pickle
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader, Dataset
|
5 |
+
from datasets import load_from_disk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
from .collators import DataCollatorForMultitaskCellClassification
|
8 |
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
class StreamingMultiTaskDataset(Dataset):
|
11 |
+
|
12 |
+
def __init__(self, dataset_path, config, is_test=False, dataset_type=""):
|
13 |
+
"""Initialize the streaming dataset."""
|
14 |
+
self.dataset = load_from_disk(dataset_path)
|
15 |
+
self.config = config
|
16 |
+
self.is_test = is_test
|
17 |
+
self.dataset_type = dataset_type
|
18 |
+
self.cell_id_mapping = {}
|
19 |
+
|
20 |
# Setup task and column mappings
|
21 |
+
self.task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
|
22 |
+
self.task_to_column = dict(zip(self.task_names, config["task_columns"]))
|
23 |
+
config["task_names"] = self.task_names
|
24 |
+
|
25 |
+
# Check if unique_cell_id column exists in the dataset
|
26 |
+
self.has_unique_cell_ids = "unique_cell_id" in self.dataset.column_names
|
27 |
+
print(f"{'Found' if self.has_unique_cell_ids else 'No'} unique_cell_id column in {dataset_type} dataset")
|
28 |
+
|
29 |
+
# Setup label mappings
|
30 |
+
self.label_mappings_path = os.path.join(
|
31 |
config["results_dir"],
|
32 |
f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
|
33 |
)
|
34 |
+
|
35 |
if not is_test:
|
36 |
+
self._validate_columns()
|
37 |
+
self.task_label_mappings, self.num_labels_list = self._create_label_mappings()
|
38 |
+
self._save_label_mappings()
|
|
|
|
|
39 |
else:
|
40 |
# Load existing mappings for test data
|
41 |
+
self.task_label_mappings = self._load_label_mappings()
|
42 |
+
self.num_labels_list = [len(mapping) for mapping in self.task_label_mappings.values()]
|
43 |
+
|
44 |
+
def _validate_columns(self):
|
45 |
+
"""Ensures required columns are present in the dataset."""
|
46 |
+
missing_columns = [col for col in self.task_to_column.values()
|
47 |
+
if col not in self.dataset.column_names]
|
48 |
+
if missing_columns:
|
49 |
+
raise KeyError(
|
50 |
+
f"Missing columns in {self.dataset_type} dataset: {missing_columns}. "
|
51 |
+
f"Available columns: {self.dataset.column_names}"
|
52 |
+
)
|
53 |
+
|
54 |
+
def _create_label_mappings(self):
|
55 |
+
"""Creates label mappings for the dataset."""
|
56 |
+
task_label_mappings = {}
|
57 |
+
num_labels_list = []
|
58 |
+
|
59 |
+
for task, column in self.task_to_column.items():
|
60 |
+
unique_values = sorted(set(self.dataset[column]))
|
61 |
+
mapping = {label: idx for idx, label in enumerate(unique_values)}
|
62 |
+
task_label_mappings[task] = mapping
|
63 |
+
num_labels_list.append(len(unique_values))
|
64 |
+
|
65 |
+
return task_label_mappings, num_labels_list
|
66 |
+
|
67 |
+
def _save_label_mappings(self):
|
68 |
+
"""Saves label mappings to a pickle file."""
|
69 |
+
with open(self.label_mappings_path, "wb") as f:
|
70 |
+
pickle.dump(self.task_label_mappings, f)
|
71 |
+
|
72 |
+
def _load_label_mappings(self):
|
73 |
+
"""Loads label mappings from a pickle file."""
|
74 |
+
with open(self.label_mappings_path, "rb") as f:
|
75 |
+
return pickle.load(f)
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return len(self.dataset)
|
79 |
+
|
80 |
+
def __getitem__(self, idx):
|
81 |
+
record = self.dataset[idx]
|
82 |
+
|
83 |
+
# Store cell ID mapping
|
84 |
+
if self.has_unique_cell_ids:
|
85 |
+
unique_cell_id = record["unique_cell_id"]
|
86 |
+
self.cell_id_mapping[idx] = unique_cell_id
|
87 |
+
else:
|
88 |
+
self.cell_id_mapping[idx] = f"cell_{idx}"
|
89 |
+
|
90 |
+
# Create transformed record
|
91 |
+
transformed_record = {
|
92 |
+
"input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
|
93 |
+
"cell_id": idx,
|
94 |
+
}
|
95 |
+
|
96 |
+
# Add labels
|
97 |
+
if not self.is_test:
|
98 |
+
label_dict = {
|
99 |
+
task: self.task_label_mappings[task][record[column]]
|
100 |
+
for task, column in self.task_to_column.items()
|
101 |
+
}
|
102 |
+
else:
|
103 |
+
label_dict = {task: -1 for task in self.config["task_names"]}
|
104 |
+
|
105 |
+
transformed_record["label"] = label_dict
|
106 |
+
|
107 |
+
return transformed_record
|
108 |
|
109 |
|
110 |
+
def get_data_loader(dataset, batch_size, sampler=None, shuffle=True):
|
111 |
+
"""Create a DataLoader with the given dataset and parameters."""
|
112 |
+
return DataLoader(
|
113 |
+
dataset,
|
114 |
+
batch_size=batch_size,
|
115 |
+
sampler=sampler,
|
116 |
+
shuffle=shuffle if sampler is None else False,
|
117 |
+
num_workers=0,
|
118 |
+
pin_memory=True,
|
119 |
+
collate_fn=DataCollatorForMultitaskCellClassification(),
|
120 |
+
)
|
121 |
|
|
|
|
|
122 |
|
123 |
+
def prepare_data_loaders(config, include_test=False):
|
124 |
+
"""Prepare data loaders for training, validation, and optionally test."""
|
125 |
+
result = {}
|
126 |
+
|
127 |
+
# Process train data
|
128 |
+
train_dataset = StreamingMultiTaskDataset(
|
129 |
+
config["train_path"],
|
130 |
+
config,
|
131 |
+
dataset_type="train"
|
132 |
+
)
|
133 |
+
result["train_loader"] = get_data_loader(train_dataset, config["batch_size"])
|
134 |
+
|
135 |
+
# Store the cell ID mapping from the dataset
|
136 |
+
result["train_cell_mapping"] = {k: v for k, v in train_dataset.cell_id_mapping.items()}
|
137 |
+
print(f"Collected {len(result['train_cell_mapping'])} cell IDs from training dataset")
|
138 |
+
|
139 |
+
result["num_labels_list"] = train_dataset.num_labels_list
|
140 |
+
|
141 |
+
# Process validation data
|
142 |
+
val_dataset = StreamingMultiTaskDataset(
|
143 |
+
config["val_path"],
|
144 |
+
config,
|
145 |
+
dataset_type="validation"
|
146 |
+
)
|
147 |
+
result["val_loader"] = get_data_loader(val_dataset, config["batch_size"])
|
148 |
+
|
149 |
+
# Store the complete cell ID mapping for validation
|
150 |
+
for idx in range(len(val_dataset)):
|
151 |
+
_ = val_dataset[idx]
|
152 |
+
|
153 |
+
result["val_cell_mapping"] = {k: v for k, v in val_dataset.cell_id_mapping.items()}
|
154 |
+
print(f"Collected {len(result['val_cell_mapping'])} cell IDs from validation dataset")
|
155 |
+
|
156 |
+
# Validate label mappings
|
157 |
validate_label_mappings(config)
|
158 |
+
|
159 |
+
# Process test data if requested
|
160 |
+
if include_test and "test_path" in config:
|
161 |
+
test_dataset = StreamingMultiTaskDataset(
|
162 |
+
config["test_path"],
|
163 |
+
config,
|
164 |
+
is_test=True,
|
165 |
+
dataset_type="test"
|
166 |
+
)
|
167 |
+
result["test_loader"] = get_data_loader(test_dataset, config["batch_size"])
|
168 |
+
|
169 |
+
for idx in range(len(test_dataset)):
|
170 |
+
_ = test_dataset[idx]
|
171 |
+
|
172 |
+
result["test_cell_mapping"] = {k: v for k, v in test_dataset.cell_id_mapping.items()}
|
173 |
+
print(f"Collected {len(result['test_cell_mapping'])} cell IDs from test dataset")
|
174 |
+
|
175 |
+
return result
|
176 |
|
177 |
|
178 |
def validate_label_mappings(config):
|
179 |
"""Ensures train and validation label mappings are consistent."""
|
180 |
train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
|
181 |
val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
|
182 |
+
|
183 |
+
with open(train_mappings_path, "rb") as f:
|
184 |
+
train_mappings = pickle.load(f)
|
185 |
+
|
186 |
+
with open(val_mappings_path, "rb") as f:
|
187 |
+
val_mappings = pickle.load(f)
|
188 |
|
189 |
for task_name in config["task_names"]:
|
190 |
if train_mappings[task_name] != val_mappings[task_name]:
|
|
|
195 |
)
|
196 |
|
197 |
|
198 |
+
# Legacy functions for backward compatibility
|
199 |
+
def preload_and_process_data(config):
|
200 |
+
"""Preloads and preprocesses train and validation datasets."""
|
201 |
+
data = prepare_data_loaders(config)
|
202 |
+
|
203 |
+
return (
|
204 |
+
data["train_loader"].dataset,
|
205 |
+
data["train_cell_mapping"],
|
206 |
+
data["val_loader"].dataset,
|
207 |
+
data["val_cell_mapping"],
|
208 |
+
data["num_labels_list"]
|
209 |
)
|
210 |
|
211 |
|
212 |
def preload_data(config):
|
213 |
"""Preprocesses train and validation data for trials."""
|
214 |
+
data = prepare_data_loaders(config)
|
215 |
+
return data["train_loader"], data["val_loader"]
|
|
|
216 |
|
217 |
|
218 |
def load_and_preprocess_test_data(config):
|
219 |
"""Loads and preprocesses test data."""
|
220 |
+
test_dataset = StreamingMultiTaskDataset(
|
221 |
+
config["test_path"],
|
222 |
+
config,
|
223 |
+
is_test=True,
|
224 |
+
dataset_type="test"
|
225 |
+
)
|
226 |
+
|
227 |
+
return (
|
228 |
+
test_dataset,
|
229 |
+
test_dataset.cell_id_mapping,
|
230 |
+
test_dataset.num_labels_list
|
231 |
+
)
|
232 |
|
233 |
|
234 |
def prepare_test_loader(config):
|
235 |
"""Prepares DataLoader for test data."""
|
236 |
+
data = prepare_data_loaders(config, include_test=True)
|
237 |
+
return data["test_loader"], data["test_cell_mapping"], data["num_labels_list"]
|
|
geneformer/mtl/eval_utils.py
CHANGED
@@ -1,19 +1,16 @@
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
|
3 |
-
from .
|
4 |
-
from .data import prepare_test_loader # noqa # isort:skip
|
5 |
from .model import GeneformerMultiTask
|
6 |
|
7 |
-
|
8 |
def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
|
9 |
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
10 |
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
11 |
cell_ids = []
|
12 |
|
13 |
-
# # Load task label mappings from pickle file
|
14 |
-
# with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
15 |
-
# task_label_mappings = pickle.load(f)
|
16 |
-
|
17 |
model.eval()
|
18 |
with torch.no_grad():
|
19 |
for batch in test_loader:
|
@@ -85,4 +82,4 @@ def load_and_evaluate_test_model(config):
|
|
85 |
best_model.to(device)
|
86 |
|
87 |
evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
|
88 |
-
print("Evaluation completed.")
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
import pandas as pd
|
5 |
|
6 |
+
from .data import prepare_test_loader
|
|
|
7 |
from .model import GeneformerMultiTask
|
8 |
|
|
|
9 |
def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
|
10 |
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
11 |
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
12 |
cell_ids = []
|
13 |
|
|
|
|
|
|
|
|
|
14 |
model.eval()
|
15 |
with torch.no_grad():
|
16 |
for batch in test_loader:
|
|
|
82 |
best_model.to(device)
|
83 |
|
84 |
evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
|
85 |
+
print("Evaluation completed.")
|
geneformer/mtl/imports.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
import functools
|
2 |
-
import gc
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
import pickle
|
6 |
-
import sys
|
7 |
-
import warnings
|
8 |
-
from enum import Enum
|
9 |
-
from itertools import chain
|
10 |
-
from typing import Dict, List, Optional, Union
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
import optuna
|
14 |
-
import pandas as pd
|
15 |
-
import torch
|
16 |
-
import torch.nn as nn
|
17 |
-
import torch.nn.functional as F
|
18 |
-
import torch.optim as optim
|
19 |
-
from datasets import load_from_disk
|
20 |
-
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
|
21 |
-
from sklearn.model_selection import train_test_split
|
22 |
-
from sklearn.preprocessing import LabelEncoder
|
23 |
-
from torch.utils.data import DataLoader
|
24 |
-
from transformers import (
|
25 |
-
AdamW,
|
26 |
-
BatchEncoding,
|
27 |
-
BertConfig,
|
28 |
-
BertModel,
|
29 |
-
DataCollatorForTokenClassification,
|
30 |
-
SpecialTokensMixin,
|
31 |
-
get_cosine_schedule_with_warmup,
|
32 |
-
get_linear_schedule_with_warmup,
|
33 |
-
get_scheduler,
|
34 |
-
)
|
35 |
-
from transformers.utils import logging, to_py_obj
|
36 |
-
|
37 |
-
from .collators import DataCollatorForMultitaskCellClassification
|
38 |
-
|
39 |
-
# local modules
|
40 |
-
from .data import get_data_loader, preload_and_process_data
|
41 |
-
from .model import GeneformerMultiTask
|
42 |
-
from .optuna_utils import create_optuna_study
|
43 |
-
from .utils import save_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/model.py
CHANGED
@@ -118,4 +118,4 @@ class GeneformerMultiTask(nn.Module):
|
|
118 |
f"Error during loss computation for task {task_id}: {e}"
|
119 |
)
|
120 |
|
121 |
-
return total_loss, logits, losses if labels is not None else logits
|
|
|
118 |
f"Error during loss computation for task {task_id}: {e}"
|
119 |
)
|
120 |
|
121 |
+
return total_loss, logits, losses if labels is not None else logits
|
geneformer/mtl/optuna_utils.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
import optuna
|
2 |
-
from optuna.integration import TensorBoardCallback
|
3 |
-
|
4 |
-
|
5 |
-
def save_trial_callback(study, trial, trials_result_path):
|
6 |
-
with open(trials_result_path, "a") as f:
|
7 |
-
f.write(
|
8 |
-
f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
|
9 |
-
)
|
10 |
-
|
11 |
-
|
12 |
-
def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
|
13 |
-
study = optuna.create_study(direction="maximize")
|
14 |
-
|
15 |
-
# init TensorBoard callback
|
16 |
-
tensorboard_callback = TensorBoardCallback(
|
17 |
-
dirname=tensorboard_log_dir, metric_name="F1 Macro"
|
18 |
-
)
|
19 |
-
|
20 |
-
# callback and TensorBoard callback
|
21 |
-
callbacks = [
|
22 |
-
lambda study, trial: save_trial_callback(study, trial, trials_result_path),
|
23 |
-
tensorboard_callback,
|
24 |
-
]
|
25 |
-
|
26 |
-
study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
|
27 |
-
return study
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/train.py
CHANGED
@@ -1,380 +1,707 @@
|
|
1 |
import os
|
2 |
-
import random
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
import pandas as pd
|
6 |
import torch
|
|
|
|
|
|
|
7 |
from torch.utils.tensorboard import SummaryWriter
|
8 |
from tqdm import tqdm
|
|
|
|
|
|
|
9 |
|
10 |
-
from .imports import *
|
11 |
from .model import GeneformerMultiTask
|
12 |
-
from .utils import
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
)
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
writer.add_scalar(
|
97 |
-
"Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
|
98 |
-
)
|
99 |
-
if config.get("use_wandb", False):
|
100 |
-
import wandb
|
101 |
-
|
102 |
-
wandb.log({"Training Loss": loss.item()})
|
103 |
-
|
104 |
-
# Update progress bar
|
105 |
-
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
106 |
-
|
107 |
-
return loss.item() # Return the last batch loss
|
108 |
-
|
109 |
-
|
110 |
-
def validate_model(model, val_loader, device, config):
|
111 |
-
model.eval()
|
112 |
-
val_loss = 0.0
|
113 |
-
task_true_labels = {task_name: [] for task_name in config["task_names"]}
|
114 |
-
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
115 |
-
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
116 |
-
|
117 |
-
with torch.no_grad():
|
118 |
-
for batch in val_loader:
|
119 |
-
input_ids = batch["input_ids"].to(device)
|
120 |
-
attention_mask = batch["attention_mask"].to(device)
|
121 |
labels = [
|
122 |
-
batch["labels"][task_name].to(device)
|
123 |
-
for task_name in config["task_names"]
|
124 |
]
|
125 |
-
loss, logits, _ = model(input_ids, attention_mask, labels)
|
126 |
-
val_loss += loss.item()
|
127 |
-
|
128 |
-
for sample_idx in range(len(batch["input_ids"])):
|
129 |
-
for i, task_name in enumerate(config["task_names"]):
|
130 |
-
true_label = batch["labels"][task_name][sample_idx].item()
|
131 |
-
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
132 |
-
pred_prob = (
|
133 |
-
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
|
134 |
-
)
|
135 |
-
task_true_labels[task_name].append(true_label)
|
136 |
-
task_pred_labels[task_name].append(pred_label)
|
137 |
-
task_pred_probs[task_name].append(pred_prob)
|
138 |
-
|
139 |
-
val_loss /= len(val_loader)
|
140 |
-
return val_loss, task_true_labels, task_pred_labels, task_pred_probs
|
141 |
-
|
142 |
-
|
143 |
-
def log_metrics(task_metrics, val_loss, config, writer, epochs):
|
144 |
-
for task_name, metrics in task_metrics.items():
|
145 |
-
print(
|
146 |
-
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
|
147 |
-
)
|
148 |
-
if config.get("use_wandb", False):
|
149 |
-
import wandb
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
)
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
)
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
if
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
|
186 |
-
row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
|
187 |
-
row[f"{task_name} Probabilities"] = ",".join(
|
188 |
-
map(str, task_pred_probs[task_name][sample_idx])
|
189 |
)
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
|
197 |
-
def
|
198 |
-
|
199 |
-
device,
|
200 |
train_loader,
|
201 |
val_loader,
|
202 |
train_cell_id_mapping,
|
203 |
val_cell_id_mapping,
|
204 |
num_labels_list,
|
|
|
|
|
205 |
):
|
|
|
206 |
set_seed(config["seed"])
|
207 |
initialize_wandb(config)
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
)
|
221 |
-
epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
227 |
|
228 |
-
|
229 |
-
|
|
|
|
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
-
|
236 |
-
import wandb
|
237 |
|
238 |
-
wandb.finish()
|
239 |
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
val_loader,
|
248 |
-
train_cell_id_mapping,
|
249 |
-
val_cell_id_mapping,
|
250 |
-
num_labels_list,
|
251 |
-
config,
|
252 |
-
device,
|
253 |
-
):
|
254 |
-
set_seed(config["seed"]) # Set the seed before each trial
|
255 |
-
initialize_wandb(config)
|
256 |
|
257 |
-
#
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
config["hyperparameters"]["warmup_ratio"]["low"],
|
267 |
-
config["hyperparameters"]["warmup_ratio"]["high"],
|
268 |
-
)
|
269 |
-
config["weight_decay"] = trial.suggest_float(
|
270 |
-
"weight_decay",
|
271 |
-
config["hyperparameters"]["weight_decay"]["low"],
|
272 |
-
config["hyperparameters"]["weight_decay"]["high"],
|
273 |
-
)
|
274 |
-
config["dropout_rate"] = trial.suggest_float(
|
275 |
-
"dropout_rate",
|
276 |
-
config["hyperparameters"]["dropout_rate"]["low"],
|
277 |
-
config["hyperparameters"]["dropout_rate"]["high"],
|
278 |
-
)
|
279 |
-
config["lr_scheduler_type"] = trial.suggest_categorical(
|
280 |
-
"lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
|
281 |
-
)
|
282 |
-
config["use_attention_pooling"] = trial.suggest_categorical(
|
283 |
-
"use_attention_pooling", [False]
|
284 |
)
|
285 |
|
286 |
-
|
287 |
-
config["task_weights"] = [
|
288 |
-
trial.suggest_float(
|
289 |
-
f"task_weight_{i}",
|
290 |
-
config["hyperparameters"]["task_weights"]["low"],
|
291 |
-
config["hyperparameters"]["task_weights"]["high"],
|
292 |
-
)
|
293 |
-
for i in range(len(num_labels_list))
|
294 |
-
]
|
295 |
-
weight_sum = sum(config["task_weights"])
|
296 |
-
config["task_weights"] = [
|
297 |
-
weight / weight_sum for weight in config["task_weights"]
|
298 |
-
]
|
299 |
-
else:
|
300 |
-
config["task_weights"] = None
|
301 |
-
|
302 |
-
# Dynamic range for max_layers_to_freeze
|
303 |
-
freeze_range = get_layer_freeze_range(config["pretrained_path"])
|
304 |
-
config["max_layers_to_freeze"] = trial.suggest_int(
|
305 |
-
"max_layers_to_freeze",
|
306 |
-
freeze_range["min"],
|
307 |
-
freeze_range["max"]
|
308 |
-
)
|
309 |
|
310 |
-
model
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
-
|
315 |
-
writer = SummaryWriter(log_dir=log_dir)
|
316 |
|
317 |
-
for epoch in range(config["epochs"]):
|
318 |
-
train_epoch(
|
319 |
-
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
320 |
-
)
|
321 |
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
)
|
325 |
-
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
326 |
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
)
|
338 |
|
339 |
-
|
340 |
-
|
341 |
|
342 |
-
|
|
|
|
|
343 |
|
344 |
-
|
345 |
-
raise optuna.TrialPruned()
|
346 |
|
347 |
-
|
348 |
-
|
|
|
|
|
349 |
|
350 |
-
|
351 |
-
{
|
352 |
-
"trial_number": trial.number,
|
353 |
-
"val_loss": val_loss,
|
354 |
-
**{
|
355 |
-
f"{task_name}_f1": metrics["f1"]
|
356 |
-
for task_name, metrics in task_metrics.items()
|
357 |
-
},
|
358 |
-
**{
|
359 |
-
f"{task_name}_accuracy": metrics["accuracy"]
|
360 |
-
for task_name, metrics in task_metrics.items()
|
361 |
-
},
|
362 |
-
**{
|
363 |
-
k: v
|
364 |
-
for k, v in config.items()
|
365 |
-
if k
|
366 |
-
in [
|
367 |
-
"learning_rate",
|
368 |
-
"warmup_ratio",
|
369 |
-
"weight_decay",
|
370 |
-
"dropout_rate",
|
371 |
-
"lr_scheduler_type",
|
372 |
-
"use_attention_pooling",
|
373 |
-
"max_layers_to_freeze",
|
374 |
-
]
|
375 |
-
},
|
376 |
-
}
|
377 |
-
)
|
378 |
-
wandb.finish()
|
379 |
|
380 |
-
return
|
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
import pandas as pd
|
3 |
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
import torch.multiprocessing as mp
|
6 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
7 |
from torch.utils.tensorboard import SummaryWriter
|
8 |
from tqdm import tqdm
|
9 |
+
import optuna
|
10 |
+
import functools
|
11 |
+
import time
|
12 |
|
|
|
13 |
from .model import GeneformerMultiTask
|
14 |
+
from .utils import (
|
15 |
+
calculate_metrics,
|
16 |
+
get_layer_freeze_range,
|
17 |
+
set_seed,
|
18 |
+
initialize_wandb,
|
19 |
+
create_model,
|
20 |
+
setup_optimizer_and_scheduler,
|
21 |
+
save_model,
|
22 |
+
save_hyperparameters,
|
23 |
+
prepare_training_environment,
|
24 |
+
log_training_step,
|
25 |
+
log_validation_metrics,
|
26 |
+
save_validation_predictions,
|
27 |
+
setup_logging,
|
28 |
+
setup_distributed_environment,
|
29 |
+
train_distributed
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class Trainer:
|
34 |
+
"""Trainer class for multi-task learning"""
|
35 |
+
|
36 |
+
def __init__(self, config):
|
37 |
+
self.config = config
|
38 |
+
self.device = None
|
39 |
+
self.model = None
|
40 |
+
self.optimizer = None
|
41 |
+
self.scheduler = None
|
42 |
+
self.writer = None
|
43 |
+
self.is_distributed = config.get("distributed_training", False)
|
44 |
+
self.local_rank = config.get("local_rank", 0)
|
45 |
+
self.is_main_process = not self.is_distributed or self.local_rank == 0
|
46 |
+
|
47 |
+
def train_epoch(self, train_loader, epoch):
|
48 |
+
"""Train the model for one epoch."""
|
49 |
+
epoch_start = time.time()
|
50 |
+
self.model.train()
|
51 |
+
|
52 |
+
# For distributed training, we need to be aware of the global batch count
|
53 |
+
if self.is_distributed:
|
54 |
+
# Get world size for reporting
|
55 |
+
world_size = dist.get_world_size()
|
56 |
+
# Calculate total batches across all GPUs
|
57 |
+
total_batches_global = len(train_loader) * world_size if self.local_rank == 0 else len(train_loader)
|
58 |
+
else:
|
59 |
+
world_size = 1
|
60 |
+
total_batches_global = len(train_loader)
|
61 |
+
|
62 |
+
progress_bar = None
|
63 |
+
if self.is_main_process:
|
64 |
+
# Use the global batch count for progress reporting in distributed mode
|
65 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}",
|
66 |
+
total=len(train_loader))
|
67 |
+
iterator = progress_bar
|
68 |
+
|
69 |
+
# Report distributed training information
|
70 |
+
if self.is_distributed:
|
71 |
+
print(f"Distributed training: {world_size} GPUs, {len(train_loader)} batches per GPU, "
|
72 |
+
f"{total_batches_global} total batches globally")
|
73 |
+
else:
|
74 |
+
iterator = train_loader
|
75 |
+
|
76 |
+
batch_times = []
|
77 |
+
forward_times = []
|
78 |
+
backward_times = []
|
79 |
+
optimizer_times = []
|
80 |
+
|
81 |
+
# Get gradient accumulation steps from config (default to 1 if not specified)
|
82 |
+
accumulation_steps = self.config.get("gradient_accumulation_steps", 1)
|
83 |
+
|
84 |
+
# Zero gradients at the beginning
|
85 |
+
self.optimizer.zero_grad()
|
86 |
+
|
87 |
+
# Track loss for the entire epoch
|
88 |
+
total_loss = 0.0
|
89 |
+
num_batches = 0
|
90 |
+
accumulated_loss = 0.0
|
91 |
+
|
92 |
+
for batch_idx, batch in enumerate(iterator):
|
93 |
+
batch_start = time.time()
|
94 |
+
|
95 |
+
input_ids = batch["input_ids"].to(self.device)
|
96 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
labels = [
|
98 |
+
batch["labels"][task_name].to(self.device) for task_name in self.config["task_names"]
|
|
|
99 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
forward_start = time.time()
|
102 |
+
loss, _, _ = self.model(input_ids, attention_mask, labels)
|
103 |
+
|
104 |
+
# Scale loss by accumulation steps for gradient accumulation
|
105 |
+
if accumulation_steps > 1:
|
106 |
+
loss = loss / accumulation_steps
|
107 |
+
|
108 |
+
forward_end = time.time()
|
109 |
+
forward_times.append(forward_end - forward_start)
|
110 |
+
|
111 |
+
# Track loss - store the unscaled loss for reporting
|
112 |
+
unscaled_loss = loss.item() * (1 if accumulation_steps == 1 else accumulation_steps)
|
113 |
+
total_loss += unscaled_loss
|
114 |
+
num_batches += 1
|
115 |
+
accumulated_loss += loss.item() # For gradient accumulation tracking
|
116 |
+
|
117 |
+
backward_start = time.time()
|
118 |
+
|
119 |
+
# Use no_sync() for all but the last accumulation step to avoid unnecessary communication
|
120 |
+
if self.is_distributed and accumulation_steps > 1:
|
121 |
+
# If this is not the last accumulation step or the last batch
|
122 |
+
if (batch_idx + 1) % accumulation_steps != 0 and (batch_idx + 1) != len(train_loader):
|
123 |
+
with self.model.no_sync():
|
124 |
+
loss.backward()
|
125 |
+
else:
|
126 |
+
loss.backward()
|
127 |
+
else:
|
128 |
+
# Non-distributed training or accumulation_steps=1
|
129 |
+
loss.backward()
|
130 |
+
|
131 |
+
backward_end = time.time()
|
132 |
+
backward_times.append(backward_end - backward_start)
|
133 |
+
|
134 |
+
# Only update weights after accumulation_steps or at the end of the epoch
|
135 |
+
if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
|
136 |
+
if self.config["gradient_clipping"]:
|
137 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
|
138 |
+
|
139 |
+
optimizer_start = time.time()
|
140 |
+
self.optimizer.step()
|
141 |
+
self.scheduler.step()
|
142 |
+
self.optimizer.zero_grad()
|
143 |
+
optimizer_end = time.time()
|
144 |
+
optimizer_times.append(optimizer_end - optimizer_start)
|
145 |
+
|
146 |
+
# Log after optimizer step
|
147 |
+
if self.is_main_process:
|
148 |
+
# Calculate running average loss
|
149 |
+
avg_loss = total_loss / num_batches
|
150 |
+
|
151 |
+
log_training_step(avg_loss, self.writer, self.config, epoch, len(train_loader), batch_idx)
|
152 |
+
|
153 |
+
# Update progress bar with just the running average loss
|
154 |
+
progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
|
155 |
+
|
156 |
+
accumulated_loss = 0.0
|
157 |
+
else:
|
158 |
+
optimizer_times.append(0) # No optimizer step taken
|
159 |
+
|
160 |
+
batch_end = time.time()
|
161 |
+
batch_times.append(batch_end - batch_start)
|
162 |
+
|
163 |
+
epoch_end = time.time()
|
164 |
+
|
165 |
+
# Calculate average loss for the epoch
|
166 |
+
epoch_avg_loss = total_loss / num_batches
|
167 |
+
|
168 |
+
# If distributed, gather losses from all processes to compute global average
|
169 |
+
if self.is_distributed:
|
170 |
+
# Create a tensor to hold the loss
|
171 |
+
loss_tensor = torch.tensor([epoch_avg_loss], device=self.device)
|
172 |
+
# Gather losses from all processes
|
173 |
+
all_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
|
174 |
+
dist.all_gather(all_losses, loss_tensor)
|
175 |
+
# Compute the global average loss across all processes
|
176 |
+
epoch_avg_loss = torch.mean(torch.stack(all_losses)).item()
|
177 |
+
|
178 |
+
if self.is_main_process:
|
179 |
+
# douhble check if batch_size has already been adjusted for world_size in the config
|
180 |
+
# This avoids double-counting the effective batch size
|
181 |
+
per_gpu_batch_size = self.config['batch_size']
|
182 |
+
total_effective_batch = per_gpu_batch_size * accumulation_steps * world_size
|
183 |
+
|
184 |
+
print(f"Epoch {epoch+1} timing:")
|
185 |
+
print(f" Total epoch time: {epoch_end - epoch_start:.2f}s")
|
186 |
+
print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
|
187 |
+
print(f" Average forward time: {sum(forward_times)/len(forward_times):.4f}s")
|
188 |
+
print(f" Average backward time: {sum(backward_times)/len(backward_times):.4f}s")
|
189 |
+
print(f" Average optimizer time: {sum([t for t in optimizer_times if t > 0])/max(1, len([t for t in optimizer_times if t > 0])):.4f}s")
|
190 |
+
print(f" Gradient accumulation steps: {accumulation_steps}")
|
191 |
+
print(f" Batch size per GPU: {per_gpu_batch_size}")
|
192 |
+
print(f" Effective global batch size: {total_effective_batch}")
|
193 |
+
print(f" Average training loss: {epoch_avg_loss:.4f}")
|
194 |
+
if self.is_distributed:
|
195 |
+
print(f" Total batches processed across all GPUs: {total_batches_global}")
|
196 |
+
print(f" Communication optimization: Using no_sync() for gradient accumulation")
|
197 |
+
|
198 |
+
return epoch_avg_loss # Return the average loss for the epoch
|
199 |
+
|
200 |
+
def validate_model(self, val_loader):
|
201 |
+
val_start = time.time()
|
202 |
+
self.model.eval()
|
203 |
+
val_loss = 0.0
|
204 |
+
task_true_labels = {task_name: [] for task_name in self.config["task_names"]}
|
205 |
+
task_pred_labels = {task_name: [] for task_name in self.config["task_names"]}
|
206 |
+
task_pred_probs = {task_name: [] for task_name in self.config["task_names"]}
|
207 |
+
|
208 |
+
val_cell_ids = {}
|
209 |
+
sample_counter = 0
|
210 |
+
|
211 |
+
batch_times = []
|
212 |
+
|
213 |
+
# Print validation dataset size
|
214 |
+
if self.is_main_process:
|
215 |
+
print(f"Validation dataset size: {len(val_loader.dataset)} samples")
|
216 |
+
print(f"Number of validation batches: {len(val_loader)}")
|
217 |
+
|
218 |
+
if self.is_distributed:
|
219 |
+
world_size = dist.get_world_size()
|
220 |
+
print(f"Distributed validation: {world_size} GPUs")
|
221 |
+
if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
|
222 |
+
samples_per_gpu = val_loader.sampler.num_samples
|
223 |
+
print(f"Each GPU processes {samples_per_gpu} validation samples")
|
224 |
+
print(f"Total validation samples processed: {samples_per_gpu * world_size}")
|
225 |
+
|
226 |
+
with torch.no_grad():
|
227 |
+
for batch in val_loader:
|
228 |
+
batch_start = time.time()
|
229 |
+
input_ids = batch["input_ids"].to(self.device)
|
230 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
231 |
+
labels = [
|
232 |
+
batch["labels"][task_name].to(self.device)
|
233 |
+
for task_name in self.config["task_names"]
|
234 |
+
]
|
235 |
+
loss, logits, _ = self.model(input_ids, attention_mask, labels)
|
236 |
+
val_loss += loss.item()
|
237 |
+
|
238 |
+
if "cell_id" in batch:
|
239 |
+
for i, cell_id in enumerate(batch["cell_id"]):
|
240 |
+
# Store the actual index for later mapping to unique_cell_id
|
241 |
+
val_cell_ids[sample_counter + i] = cell_id.item()
|
242 |
+
|
243 |
+
for sample_idx in range(len(batch["input_ids"])):
|
244 |
+
for i, task_name in enumerate(self.config["task_names"]):
|
245 |
+
true_label = batch["labels"][task_name][sample_idx].item()
|
246 |
+
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
247 |
+
# Store the full probability distribution
|
248 |
+
pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy().tolist()
|
249 |
+
task_true_labels[task_name].append(true_label)
|
250 |
+
task_pred_labels[task_name].append(pred_label)
|
251 |
+
task_pred_probs[task_name].append(pred_prob)
|
252 |
+
|
253 |
+
# Update current index for cell ID tracking
|
254 |
+
sample_counter += len(batch["input_ids"])
|
255 |
+
|
256 |
+
batch_end = time.time()
|
257 |
+
batch_times.append(batch_end - batch_start)
|
258 |
+
|
259 |
+
# norm validation loss by the number of batches
|
260 |
+
val_loss /= len(val_loader)
|
261 |
+
|
262 |
+
# distributed, gather results from all processes
|
263 |
+
if self.is_distributed:
|
264 |
+
# Create tensors to hold the local results
|
265 |
+
loss_tensor = torch.tensor([val_loss], device=self.device)
|
266 |
+
gathered_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
|
267 |
+
dist.all_gather(gathered_losses, loss_tensor)
|
268 |
+
|
269 |
+
# Compute average loss across all processes
|
270 |
+
val_loss = torch.mean(torch.cat(gathered_losses)).item()
|
271 |
+
|
272 |
+
world_size = dist.get_world_size()
|
273 |
+
|
274 |
+
if self.is_main_process:
|
275 |
+
print(f"Collected predictions from rank {self.local_rank}")
|
276 |
+
print(f"Number of samples processed by this rank: {sample_counter}")
|
277 |
+
|
278 |
+
val_end = time.time()
|
279 |
+
|
280 |
+
if self.is_main_process:
|
281 |
+
print(f"Validation timing:")
|
282 |
+
print(f" Total validation time: {val_end - val_start:.2f}s")
|
283 |
+
print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
|
284 |
+
print(f" Collected {len(val_cell_ids)} cell indices from validation data")
|
285 |
+
print(f" Processed {sample_counter} total samples during validation")
|
286 |
+
|
287 |
+
# Print number of samples per task
|
288 |
+
for task_name in self.config["task_names"]:
|
289 |
+
print(f" Task {task_name}: {len(task_true_labels[task_name])} samples")
|
290 |
+
|
291 |
+
return val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids
|
292 |
+
|
293 |
+
def train(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
|
294 |
+
"""Train the model and return validation loss and trained model."""
|
295 |
+
if self.config.get("use_wandb", False) and self.is_main_process:
|
296 |
+
initialize_wandb(self.config)
|
297 |
+
|
298 |
+
# Create model
|
299 |
+
self.model = create_model(self.config, num_labels_list, self.device, self.is_distributed, self.local_rank)
|
300 |
+
|
301 |
+
# Setup optimizer and scheduler
|
302 |
+
total_steps = len(train_loader) * self.config["epochs"]
|
303 |
+
self.optimizer, self.scheduler = setup_optimizer_and_scheduler(self.model, self.config, total_steps)
|
304 |
+
|
305 |
+
# Training loop
|
306 |
+
if self.is_main_process:
|
307 |
+
epoch_progress = tqdm(range(self.config["epochs"]), desc="Training Progress")
|
308 |
+
else:
|
309 |
+
epoch_progress = range(self.config["epochs"])
|
310 |
+
|
311 |
+
best_val_loss = float('inf')
|
312 |
+
train_losses = []
|
313 |
+
|
314 |
+
with setup_logging(self.config) as self.writer:
|
315 |
+
for epoch in epoch_progress:
|
316 |
+
if self.is_distributed:
|
317 |
+
train_loader.sampler.set_epoch(epoch)
|
318 |
+
|
319 |
+
train_loss = self.train_epoch(train_loader, epoch)
|
320 |
+
train_losses.append(train_loss)
|
321 |
+
|
322 |
+
# Run validation after each epoch if configured to do so
|
323 |
+
if self.config.get("validate_each_epoch", False):
|
324 |
+
val_loss, _, _, _, _ = self.validate_model(val_loader)
|
325 |
+
if val_loss < best_val_loss:
|
326 |
+
best_val_loss = val_loss
|
327 |
+
|
328 |
+
if self.is_main_process:
|
329 |
+
epoch_progress.set_postfix({
|
330 |
+
"train_loss": f"{train_loss:.4f}",
|
331 |
+
"val_loss": f"{val_loss:.4f}",
|
332 |
+
"best_val_loss": f"{best_val_loss:.4f}"
|
333 |
+
})
|
334 |
+
else:
|
335 |
+
if self.is_main_process:
|
336 |
+
epoch_progress.set_postfix({
|
337 |
+
"train_loss": f"{train_loss:.4f}"
|
338 |
+
})
|
339 |
+
|
340 |
+
val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = self.validate_model(val_loader)
|
341 |
+
task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
|
342 |
+
|
343 |
+
if self.is_main_process:
|
344 |
+
log_validation_metrics(task_metrics, val_loss, self.config, self.writer, self.config["epochs"])
|
345 |
+
|
346 |
+
# Save validation predictions
|
347 |
+
save_validation_predictions(
|
348 |
+
val_cell_ids,
|
349 |
+
task_true_labels,
|
350 |
+
task_pred_labels,
|
351 |
+
task_pred_probs,
|
352 |
+
{**self.config, "val_cell_mapping": val_cell_id_mapping} # Include the mapping
|
353 |
+
)
|
354 |
+
|
355 |
+
if self.config.get("use_wandb", False):
|
356 |
+
import wandb
|
357 |
+
wandb.finish()
|
358 |
+
|
359 |
+
print(f"\nTraining Summary:")
|
360 |
+
print(f" Final Training Loss: {train_losses[-1]:.4f}")
|
361 |
+
print(f" Final Validation Loss: {val_loss:.4f}")
|
362 |
+
for task_name, metrics in task_metrics.items():
|
363 |
+
print(f" {task_name} - F1: {metrics['f1']:.4f}, Accuracy: {metrics['accuracy']:.4f}")
|
364 |
+
|
365 |
+
return val_loss, self.model # Return both the validation loss and the trained model
|
366 |
+
|
367 |
+
def setup(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
|
368 |
+
if self.is_distributed:
|
369 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
370 |
+
else:
|
371 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
372 |
+
|
373 |
+
self.model = create_model(self.config, num_labels_list, self.device)
|
374 |
+
|
375 |
+
# war model w DDP
|
376 |
+
if self.is_distributed:
|
377 |
+
self.model = DDP(self.model, device_ids=[self.local_rank])
|
378 |
+
|
379 |
+
# communication hook to optimize gradient synchronization
|
380 |
+
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
|
381 |
+
|
382 |
+
# default hook which maintains full precision
|
383 |
+
self.model.register_comm_hook(
|
384 |
+
state=None,
|
385 |
+
hook=comm_hooks.allreduce_hook
|
386 |
)
|
387 |
+
|
388 |
+
print(f"Rank {self.local_rank}: Registered communication hook for optimized gradient synchronization")
|
389 |
+
|
390 |
+
print(f"Rank {self.local_rank}: Using samplers created in distributed worker")
|
391 |
+
print(f"Rank {self.local_rank}: Training dataset has {len(train_loader.dataset)} samples")
|
392 |
+
if hasattr(train_loader, 'sampler') and hasattr(train_loader.sampler, 'num_samples'):
|
393 |
+
print(f"Rank {self.local_rank}: This GPU will process {train_loader.sampler.num_samples} training samples per epoch")
|
394 |
+
|
395 |
+
if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
|
396 |
+
print(f"Rank {self.local_rank}: This GPU will process {val_loader.sampler.num_samples} validation samples")
|
397 |
+
|
398 |
+
# Set up optimizer and scheduler
|
399 |
+
self.optimizer, self.scheduler = setup_optimizer_and_scheduler(
|
400 |
+
self.model, self.config, len(train_loader)
|
401 |
)
|
402 |
|
403 |
+
if self.is_main_process and self.config.get("use_wandb", False):
|
404 |
+
initialize_wandb(self.config)
|
405 |
+
|
406 |
+
return train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
|
407 |
+
|
408 |
+
|
409 |
+
def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
|
410 |
+
"""Train a model with the given configuration and data."""
|
411 |
+
# Check if distributed training is enabled
|
412 |
+
if config.get("distributed_training", False):
|
413 |
+
# Check if we have multiple GPUs
|
414 |
+
if torch.cuda.device_count() > 1:
|
415 |
+
result = train_distributed(
|
416 |
+
Trainer,
|
417 |
+
config,
|
418 |
+
train_loader,
|
419 |
+
val_loader,
|
420 |
+
train_cell_id_mapping,
|
421 |
+
val_cell_id_mapping,
|
422 |
+
num_labels_list
|
|
|
|
|
|
|
|
|
423 |
)
|
424 |
+
if result is not None:
|
425 |
+
return result
|
426 |
+
else:
|
427 |
+
print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
|
428 |
+
config["distributed_training"] = False
|
429 |
+
|
430 |
+
# Non-distributed training
|
431 |
+
trainer = Trainer(config)
|
432 |
+
trainer.device = device
|
433 |
+
return trainer.train(train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
|
434 |
|
435 |
|
436 |
+
def objective(
|
437 |
+
trial,
|
|
|
438 |
train_loader,
|
439 |
val_loader,
|
440 |
train_cell_id_mapping,
|
441 |
val_cell_id_mapping,
|
442 |
num_labels_list,
|
443 |
+
config,
|
444 |
+
device,
|
445 |
):
|
446 |
+
"""Objective function for Optuna hyperparameter optimization."""
|
447 |
set_seed(config["seed"])
|
448 |
initialize_wandb(config)
|
449 |
|
450 |
+
trial_config = config.copy()
|
451 |
+
|
452 |
+
# Suggest hyperparameters for this trial
|
453 |
+
for param_name, param_config in config["hyperparameters"].items():
|
454 |
+
if param_name == "lr_scheduler_type":
|
455 |
+
trial_config[param_name] = trial.suggest_categorical(
|
456 |
+
param_name, param_config["choices"]
|
457 |
+
)
|
458 |
+
elif param_name == "task_weights" and config["use_task_weights"]:
|
459 |
+
weights = [
|
460 |
+
trial.suggest_float(
|
461 |
+
f"task_weight_{i}",
|
462 |
+
param_config["low"],
|
463 |
+
param_config["high"],
|
464 |
+
)
|
465 |
+
for i in range(len(num_labels_list))
|
466 |
+
]
|
467 |
+
weight_sum = sum(weights)
|
468 |
+
trial_config[param_name] = [w / weight_sum for w in weights]
|
469 |
+
elif "log" in param_config and param_config["log"]:
|
470 |
+
trial_config[param_name] = trial.suggest_float(
|
471 |
+
param_name, param_config["low"], param_config["high"], log=True
|
472 |
+
)
|
473 |
+
else:
|
474 |
+
trial_config[param_name] = trial.suggest_float(
|
475 |
+
param_name, param_config["low"], param_config["high"]
|
476 |
+
)
|
477 |
+
|
478 |
+
# Set appropriate max layers to freeze based on pretrained model
|
479 |
+
if "max_layers_to_freeze" in trial_config:
|
480 |
+
freeze_range = get_layer_freeze_range(trial_config["pretrained_path"])
|
481 |
+
trial_config["max_layers_to_freeze"] = int(trial.suggest_int(
|
482 |
+
"max_layers_to_freeze",
|
483 |
+
freeze_range["min"],
|
484 |
+
freeze_range["max"]
|
485 |
+
))
|
486 |
+
|
487 |
+
trial_config["run_name"] = f"trial_{trial.number}"
|
488 |
+
|
489 |
+
# Handle distributed training for this trial
|
490 |
+
if trial_config.get("distributed_training", False) and torch.cuda.device_count() > 1:
|
491 |
+
manager = mp.Manager()
|
492 |
+
shared_dict = manager.dict()
|
493 |
+
|
494 |
+
train_distributed(
|
495 |
+
Trainer,
|
496 |
+
trial_config,
|
497 |
+
train_loader,
|
498 |
+
val_loader,
|
499 |
+
train_cell_id_mapping,
|
500 |
+
val_cell_id_mapping,
|
501 |
+
num_labels_list,
|
502 |
+
trial.number,
|
503 |
+
shared_dict
|
504 |
+
)
|
505 |
+
|
506 |
+
val_loss = shared_dict.get('val_loss', float('inf'))
|
507 |
+
task_metrics = shared_dict.get('task_metrics', {})
|
508 |
+
|
509 |
+
trial.set_user_attr("model_state_dict", shared_dict.get('model_state_dict', {}))
|
510 |
+
trial.set_user_attr("task_weights", trial_config["task_weights"])
|
511 |
+
|
512 |
+
if config.get("use_wandb", False):
|
513 |
+
import wandb
|
514 |
+
wandb.log({
|
515 |
+
"trial_number": trial.number,
|
516 |
+
"val_loss": val_loss,
|
517 |
+
**{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
|
518 |
+
**{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
|
519 |
+
})
|
520 |
+
wandb.finish()
|
521 |
+
|
522 |
+
return val_loss
|
523 |
+
|
524 |
+
with setup_logging(trial_config) as writer:
|
525 |
+
trainer = Trainer(trial_config)
|
526 |
+
trainer.device = device
|
527 |
+
trainer.writer = writer
|
528 |
+
|
529 |
+
# Create model with trial hyperparameters
|
530 |
+
trainer.model = create_model(trial_config, num_labels_list, device)
|
531 |
+
total_steps = len(train_loader) * config["epochs"]
|
532 |
+
trainer.optimizer, trainer.scheduler = setup_optimizer_and_scheduler(trainer.model, trial_config, total_steps)
|
533 |
+
|
534 |
+
# Training loop
|
535 |
+
for epoch in range(config["epochs"]):
|
536 |
+
trainer.train_epoch(train_loader, epoch)
|
537 |
+
|
538 |
+
val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = trainer.validate_model(val_loader)
|
539 |
+
task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
|
540 |
+
|
541 |
+
# Log metrics
|
542 |
+
log_validation_metrics(task_metrics, val_loss, trial_config, writer, config["epochs"])
|
543 |
+
|
544 |
+
# Save validation predictions
|
545 |
+
save_validation_predictions(
|
546 |
+
val_cell_ids,
|
547 |
+
task_true_labels,
|
548 |
+
task_pred_labels,
|
549 |
+
task_pred_probs,
|
550 |
+
{**trial_config, "val_cell_mapping": val_cell_id_mapping},
|
551 |
+
trial.number,
|
552 |
)
|
|
|
553 |
|
554 |
+
# Store model state dict and task weights in trial user attributes
|
555 |
+
trial.set_user_attr("model_state_dict", trainer.model.state_dict())
|
556 |
+
trial.set_user_attr("task_weights", trial_config["task_weights"])
|
|
|
557 |
|
558 |
+
# Report intermediate value to Optuna
|
559 |
+
trial.report(val_loss, config["epochs"])
|
560 |
+
if trial.should_prune():
|
561 |
+
raise optuna.TrialPruned()
|
562 |
|
563 |
+
if config.get("use_wandb", False):
|
564 |
+
import wandb
|
565 |
+
wandb.log(
|
566 |
+
{
|
567 |
+
"trial_number": trial.number,
|
568 |
+
"val_loss": val_loss,
|
569 |
+
**{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
|
570 |
+
**{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
|
571 |
+
**{k: v for k, v in trial_config.items() if k in [
|
572 |
+
"learning_rate", "warmup_ratio", "weight_decay", "dropout_rate",
|
573 |
+
"lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze"
|
574 |
+
]},
|
575 |
+
}
|
576 |
+
)
|
577 |
+
wandb.finish()
|
578 |
|
579 |
+
return val_loss
|
|
|
580 |
|
|
|
581 |
|
582 |
+
def run_manual_tuning(config):
|
583 |
+
"""Run training with manually specified hyperparameters."""
|
584 |
+
(
|
585 |
+
device,
|
586 |
+
train_loader,
|
587 |
+
val_loader,
|
588 |
+
train_cell_id_mapping,
|
589 |
+
val_cell_id_mapping,
|
590 |
+
num_labels_list,
|
591 |
+
) = prepare_training_environment(config)
|
592 |
|
593 |
+
print("\nManual hyperparameters being used:")
|
594 |
+
for key, value in config["manual_hyperparameters"].items():
|
595 |
+
print(f"{key}: {value}")
|
596 |
+
print()
|
597 |
|
598 |
+
# Update config with manual hyperparameters
|
599 |
+
for key, value in config["manual_hyperparameters"].items():
|
600 |
+
config[key] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
601 |
|
602 |
+
# Train the model
|
603 |
+
val_loss, trained_model = train_model(
|
604 |
+
config,
|
605 |
+
device,
|
606 |
+
train_loader,
|
607 |
+
val_loader,
|
608 |
+
train_cell_id_mapping,
|
609 |
+
val_cell_id_mapping,
|
610 |
+
num_labels_list,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
611 |
)
|
612 |
|
613 |
+
print(f"\nValidation loss with manual hyperparameters: {val_loss}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
|
615 |
+
# Save the trained model - only if not using distributed training
|
616 |
+
# (distributed training saves the model in the worker)
|
617 |
+
if not config.get("distributed_training", False):
|
618 |
+
model_save_directory = os.path.join(
|
619 |
+
config["model_save_path"], "GeneformerMultiTask"
|
620 |
+
)
|
621 |
+
save_model(trained_model, model_save_directory)
|
622 |
+
|
623 |
+
# Save the hyperparameters
|
624 |
+
hyperparams_to_save = {
|
625 |
+
**config["manual_hyperparameters"],
|
626 |
+
"dropout_rate": config["dropout_rate"],
|
627 |
+
"use_task_weights": config["use_task_weights"],
|
628 |
+
"task_weights": config["task_weights"],
|
629 |
+
"max_layers_to_freeze": config["max_layers_to_freeze"],
|
630 |
+
"use_attention_pooling": config["use_attention_pooling"],
|
631 |
+
}
|
632 |
+
save_hyperparameters(model_save_directory, hyperparams_to_save)
|
633 |
|
634 |
+
return val_loss
|
|
|
635 |
|
|
|
|
|
|
|
|
|
636 |
|
637 |
+
def run_optuna_study(config):
|
638 |
+
"""Run hyperparameter optimization using Optuna."""
|
639 |
+
# Prepare training environment
|
640 |
+
(
|
641 |
+
device,
|
642 |
+
train_loader,
|
643 |
+
val_loader,
|
644 |
+
train_cell_id_mapping,
|
645 |
+
val_cell_id_mapping,
|
646 |
+
num_labels_list,
|
647 |
+
) = prepare_training_environment(config)
|
648 |
+
|
649 |
+
# If manual hyperparameters are specified, use them instead of running Optuna
|
650 |
+
if config.get("use_manual_hyperparameters", False):
|
651 |
+
return run_manual_tuning(config)
|
652 |
+
|
653 |
+
# Create a partial function with fixed arguments for the objective
|
654 |
+
objective_with_config_and_data = functools.partial(
|
655 |
+
objective,
|
656 |
+
train_loader=train_loader,
|
657 |
+
val_loader=val_loader,
|
658 |
+
train_cell_id_mapping=train_cell_id_mapping,
|
659 |
+
val_cell_id_mapping=val_cell_id_mapping,
|
660 |
+
num_labels_list=num_labels_list,
|
661 |
+
config=config,
|
662 |
+
device=device,
|
663 |
)
|
|
|
664 |
|
665 |
+
# Create and run the Optuna study
|
666 |
+
study = optuna.create_study(
|
667 |
+
direction="minimize", # Minimize validation loss
|
668 |
+
study_name=config["study_name"],
|
669 |
+
# storage=config["storage"],
|
670 |
+
load_if_exists=True,
|
671 |
+
)
|
672 |
|
673 |
+
study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
|
674 |
+
|
675 |
+
# After finding the best trial
|
676 |
+
best_params = study.best_trial.params
|
677 |
+
best_task_weights = study.best_trial.user_attrs["task_weights"]
|
678 |
+
print("Saving the best model and its hyperparameters...")
|
679 |
+
|
680 |
+
# Create a model with the best hyperparameters
|
681 |
+
best_model = GeneformerMultiTask(
|
682 |
+
config["pretrained_path"],
|
683 |
+
num_labels_list,
|
684 |
+
dropout_rate=best_params["dropout_rate"],
|
685 |
+
use_task_weights=config["use_task_weights"],
|
686 |
+
task_weights=best_task_weights,
|
687 |
+
max_layers_to_freeze=best_params.get("max_layers_to_freeze", 0),
|
688 |
+
use_attention_pooling=best_params.get("use_attention_pooling", False),
|
689 |
)
|
690 |
|
691 |
+
# Get the best model state dictionary
|
692 |
+
best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
|
693 |
|
694 |
+
best_model_state_dict = {
|
695 |
+
k.replace("module.", ""): v for k, v in best_model_state_dict.items()
|
696 |
+
}
|
697 |
|
698 |
+
best_model.load_state_dict(best_model_state_dict, strict=False)
|
|
|
699 |
|
700 |
+
model_save_directory = os.path.join(
|
701 |
+
config["model_save_path"], "GeneformerMultiTask"
|
702 |
+
)
|
703 |
+
save_model(best_model, model_save_directory)
|
704 |
|
705 |
+
save_hyperparameters(model_save_directory, {**best_params, "task_weights": best_task_weights})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
706 |
|
707 |
+
return study.best_trial.value # Return the best validation loss
|
geneformer/mtl/train_utils.py
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
|
3 |
-
from .data import get_data_loader, preload_and_process_data
|
4 |
-
from .imports import *
|
5 |
-
from .model import GeneformerMultiTask
|
6 |
-
from .train import objective, train_model
|
7 |
-
from .utils import save_model
|
8 |
-
|
9 |
-
|
10 |
-
def set_seed(seed):
|
11 |
-
random.seed(seed)
|
12 |
-
np.random.seed(seed)
|
13 |
-
torch.manual_seed(seed)
|
14 |
-
torch.cuda.manual_seed_all(seed)
|
15 |
-
torch.backends.cudnn.deterministic = True
|
16 |
-
torch.backends.cudnn.benchmark = False
|
17 |
-
|
18 |
-
|
19 |
-
def run_manual_tuning(config):
|
20 |
-
# Set seed for reproducibility
|
21 |
-
set_seed(config["seed"])
|
22 |
-
|
23 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
-
(
|
25 |
-
train_dataset,
|
26 |
-
train_cell_id_mapping,
|
27 |
-
val_dataset,
|
28 |
-
val_cell_id_mapping,
|
29 |
-
num_labels_list,
|
30 |
-
) = preload_and_process_data(config)
|
31 |
-
train_loader = get_data_loader(train_dataset, config["batch_size"])
|
32 |
-
val_loader = get_data_loader(val_dataset, config["batch_size"])
|
33 |
-
|
34 |
-
# Print the manual hyperparameters being used
|
35 |
-
print("\nManual hyperparameters being used:")
|
36 |
-
for key, value in config["manual_hyperparameters"].items():
|
37 |
-
print(f"{key}: {value}")
|
38 |
-
print() # Add an empty line for better readability
|
39 |
-
|
40 |
-
# Use the manual hyperparameters
|
41 |
-
for key, value in config["manual_hyperparameters"].items():
|
42 |
-
config[key] = value
|
43 |
-
|
44 |
-
# Train the model
|
45 |
-
val_loss, trained_model = train_model(
|
46 |
-
config,
|
47 |
-
device,
|
48 |
-
train_loader,
|
49 |
-
val_loader,
|
50 |
-
train_cell_id_mapping,
|
51 |
-
val_cell_id_mapping,
|
52 |
-
num_labels_list,
|
53 |
-
)
|
54 |
-
|
55 |
-
print(f"\nValidation loss with manual hyperparameters: {val_loss}")
|
56 |
-
|
57 |
-
# Save the trained model
|
58 |
-
model_save_directory = os.path.join(
|
59 |
-
config["model_save_path"], "GeneformerMultiTask"
|
60 |
-
)
|
61 |
-
save_model(trained_model, model_save_directory)
|
62 |
-
|
63 |
-
# Save the hyperparameters
|
64 |
-
hyperparams_to_save = {
|
65 |
-
**config["manual_hyperparameters"],
|
66 |
-
"dropout_rate": config["dropout_rate"],
|
67 |
-
"use_task_weights": config["use_task_weights"],
|
68 |
-
"task_weights": config["task_weights"],
|
69 |
-
"max_layers_to_freeze": config["max_layers_to_freeze"],
|
70 |
-
"use_attention_pooling": config["use_attention_pooling"],
|
71 |
-
}
|
72 |
-
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
73 |
-
with open(hyperparams_path, "w") as f:
|
74 |
-
json.dump(hyperparams_to_save, f)
|
75 |
-
print(f"Manual hyperparameters saved to {hyperparams_path}")
|
76 |
-
|
77 |
-
return val_loss
|
78 |
-
|
79 |
-
|
80 |
-
def run_optuna_study(config):
|
81 |
-
# Set seed for reproducibility
|
82 |
-
set_seed(config["seed"])
|
83 |
-
|
84 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
85 |
-
(
|
86 |
-
train_dataset,
|
87 |
-
train_cell_id_mapping,
|
88 |
-
val_dataset,
|
89 |
-
val_cell_id_mapping,
|
90 |
-
num_labels_list,
|
91 |
-
) = preload_and_process_data(config)
|
92 |
-
train_loader = get_data_loader(train_dataset, config["batch_size"])
|
93 |
-
val_loader = get_data_loader(val_dataset, config["batch_size"])
|
94 |
-
|
95 |
-
if config["use_manual_hyperparameters"]:
|
96 |
-
train_model(
|
97 |
-
config,
|
98 |
-
device,
|
99 |
-
train_loader,
|
100 |
-
val_loader,
|
101 |
-
train_cell_id_mapping,
|
102 |
-
val_cell_id_mapping,
|
103 |
-
num_labels_list,
|
104 |
-
)
|
105 |
-
else:
|
106 |
-
objective_with_config_and_data = functools.partial(
|
107 |
-
objective,
|
108 |
-
train_loader=train_loader,
|
109 |
-
val_loader=val_loader,
|
110 |
-
train_cell_id_mapping=train_cell_id_mapping,
|
111 |
-
val_cell_id_mapping=val_cell_id_mapping,
|
112 |
-
num_labels_list=num_labels_list,
|
113 |
-
config=config,
|
114 |
-
device=device,
|
115 |
-
)
|
116 |
-
|
117 |
-
study = optuna.create_study(
|
118 |
-
direction="minimize", # Minimize validation loss
|
119 |
-
study_name=config["study_name"],
|
120 |
-
# storage=config["storage"],
|
121 |
-
load_if_exists=True,
|
122 |
-
)
|
123 |
-
|
124 |
-
study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
|
125 |
-
|
126 |
-
# After finding the best trial
|
127 |
-
best_params = study.best_trial.params
|
128 |
-
best_task_weights = study.best_trial.user_attrs["task_weights"]
|
129 |
-
print("Saving the best model and its hyperparameters...")
|
130 |
-
|
131 |
-
# Saving model as before
|
132 |
-
best_model = GeneformerMultiTask(
|
133 |
-
config["pretrained_path"],
|
134 |
-
num_labels_list,
|
135 |
-
dropout_rate=best_params["dropout_rate"],
|
136 |
-
use_task_weights=config["use_task_weights"],
|
137 |
-
task_weights=best_task_weights,
|
138 |
-
)
|
139 |
-
|
140 |
-
# Get the best model state dictionary
|
141 |
-
best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
|
142 |
-
|
143 |
-
# Remove the "module." prefix from the state dictionary keys if present
|
144 |
-
best_model_state_dict = {
|
145 |
-
k.replace("module.", ""): v for k, v in best_model_state_dict.items()
|
146 |
-
}
|
147 |
-
|
148 |
-
# Load the modified state dictionary into the model, skipping unexpected keys
|
149 |
-
best_model.load_state_dict(best_model_state_dict, strict=False)
|
150 |
-
|
151 |
-
model_save_directory = os.path.join(
|
152 |
-
config["model_save_path"], "GeneformerMultiTask"
|
153 |
-
)
|
154 |
-
save_model(best_model, model_save_directory)
|
155 |
-
|
156 |
-
# Additionally, save the best hyperparameters and task weights
|
157 |
-
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
158 |
-
|
159 |
-
with open(hyperparams_path, "w") as f:
|
160 |
-
json.dump({**best_params, "task_weights": best_task_weights}, f)
|
161 |
-
print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl/utils.py
CHANGED
@@ -1,129 +1,641 @@
|
|
|
|
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
|
|
|
|
|
|
|
|
|
4 |
from sklearn.metrics import accuracy_score, f1_score
|
5 |
from sklearn.preprocessing import LabelEncoder
|
6 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
from .imports import *
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
def save_model(model, model_save_directory):
|
12 |
-
if not os.path.exists(model_save_directory):
|
13 |
-
os.makedirs(model_save_directory)
|
14 |
-
|
15 |
-
# Get the state dict
|
16 |
-
if isinstance(model, nn.DataParallel):
|
17 |
-
model_state_dict = (
|
18 |
-
model.module.state_dict()
|
19 |
-
) # Use model.module to access the underlying model
|
20 |
-
else:
|
21 |
-
model_state_dict = model.state_dict()
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
|
29 |
torch.save(model_state_dict, model_save_path)
|
30 |
|
31 |
# Save the model configuration
|
32 |
-
|
33 |
-
model.module.config.to_json_file(
|
34 |
-
os.path.join(model_save_directory, "config.json")
|
35 |
-
)
|
36 |
-
else:
|
37 |
-
model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
|
38 |
|
39 |
print(f"Model and configuration saved to {model_save_directory}")
|
40 |
|
41 |
|
42 |
-
def
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
accuracy = accuracy_score(true_labels, pred_labels)
|
49 |
-
task_metrics[task_name] = {"f1": f1, "accuracy": accuracy}
|
50 |
-
return task_metrics
|
51 |
|
52 |
|
53 |
-
def
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
# Fit and transform combined labels and predictions to numerical values
|
58 |
-
le.fit(combined_labels + combined_preds)
|
59 |
-
encoded_true_labels = le.transform(combined_labels)
|
60 |
-
encoded_pred_labels = le.transform(combined_preds)
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
|
67 |
-
# Calculate accuracy
|
68 |
-
accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
|
69 |
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
|
76 |
-
#
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
# # Load the model state dictionary
|
83 |
-
# model_state_dict = torch.load(
|
84 |
-
# os.path.join(original_model_save_directory, "pytorch_model.bin")
|
85 |
-
# )
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
# # Filter the state dict to exclude classification heads
|
94 |
-
# model_without_heads_state_dict = {
|
95 |
-
# k: v
|
96 |
-
# for k, v in model_state_dict.items()
|
97 |
-
# if not k.startswith("classification_heads")
|
98 |
-
# }
|
99 |
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
# # Save the model without heads
|
104 |
-
# model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
|
105 |
-
# torch.save(model_without_heads.state_dict(), model_save_path)
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
112 |
|
113 |
-
# print(f"Model without classification heads saved to {new_model_save_directory}")
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
"""
|
118 |
-
|
|
|
119 |
Args:
|
120 |
-
|
121 |
-
|
122 |
-
dict:
|
123 |
"""
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Union
|
2 |
+
import json
|
3 |
import os
|
4 |
+
import pickle
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import wandb
|
9 |
+
import optuna
|
10 |
from sklearn.metrics import accuracy_score, f1_score
|
11 |
from sklearn.preprocessing import LabelEncoder
|
12 |
+
from torch.utils.tensorboard import SummaryWriter
|
13 |
+
from transformers import AutoConfig, BertConfig, BertModel, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
|
14 |
+
from torch.optim import AdamW
|
15 |
+
import pandas as pd
|
16 |
+
import torch.distributed as dist
|
17 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
18 |
+
import torch.multiprocessing as mp
|
19 |
+
from contextlib import contextmanager
|
20 |
|
|
|
21 |
|
22 |
+
def set_seed(seed):
|
23 |
+
random.seed(seed)
|
24 |
+
np.random.seed(seed)
|
25 |
+
torch.manual_seed(seed)
|
26 |
+
torch.cuda.manual_seed_all(seed)
|
27 |
+
torch.backends.cudnn.deterministic = True
|
28 |
+
torch.backends.cudnn.benchmark = False
|
29 |
+
|
30 |
+
|
31 |
+
def initialize_wandb(config):
|
32 |
+
if config.get("use_wandb", False):
|
33 |
+
wandb.init(
|
34 |
+
project=config.get("wandb_project", "geneformer_multitask"),
|
35 |
+
name=config.get("run_name", "experiment"),
|
36 |
+
config=config,
|
37 |
+
reinit=True,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def create_model(config, num_labels_list, device, is_distributed=False, local_rank=0):
|
42 |
+
"""Create and initialize the model based on configuration."""
|
43 |
+
from .model import GeneformerMultiTask
|
44 |
+
|
45 |
+
model = GeneformerMultiTask(
|
46 |
+
config["pretrained_path"],
|
47 |
+
num_labels_list,
|
48 |
+
dropout_rate=config.get("dropout_rate", 0.1),
|
49 |
+
use_task_weights=config.get("use_task_weights", False),
|
50 |
+
task_weights=config.get("task_weights", None),
|
51 |
+
max_layers_to_freeze=config.get("max_layers_to_freeze", 0),
|
52 |
+
use_attention_pooling=config.get("use_attention_pooling", False),
|
53 |
+
)
|
54 |
+
|
55 |
+
# Move model to device
|
56 |
+
model.to(device)
|
57 |
+
|
58 |
+
if is_distributed:
|
59 |
+
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
|
60 |
+
|
61 |
+
return model
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
def setup_optimizer_and_scheduler(model, config, total_steps):
|
65 |
+
"""Set up optimizer and learning rate scheduler."""
|
66 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
67 |
+
optimizer_grouped_parameters = [
|
68 |
+
{
|
69 |
+
"params": [p for n, p in model.named_parameters()
|
70 |
+
if not any(nd in n for nd in no_decay) and p.requires_grad],
|
71 |
+
"weight_decay": config["weight_decay"],
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"params": [p for n, p in model.named_parameters()
|
75 |
+
if any(nd in n for nd in no_decay) and p.requires_grad],
|
76 |
+
"weight_decay": 0.0,
|
77 |
+
},
|
78 |
+
]
|
79 |
+
|
80 |
+
optimizer = AdamW(
|
81 |
+
optimizer_grouped_parameters,
|
82 |
+
lr=config["learning_rate"],
|
83 |
+
eps=config.get("adam_epsilon", 1e-8)
|
84 |
+
)
|
85 |
+
|
86 |
+
# Prepare scheduler
|
87 |
+
warmup_steps = int(total_steps * config["warmup_ratio"])
|
88 |
+
|
89 |
+
scheduler_map = {
|
90 |
+
"linear": get_linear_schedule_with_warmup,
|
91 |
+
"cosine": get_cosine_schedule_with_warmup
|
92 |
}
|
93 |
+
|
94 |
+
scheduler_fn = scheduler_map.get(config["lr_scheduler_type"])
|
95 |
+
if not scheduler_fn:
|
96 |
+
raise ValueError(f"Unsupported scheduler type: {config['lr_scheduler_type']}")
|
97 |
+
|
98 |
+
scheduler = scheduler_fn(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
|
99 |
+
|
100 |
+
return optimizer, scheduler
|
101 |
+
|
102 |
+
|
103 |
+
def save_model(model, model_save_directory):
|
104 |
+
"""Save model weights and configuration."""
|
105 |
+
os.makedirs(model_save_directory, exist_ok=True)
|
106 |
+
|
107 |
+
# Handle DDP model
|
108 |
+
if isinstance(model, DDP):
|
109 |
+
model_to_save = model.module
|
110 |
+
else:
|
111 |
+
model_to_save = model
|
112 |
+
|
113 |
+
model_state_dict = model_to_save.state_dict()
|
114 |
|
115 |
model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
|
116 |
torch.save(model_state_dict, model_save_path)
|
117 |
|
118 |
# Save the model configuration
|
119 |
+
model_to_save.config.to_json_file(os.path.join(model_save_directory, "config.json"))
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
print(f"Model and configuration saved to {model_save_directory}")
|
122 |
|
123 |
|
124 |
+
def save_hyperparameters(model_save_directory, hyperparams):
|
125 |
+
"""Save hyperparameters to a JSON file."""
|
126 |
+
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
127 |
+
with open(hyperparams_path, "w") as f:
|
128 |
+
json.dump(hyperparams, f)
|
129 |
+
print(f"Hyperparameters saved to {hyperparams_path}")
|
|
|
|
|
|
|
130 |
|
131 |
|
132 |
+
def calculate_metrics(labels=None, preds=None, task_data=None, metric_type="task_specific", return_format="dict"):
|
133 |
+
if metric_type == "single":
|
134 |
+
# Calculate metrics for a single task
|
135 |
+
if labels is None or preds is None:
|
136 |
+
raise ValueError("Labels and predictions must be provided for single task metrics")
|
137 |
+
|
138 |
+
task_name = None
|
139 |
+
if isinstance(labels, dict) and len(labels) == 1:
|
140 |
+
task_name = list(labels.keys())[0]
|
141 |
+
labels = labels[task_name]
|
142 |
+
preds = preds[task_name]
|
143 |
+
|
144 |
+
f1 = f1_score(labels, preds, average="macro")
|
145 |
+
accuracy = accuracy_score(labels, preds)
|
146 |
+
|
147 |
+
if return_format == "tuple":
|
148 |
+
return f1, accuracy
|
149 |
+
|
150 |
+
result = {"f1": f1, "accuracy": accuracy}
|
151 |
+
if task_name:
|
152 |
+
return {task_name: result}
|
153 |
+
return result
|
154 |
+
|
155 |
+
elif metric_type == "task_specific":
|
156 |
+
# Calculate metrics for multiple tasks
|
157 |
+
if task_data:
|
158 |
+
result = {}
|
159 |
+
for task_name, (task_labels, task_preds) in task_data.items():
|
160 |
+
f1 = f1_score(task_labels, task_preds, average="macro")
|
161 |
+
accuracy = accuracy_score(task_labels, task_preds)
|
162 |
+
result[task_name] = {"f1": f1, "accuracy": accuracy}
|
163 |
+
return result
|
164 |
+
elif isinstance(labels, dict) and isinstance(preds, dict):
|
165 |
+
result = {}
|
166 |
+
for task_name in labels:
|
167 |
+
if task_name in preds:
|
168 |
+
f1 = f1_score(labels[task_name], preds[task_name], average="macro")
|
169 |
+
accuracy = accuracy_score(labels[task_name], preds[task_name])
|
170 |
+
result[task_name] = {"f1": f1, "accuracy": accuracy}
|
171 |
+
return result
|
172 |
+
else:
|
173 |
+
raise ValueError("For task_specific metrics, either task_data or labels and preds dictionaries must be provided")
|
174 |
+
|
175 |
+
elif metric_type == "combined":
|
176 |
+
# Calculate combined metrics across all tasks
|
177 |
+
if labels is None or preds is None:
|
178 |
+
raise ValueError("Labels and predictions must be provided for combined metrics")
|
179 |
+
|
180 |
+
# Handle label encoding for non-numeric labels
|
181 |
+
if not all(isinstance(x, (int, float)) for x in labels + preds):
|
182 |
+
le = LabelEncoder()
|
183 |
+
le.fit(labels + preds)
|
184 |
+
labels = le.transform(labels)
|
185 |
+
preds = le.transform(preds)
|
186 |
+
|
187 |
+
f1 = f1_score(labels, preds, average="macro")
|
188 |
+
accuracy = accuracy_score(labels, preds)
|
189 |
+
|
190 |
+
if return_format == "tuple":
|
191 |
+
return f1, accuracy
|
192 |
+
return {"f1": f1, "accuracy": accuracy}
|
193 |
+
|
194 |
+
else:
|
195 |
+
raise ValueError(f"Unknown metric_type: {metric_type}")
|
196 |
|
|
|
|
|
|
|
|
|
197 |
|
198 |
+
def get_layer_freeze_range(pretrained_path):
|
199 |
+
if not pretrained_path:
|
200 |
+
return {"min": 0, "max": 0}
|
201 |
+
|
202 |
+
config = AutoConfig.from_pretrained(pretrained_path)
|
203 |
+
total_layers = config.num_hidden_layers
|
204 |
+
return {"min": 0, "max": total_layers - 1}
|
205 |
|
|
|
|
|
206 |
|
207 |
+
def prepare_training_environment(config):
|
208 |
+
"""
|
209 |
+
Prepare the training environment by setting seed and loading data.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
tuple: (device, train_loader, val_loader, train_cell_id_mapping,
|
213 |
+
val_cell_id_mapping, num_labels_list)
|
214 |
+
"""
|
215 |
+
from .data import prepare_data_loaders
|
216 |
+
|
217 |
+
# Set seed for reproducibility
|
218 |
+
set_seed(config["seed"])
|
219 |
|
220 |
+
# Set up device - for non-distributed training
|
221 |
+
if not config.get("distributed_training", False):
|
222 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
223 |
+
else:
|
224 |
+
# For distributed training, device will be set per process
|
225 |
+
device = None
|
226 |
+
|
227 |
+
# Load data using the streaming dataset
|
228 |
+
data = prepare_data_loaders(config)
|
229 |
+
|
230 |
+
# For distributed training, we'll set up samplers later in the distributed worker
|
231 |
+
# Don't create DistributedSampler here as process group isn't initialized yet
|
232 |
+
|
233 |
+
return (
|
234 |
+
device,
|
235 |
+
data["train_loader"],
|
236 |
+
data["val_loader"],
|
237 |
+
data["train_cell_mapping"],
|
238 |
+
data["val_cell_mapping"],
|
239 |
+
data["num_labels_list"],
|
240 |
+
)
|
241 |
|
242 |
|
243 |
+
# Optuna hyperparameter optimization utilities
|
244 |
+
def save_trial_callback(study, trial, trials_result_path):
|
245 |
+
"""
|
246 |
+
Callback to save Optuna trial results to a file.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
study: Optuna study object
|
250 |
+
trial: Current trial object
|
251 |
+
trials_result_path: Path to save trial results
|
252 |
+
"""
|
253 |
+
with open(trials_result_path, "a") as f:
|
254 |
+
f.write(
|
255 |
+
f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
|
256 |
+
)
|
257 |
|
|
|
|
|
|
|
|
|
258 |
|
259 |
+
def create_optuna_study(objective, n_trials: int, trials_result_path: str, tensorboard_log_dir: str) -> optuna.Study:
|
260 |
+
"""Create and run an Optuna study with TensorBoard logging."""
|
261 |
+
from optuna.integration import TensorBoardCallback
|
262 |
+
|
263 |
+
study = optuna.create_study(direction="maximize")
|
264 |
+
study.optimize(
|
265 |
+
objective,
|
266 |
+
n_trials=n_trials,
|
267 |
+
callbacks=[
|
268 |
+
lambda study, trial: save_trial_callback(study, trial, trials_result_path),
|
269 |
+
TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro")
|
270 |
+
]
|
271 |
+
)
|
272 |
+
return study
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
+
@contextmanager
|
276 |
+
def setup_logging(config):
|
277 |
+
run_name = config.get("run_name", "manual_run")
|
278 |
+
log_dir = os.path.join(config["tensorboard_log_dir"], run_name)
|
279 |
+
writer = SummaryWriter(log_dir=log_dir)
|
280 |
+
try:
|
281 |
+
yield writer
|
282 |
+
finally:
|
283 |
+
writer.close()
|
284 |
|
|
|
|
|
|
|
285 |
|
286 |
+
def log_training_step(loss, writer, config, epoch, steps_per_epoch, batch_idx):
|
287 |
+
"""Log training step metrics to TensorBoard and optionally W&B."""
|
288 |
+
writer.add_scalar(
|
289 |
+
"Training Loss", loss, epoch * steps_per_epoch + batch_idx
|
290 |
+
)
|
291 |
+
if config.get("use_wandb", False):
|
292 |
+
import wandb
|
293 |
+
wandb.log({"Training Loss": loss})
|
294 |
|
|
|
295 |
|
296 |
+
def log_validation_metrics(task_metrics, val_loss, config, writer, epoch):
|
297 |
+
"""Log validation metrics to console, TensorBoard, and optionally W&B."""
|
298 |
+
for task_name, metrics in task_metrics.items():
|
299 |
+
print(
|
300 |
+
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
|
301 |
+
)
|
302 |
+
if config.get("use_wandb", False):
|
303 |
+
import wandb
|
304 |
+
wandb.log(
|
305 |
+
{
|
306 |
+
f"{task_name} Validation F1 Macro": metrics["f1"],
|
307 |
+
f"{task_name} Validation Accuracy": metrics["accuracy"],
|
308 |
+
}
|
309 |
+
)
|
310 |
|
311 |
+
writer.add_scalar("Validation Loss", val_loss, epoch)
|
312 |
+
for task_name, metrics in task_metrics.items():
|
313 |
+
writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epoch)
|
314 |
+
writer.add_scalar(
|
315 |
+
f"{task_name} - Validation Accuracy", metrics["accuracy"], epoch
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def load_label_mappings(results_dir: str, task_names: List[str]) -> Dict[str, Dict]:
|
320 |
+
"""Load or initialize task label mappings."""
|
321 |
+
label_mappings_path = os.path.join(results_dir, "task_label_mappings_val.pkl")
|
322 |
+
if os.path.exists(label_mappings_path):
|
323 |
+
with open(label_mappings_path, 'rb') as f:
|
324 |
+
return pickle.load(f)
|
325 |
+
return {task_name: {} for task_name in task_names}
|
326 |
+
|
327 |
+
|
328 |
+
def create_prediction_row(sample_idx: int, val_cell_indices: Dict, task_true_labels: Dict,
|
329 |
+
task_pred_labels: Dict, task_pred_probs: Dict, task_names: List[str],
|
330 |
+
inverted_mappings: Dict, val_cell_mapping: Dict) -> Dict:
|
331 |
+
"""Create a row for validation predictions."""
|
332 |
+
batch_cell_idx = val_cell_indices.get(sample_idx)
|
333 |
+
cell_id = val_cell_mapping.get(batch_cell_idx, f"unknown_cell_{sample_idx}") if batch_cell_idx is not None else f"unknown_cell_{sample_idx}"
|
334 |
+
|
335 |
+
row = {"Cell ID": cell_id}
|
336 |
+
for task_name in task_names:
|
337 |
+
if task_name in task_true_labels and sample_idx < len(task_true_labels[task_name]):
|
338 |
+
true_idx = task_true_labels[task_name][sample_idx]
|
339 |
+
pred_idx = task_pred_labels[task_name][sample_idx]
|
340 |
+
true_label = inverted_mappings.get(task_name, {}).get(true_idx, f"Unknown-{true_idx}")
|
341 |
+
pred_label = inverted_mappings.get(task_name, {}).get(pred_idx, f"Unknown-{pred_idx}")
|
342 |
+
|
343 |
+
row.update({
|
344 |
+
f"{task_name}_true_idx": true_idx,
|
345 |
+
f"{task_name}_pred_idx": pred_idx,
|
346 |
+
f"{task_name}_true_label": true_label,
|
347 |
+
f"{task_name}_pred_label": pred_label
|
348 |
+
})
|
349 |
+
|
350 |
+
if task_name in task_pred_probs and sample_idx < len(task_pred_probs[task_name]):
|
351 |
+
probs = task_pred_probs[task_name][sample_idx]
|
352 |
+
if isinstance(probs, (list, np.ndarray)) or (hasattr(probs, '__iter__') and not isinstance(probs, str)):
|
353 |
+
prob_list = list(probs) if not isinstance(probs, list) else probs
|
354 |
+
row[f"{task_name}_all_probs"] = ",".join(map(str, prob_list))
|
355 |
+
for class_idx, prob in enumerate(prob_list):
|
356 |
+
class_label = inverted_mappings.get(task_name, {}).get(class_idx, f"Unknown-{class_idx}")
|
357 |
+
row[f"{task_name}_prob_{class_label}"] = prob
|
358 |
+
else:
|
359 |
+
row[f"{task_name}_all_probs"] = str(probs)
|
360 |
+
|
361 |
+
return row
|
362 |
+
|
363 |
+
|
364 |
+
def save_validation_predictions(
|
365 |
+
val_cell_indices,
|
366 |
+
task_true_labels,
|
367 |
+
task_pred_labels,
|
368 |
+
task_pred_probs,
|
369 |
+
config,
|
370 |
+
trial_number=None,
|
371 |
+
):
|
372 |
+
"""Save validation predictions to a CSV file with class labels and probabilities."""
|
373 |
+
os.makedirs(config["results_dir"], exist_ok=True)
|
374 |
+
|
375 |
+
if trial_number is not None:
|
376 |
+
os.makedirs(os.path.join(config["results_dir"], f"trial_{trial_number}"), exist_ok=True)
|
377 |
+
val_preds_file = os.path.join(config["results_dir"], f"trial_{trial_number}/val_preds.csv")
|
378 |
+
else:
|
379 |
+
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
|
380 |
+
|
381 |
+
if not val_cell_indices or not task_true_labels:
|
382 |
+
pd.DataFrame().to_csv(val_preds_file, index=False)
|
383 |
+
return
|
384 |
+
|
385 |
+
try:
|
386 |
+
label_mappings = load_label_mappings(config["results_dir"], config["task_names"])
|
387 |
+
inverted_mappings = {task: {idx: label for label, idx in mapping.items()} for task, mapping in label_mappings.items()}
|
388 |
+
val_cell_mapping = config.get("val_cell_mapping", {})
|
389 |
+
|
390 |
+
# Determine maximum number of samples
|
391 |
+
max_samples = max(
|
392 |
+
[len(val_cell_indices)] +
|
393 |
+
[len(task_true_labels[task]) for task in task_true_labels]
|
394 |
+
)
|
395 |
+
|
396 |
+
rows = [
|
397 |
+
create_prediction_row(
|
398 |
+
sample_idx, val_cell_indices, task_true_labels, task_pred_labels,
|
399 |
+
task_pred_probs, config["task_names"], inverted_mappings, val_cell_mapping
|
400 |
+
)
|
401 |
+
for sample_idx in range(max_samples)
|
402 |
+
]
|
403 |
+
|
404 |
+
pd.DataFrame(rows).to_csv(val_preds_file, index=False)
|
405 |
+
except Exception as e:
|
406 |
+
pd.DataFrame([{"Error": str(e)}]).to_csv(val_preds_file, index=False)
|
407 |
+
|
408 |
+
|
409 |
+
def setup_distributed_environment(rank, world_size, config):
|
410 |
"""
|
411 |
+
Setup the distributed training environment.
|
412 |
+
|
413 |
Args:
|
414 |
+
rank (int): The rank of the current process
|
415 |
+
world_size (int): Total number of processes
|
416 |
+
config (dict): Configuration dictionary
|
417 |
"""
|
418 |
+
os.environ['MASTER_ADDR'] = config.get('master_addr', 'localhost')
|
419 |
+
os.environ['MASTER_PORT'] = config.get('master_port', '12355')
|
420 |
+
|
421 |
+
# Initialize the process group
|
422 |
+
dist.init_process_group(
|
423 |
+
backend='nccl',
|
424 |
+
init_method='env://',
|
425 |
+
world_size=world_size,
|
426 |
+
rank=rank
|
427 |
+
)
|
428 |
+
|
429 |
+
# Set the device for this process
|
430 |
+
torch.cuda.set_device(rank)
|
431 |
+
|
432 |
+
|
433 |
+
def train_distributed(trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
|
434 |
+
"""Run distributed training across multiple GPUs with fallback to single GPU."""
|
435 |
+
world_size = torch.cuda.device_count()
|
436 |
+
|
437 |
+
if world_size <= 1:
|
438 |
+
print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
|
439 |
+
config["distributed_training"] = False
|
440 |
+
trainer = trainer_class(config)
|
441 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
442 |
+
trainer.device = device
|
443 |
+
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
|
444 |
+
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
|
445 |
+
)
|
446 |
+
val_loss, model = trainer.train(
|
447 |
+
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
|
448 |
+
)
|
449 |
+
model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
|
450 |
+
save_model(model, model_save_directory)
|
451 |
+
save_hyperparameters(model_save_directory, {
|
452 |
+
**get_config_value(config, "manual_hyperparameters", {}),
|
453 |
+
"dropout_rate": config["dropout_rate"],
|
454 |
+
"use_task_weights": config["use_task_weights"],
|
455 |
+
"task_weights": config["task_weights"],
|
456 |
+
"max_layers_to_freeze": config["max_layers_to_freeze"],
|
457 |
+
"use_attention_pooling": config["use_attention_pooling"],
|
458 |
+
})
|
459 |
+
|
460 |
+
if shared_dict is not None:
|
461 |
+
shared_dict['val_loss'] = val_loss
|
462 |
+
task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(model, val_loader, device, config)
|
463 |
+
shared_dict['task_metrics'] = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
|
464 |
+
shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model.state_dict().items()}
|
465 |
+
|
466 |
+
return val_loss, model
|
467 |
+
|
468 |
+
print(f"Using distributed training with {world_size} GPUs")
|
469 |
+
mp.spawn(
|
470 |
+
_distributed_worker,
|
471 |
+
args=(world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number, shared_dict),
|
472 |
+
nprocs=world_size,
|
473 |
+
join=True
|
474 |
+
)
|
475 |
+
|
476 |
+
if trial_number is None and shared_dict is None:
|
477 |
+
model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
|
478 |
+
model_path = os.path.join(model_save_directory, "pytorch_model.bin")
|
479 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
480 |
+
model = create_model(config, num_labels_list, device)
|
481 |
+
model.load_state_dict(torch.load(model_path))
|
482 |
+
return 0.0, model
|
483 |
+
|
484 |
+
return None
|
485 |
+
|
486 |
+
|
487 |
+
def _distributed_worker(rank, world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
|
488 |
+
"""Worker function for distributed training."""
|
489 |
+
setup_distributed_environment(rank, world_size, config)
|
490 |
+
config["local_rank"] = rank
|
491 |
+
|
492 |
+
# Set up distributed samplers
|
493 |
+
from torch.utils.data import DistributedSampler
|
494 |
+
from .data import get_data_loader
|
495 |
+
|
496 |
+
train_sampler = DistributedSampler(train_loader.dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
|
497 |
+
val_sampler = DistributedSampler(val_loader.dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
|
498 |
+
|
499 |
+
train_loader = get_data_loader(train_loader.dataset, config["batch_size"], sampler=train_sampler, shuffle=False)
|
500 |
+
val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=val_sampler, shuffle=False)
|
501 |
+
|
502 |
+
if rank == 0:
|
503 |
+
print(f"Rank {rank}: Training {len(train_sampler)} samples, Validation {len(val_sampler)} samples")
|
504 |
+
print(f"Total samples across {world_size} GPUs: Training {len(train_sampler) * world_size}, Validation {len(val_sampler) * world_size}")
|
505 |
+
|
506 |
+
# Create and setup trainer
|
507 |
+
trainer = trainer_class(config)
|
508 |
+
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
|
509 |
+
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
|
510 |
+
)
|
511 |
+
|
512 |
+
# Train the model
|
513 |
+
val_loss, model = trainer.train(
|
514 |
+
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
|
515 |
+
)
|
516 |
+
|
517 |
+
# Save model only from the main process
|
518 |
+
if rank == 0:
|
519 |
+
model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
|
520 |
+
save_model(model, model_save_directory)
|
521 |
+
|
522 |
+
save_hyperparameters(model_save_directory, {
|
523 |
+
**get_config_value(config, "manual_hyperparameters", {}),
|
524 |
+
"dropout_rate": config["dropout_rate"],
|
525 |
+
"use_task_weights": config["use_task_weights"],
|
526 |
+
"task_weights": config["task_weights"],
|
527 |
+
"max_layers_to_freeze": config["max_layers_to_freeze"],
|
528 |
+
"use_attention_pooling": config["use_attention_pooling"],
|
529 |
+
})
|
530 |
+
|
531 |
+
# For Optuna trials, store results in shared dictionary
|
532 |
+
if shared_dict is not None:
|
533 |
+
shared_dict['val_loss'] = val_loss
|
534 |
+
|
535 |
+
# Run validation on full dataset from rank 0 for consistent metrics
|
536 |
+
full_val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=None, shuffle=False)
|
537 |
+
|
538 |
+
# Get validation predictions using our utility function
|
539 |
+
task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(
|
540 |
+
model, full_val_loader, trainer.device, config
|
541 |
+
)
|
542 |
+
|
543 |
+
# Calculate metrics
|
544 |
+
task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
|
545 |
+
shared_dict['task_metrics'] = task_metrics
|
546 |
+
|
547 |
+
# Store model state dict
|
548 |
+
if isinstance(model, DDP):
|
549 |
+
model_state_dict = model.module.state_dict()
|
550 |
+
else:
|
551 |
+
model_state_dict = model.state_dict()
|
552 |
+
|
553 |
+
shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model_state_dict.items()}
|
554 |
+
|
555 |
+
# Clean up distributed environment
|
556 |
+
dist.destroy_process_group()
|
557 |
+
|
558 |
+
|
559 |
+
def save_model_without_heads(model_directory):
|
560 |
+
"""
|
561 |
+
Save a version of the fine-tuned model without classification heads.
|
562 |
+
|
563 |
+
Args:
|
564 |
+
model_directory (str): Path to the directory containing the fine-tuned model
|
565 |
+
"""
|
566 |
+
import torch
|
567 |
+
from transformers import BertConfig, BertModel
|
568 |
+
|
569 |
+
# Load the full model
|
570 |
+
model_path = os.path.join(model_directory, "pytorch_model.bin")
|
571 |
+
config_path = os.path.join(model_directory, "config.json")
|
572 |
+
|
573 |
+
if not os.path.exists(model_path) or not os.path.exists(config_path):
|
574 |
+
raise FileNotFoundError(f"Model files not found in {model_directory}")
|
575 |
+
|
576 |
+
# Load the configuration
|
577 |
+
config = BertConfig.from_json_file(config_path)
|
578 |
+
|
579 |
+
# Load the model state dict
|
580 |
+
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
581 |
+
|
582 |
+
# Create a new model without heads
|
583 |
+
base_model = BertModel(config)
|
584 |
+
|
585 |
+
# Filter out the classification head parameters
|
586 |
+
base_model_state_dict = {}
|
587 |
+
for key, value in state_dict.items():
|
588 |
+
# Only keep parameters that belong to the base model (not classification heads)
|
589 |
+
if not key.startswith('classification_heads') and not key.startswith('attention_pool'):
|
590 |
+
base_model_state_dict[key] = value
|
591 |
+
|
592 |
+
# Load the filtered state dict into the base model
|
593 |
+
base_model.load_state_dict(base_model_state_dict, strict=False)
|
594 |
+
|
595 |
+
# Save the model without heads
|
596 |
+
output_dir = os.path.join(model_directory, "model_without_heads")
|
597 |
+
os.makedirs(output_dir, exist_ok=True)
|
598 |
+
|
599 |
+
# Save the model weights
|
600 |
+
torch.save(base_model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
|
601 |
+
|
602 |
+
# Save the configuration
|
603 |
+
base_model.config.to_json_file(os.path.join(output_dir, "config.json"))
|
604 |
+
|
605 |
+
print(f"Model without classification heads saved to {output_dir}")
|
606 |
+
return output_dir
|
607 |
+
|
608 |
+
|
609 |
+
def get_config_value(config: Dict, key: str, default=None):
|
610 |
+
|
611 |
+
return config.get(key, default)
|
612 |
+
|
613 |
+
|
614 |
+
def collect_validation_predictions(model, val_loader, device, config) -> tuple:
|
615 |
+
task_true_labels = {}
|
616 |
+
task_pred_labels = {}
|
617 |
+
task_pred_probs = {}
|
618 |
+
|
619 |
+
with torch.no_grad():
|
620 |
+
for batch in val_loader:
|
621 |
+
input_ids = batch["input_ids"].to(device)
|
622 |
+
attention_mask = batch["attention_mask"].to(device)
|
623 |
+
labels = [batch["labels"][task_name].to(device) for task_name in config["task_names"]]
|
624 |
+
_, logits, _ = model(input_ids, attention_mask, labels)
|
625 |
+
|
626 |
+
for sample_idx in range(len(batch["input_ids"])):
|
627 |
+
for i, task_name in enumerate(config["task_names"]):
|
628 |
+
if task_name not in task_true_labels:
|
629 |
+
task_true_labels[task_name] = []
|
630 |
+
task_pred_labels[task_name] = []
|
631 |
+
task_pred_probs[task_name] = []
|
632 |
+
|
633 |
+
true_label = batch["labels"][task_name][sample_idx].item()
|
634 |
+
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
635 |
+
pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
|
636 |
+
|
637 |
+
task_true_labels[task_name].append(true_label)
|
638 |
+
task_pred_labels[task_name].append(pred_label)
|
639 |
+
task_pred_probs[task_name].append(pred_prob)
|
640 |
+
|
641 |
+
return task_true_labels, task_pred_labels, task_pred_probs
|
geneformer/mtl_classifier.py
CHANGED
@@ -29,7 +29,8 @@ Geneformer multi-task cell classifier.
|
|
29 |
import logging
|
30 |
import os
|
31 |
|
32 |
-
from .mtl import eval_utils,
|
|
|
33 |
|
34 |
logger = logging.getLogger(__name__)
|
35 |
|
@@ -49,7 +50,9 @@ class MTLClassifier:
|
|
49 |
"max_layers_to_freeze": {None, dict},
|
50 |
"epochs": {None, int},
|
51 |
"tensorboard_log_dir": {None, str},
|
52 |
-
"
|
|
|
|
|
53 |
"use_attention_pooling": {None, bool},
|
54 |
"use_task_weights": {None, bool},
|
55 |
"hyperparameters": {None, dict},
|
@@ -61,6 +64,7 @@ class MTLClassifier:
|
|
61 |
"max_grad_norm": {None, int, float},
|
62 |
"seed": {None, int},
|
63 |
"trials_result_path": {None, str},
|
|
|
64 |
}
|
65 |
|
66 |
def __init__(
|
@@ -79,7 +83,9 @@ class MTLClassifier:
|
|
79 |
max_layers_to_freeze=None,
|
80 |
epochs=1,
|
81 |
tensorboard_log_dir="/results/tblogdir",
|
82 |
-
|
|
|
|
|
83 |
use_attention_pooling=True,
|
84 |
use_task_weights=True,
|
85 |
hyperparameters=None, # Default is None
|
@@ -89,6 +95,7 @@ class MTLClassifier:
|
|
89 |
wandb_project=None,
|
90 |
gradient_clipping=False,
|
91 |
max_grad_norm=None,
|
|
|
92 |
seed=42, # Default seed value
|
93 |
):
|
94 |
"""
|
@@ -117,8 +124,12 @@ class MTLClassifier:
|
|
117 |
| Path to directory to save results
|
118 |
tensorboard_log_dir : None, str
|
119 |
| Path to directory for Tensorboard logging results
|
120 |
-
|
121 |
-
| Whether to use data
|
|
|
|
|
|
|
|
|
122 |
use_attention_pooling : None, bool
|
123 |
| Whether to use attention pooling
|
124 |
use_task_weights : None, bool
|
@@ -150,6 +161,8 @@ class MTLClassifier:
|
|
150 |
| Whether to use gradient clipping
|
151 |
max_grad_norm : None, int, float
|
152 |
| Maximum norm for gradient clipping
|
|
|
|
|
153 |
seed : None, int
|
154 |
| Random seed
|
155 |
"""
|
@@ -165,6 +178,7 @@ class MTLClassifier:
|
|
165 |
self.batch_size = batch_size
|
166 |
self.n_trials = n_trials
|
167 |
self.study_name = study_name
|
|
|
168 |
|
169 |
if max_layers_to_freeze is None:
|
170 |
# Dynamically determine the range of layers to freeze
|
@@ -175,7 +189,9 @@ class MTLClassifier:
|
|
175 |
|
176 |
self.epochs = epochs
|
177 |
self.tensorboard_log_dir = tensorboard_log_dir
|
178 |
-
self.
|
|
|
|
|
179 |
self.use_attention_pooling = use_attention_pooling
|
180 |
self.use_task_weights = use_task_weights
|
181 |
self.hyperparameters = (
|
@@ -293,7 +309,7 @@ class MTLClassifier:
|
|
293 |
self.config["manual_hyperparameters"] = self.manual_hyperparameters
|
294 |
self.config["use_manual_hyperparameters"] = True
|
295 |
|
296 |
-
|
297 |
|
298 |
def validate_additional_options(self, req_var_dict):
|
299 |
missing_variable = False
|
@@ -330,7 +346,7 @@ class MTLClassifier:
|
|
330 |
req_var_dict = dict(zip(required_variable_names, required_variables))
|
331 |
self.validate_additional_options(req_var_dict)
|
332 |
|
333 |
-
|
334 |
|
335 |
def load_and_evaluate_test_model(
|
336 |
self,
|
|
|
29 |
import logging
|
30 |
import os
|
31 |
|
32 |
+
from .mtl import eval_utils, utils
|
33 |
+
from .mtl.train import run_manual_tuning, run_optuna_study
|
34 |
|
35 |
logger = logging.getLogger(__name__)
|
36 |
|
|
|
50 |
"max_layers_to_freeze": {None, dict},
|
51 |
"epochs": {None, int},
|
52 |
"tensorboard_log_dir": {None, str},
|
53 |
+
"distributed_training": {None, bool},
|
54 |
+
"master_addr": {None, str},
|
55 |
+
"master_port": {None, str},
|
56 |
"use_attention_pooling": {None, bool},
|
57 |
"use_task_weights": {None, bool},
|
58 |
"hyperparameters": {None, dict},
|
|
|
64 |
"max_grad_norm": {None, int, float},
|
65 |
"seed": {None, int},
|
66 |
"trials_result_path": {None, str},
|
67 |
+
"gradient_accumulation_steps": {None, int},
|
68 |
}
|
69 |
|
70 |
def __init__(
|
|
|
83 |
max_layers_to_freeze=None,
|
84 |
epochs=1,
|
85 |
tensorboard_log_dir="/results/tblogdir",
|
86 |
+
distributed_training=False,
|
87 |
+
master_addr="localhost",
|
88 |
+
master_port="12355",
|
89 |
use_attention_pooling=True,
|
90 |
use_task_weights=True,
|
91 |
hyperparameters=None, # Default is None
|
|
|
95 |
wandb_project=None,
|
96 |
gradient_clipping=False,
|
97 |
max_grad_norm=None,
|
98 |
+
gradient_accumulation_steps=1, # Add this line with default value 1
|
99 |
seed=42, # Default seed value
|
100 |
):
|
101 |
"""
|
|
|
124 |
| Path to directory to save results
|
125 |
tensorboard_log_dir : None, str
|
126 |
| Path to directory for Tensorboard logging results
|
127 |
+
distributed_training : None, bool
|
128 |
+
| Whether to use distributed data parallel training across multiple GPUs
|
129 |
+
master_addr : None, str
|
130 |
+
| Master address for distributed training (default: localhost)
|
131 |
+
master_port : None, str
|
132 |
+
| Master port for distributed training (default: 12355)
|
133 |
use_attention_pooling : None, bool
|
134 |
| Whether to use attention pooling
|
135 |
use_task_weights : None, bool
|
|
|
161 |
| Whether to use gradient clipping
|
162 |
max_grad_norm : None, int, float
|
163 |
| Maximum norm for gradient clipping
|
164 |
+
gradient_accumulation_steps : None, int
|
165 |
+
| Number of steps to accumulate gradients before performing a backward/update pass
|
166 |
seed : None, int
|
167 |
| Random seed
|
168 |
"""
|
|
|
178 |
self.batch_size = batch_size
|
179 |
self.n_trials = n_trials
|
180 |
self.study_name = study_name
|
181 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
182 |
|
183 |
if max_layers_to_freeze is None:
|
184 |
# Dynamically determine the range of layers to freeze
|
|
|
189 |
|
190 |
self.epochs = epochs
|
191 |
self.tensorboard_log_dir = tensorboard_log_dir
|
192 |
+
self.distributed_training = distributed_training
|
193 |
+
self.master_addr = master_addr
|
194 |
+
self.master_port = master_port
|
195 |
self.use_attention_pooling = use_attention_pooling
|
196 |
self.use_task_weights = use_task_weights
|
197 |
self.hyperparameters = (
|
|
|
309 |
self.config["manual_hyperparameters"] = self.manual_hyperparameters
|
310 |
self.config["use_manual_hyperparameters"] = True
|
311 |
|
312 |
+
run_manual_tuning(self.config)
|
313 |
|
314 |
def validate_additional_options(self, req_var_dict):
|
315 |
missing_variable = False
|
|
|
346 |
req_var_dict = dict(zip(required_variable_names, required_variables))
|
347 |
self.validate_additional_options(req_var_dict)
|
348 |
|
349 |
+
run_optuna_study(self.config)
|
350 |
|
351 |
def load_and_evaluate_test_model(
|
352 |
self,
|