madhavanvvs commited on
Commit
a794aec
·
1 Parent(s): c6d04a6

Refactor MTL: DDP NCCL support

Browse files
geneformer/mtl/__init__.py CHANGED
@@ -1 +1,4 @@
1
- # ruff: noqa: F401
 
 
 
 
1
+ # ruff: noqa: F401
2
+
3
+ from . import eval_utils
4
+ from . import utils
geneformer/mtl/collators.py CHANGED
@@ -1,8 +1,8 @@
1
  # imports
2
  import torch
3
  import pickle
4
- from ..collator_for_classification import DataCollatorForGeneClassification
5
- from .. import TOKEN_DICTIONARY_FILE
6
 
7
  """Geneformer collator for multi-task cell classification."""
8
 
 
1
  # imports
2
  import torch
3
  import pickle
4
+ from geneformer.collator_for_classification import DataCollatorForGeneClassification
5
+ from geneformer import TOKEN_DICTIONARY_FILE
6
 
7
  """Geneformer collator for multi-task cell classification."""
8
 
geneformer/mtl/data.py CHANGED
@@ -1,126 +1,190 @@
1
  import os
2
- from .collators import DataCollatorForMultitaskCellClassification
3
- from .imports import *
4
-
5
- def validate_columns(dataset, required_columns, dataset_type):
6
- """Ensures required columns are present in the dataset."""
7
- missing_columns = [col for col in required_columns if col not in dataset.column_names]
8
- if missing_columns:
9
- raise KeyError(
10
- f"Missing columns in {dataset_type} dataset: {missing_columns}. "
11
- f"Available columns: {dataset.column_names}"
12
- )
13
-
14
-
15
- def create_label_mappings(dataset, task_to_column):
16
- """Creates label mappings for the dataset."""
17
- task_label_mappings = {}
18
- num_labels_list = []
19
- for task, column in task_to_column.items():
20
- unique_values = sorted(set(dataset[column]))
21
- mapping = {label: idx for idx, label in enumerate(unique_values)}
22
- task_label_mappings[task] = mapping
23
- num_labels_list.append(len(unique_values))
24
- return task_label_mappings, num_labels_list
25
-
26
-
27
- def save_label_mappings(mappings, path):
28
- """Saves label mappings to a pickle file."""
29
- with open(path, "wb") as f:
30
- pickle.dump(mappings, f)
31
-
32
-
33
- def load_label_mappings(path):
34
- """Loads label mappings from a pickle file."""
35
- with open(path, "rb") as f:
36
- return pickle.load(f)
37
-
38
-
39
- def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
40
- """Transforms the dataset to the required format."""
41
- transformed_dataset = []
42
- cell_id_mapping = {}
43
-
44
- for idx, record in enumerate(dataset):
45
- transformed_record = {
46
- "input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
47
- "cell_id": idx, # Index-based cell ID
48
- }
49
-
50
- if not is_test:
51
- label_dict = {
52
- task: task_label_mappings[task][record[column]]
53
- for task, column in task_to_column.items()
54
- }
55
- else:
56
- label_dict = {task: -1 for task in config["task_names"]}
57
-
58
- transformed_record["label"] = label_dict
59
- transformed_dataset.append(transformed_record)
60
- cell_id_mapping[idx] = record.get("unique_cell_id", idx)
61
-
62
- return transformed_dataset, cell_id_mapping
63
 
 
64
 
65
- def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
66
- """Main function to load and preprocess data."""
67
- try:
68
- dataset = load_from_disk(dataset_path)
69
 
 
 
 
 
 
 
 
 
 
 
70
  # Setup task and column mappings
71
- task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
72
- task_to_column = dict(zip(task_names, config["task_columns"]))
73
- config["task_names"] = task_names
74
-
75
- label_mappings_path = os.path.join(
 
 
 
 
 
76
  config["results_dir"],
77
  f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
78
  )
79
-
80
  if not is_test:
81
- validate_columns(dataset, task_to_column.values(), dataset_type)
82
-
83
- # Create and save label mappings
84
- task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column)
85
- save_label_mappings(task_label_mappings, label_mappings_path)
86
  else:
87
  # Load existing mappings for test data
88
- task_label_mappings = load_label_mappings(label_mappings_path)
89
- num_labels_list = [len(mapping) for mapping in task_label_mappings.values()]
90
-
91
- # Transform dataset
92
- transformed_dataset, cell_id_mapping = transform_dataset(
93
- dataset, task_to_column, task_label_mappings, config, is_test
94
- )
95
-
96
- return transformed_dataset, cell_id_mapping, num_labels_list
97
-
98
- except KeyError as e:
99
- raise ValueError(f"Configuration error or dataset key missing: {e}")
100
- except Exception as e:
101
- raise RuntimeError(f"Error during data loading or preprocessing: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
- def preload_and_process_data(config):
105
- """Preloads and preprocesses train and validation datasets."""
106
- # Process train data and save mappings
107
- train_data = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
 
 
 
 
 
 
 
108
 
109
- # Process validation data and save mappings
110
- val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
111
 
112
- # Validate that the mappings match
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  validate_label_mappings(config)
114
-
115
- return (*train_data[:2], *val_data) # Return train and val data along with mappings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  def validate_label_mappings(config):
119
  """Ensures train and validation label mappings are consistent."""
120
  train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
121
  val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
122
- train_mappings = load_label_mappings(train_mappings_path)
123
- val_mappings = load_label_mappings(val_mappings_path)
 
 
 
 
124
 
125
  for task_name in config["task_names"]:
126
  if train_mappings[task_name] != val_mappings[task_name]:
@@ -131,32 +195,43 @@ def validate_label_mappings(config):
131
  )
132
 
133
 
134
- def get_data_loader(preprocessed_dataset, batch_size):
135
- """Creates a DataLoader with optimal settings."""
136
- return DataLoader(
137
- preprocessed_dataset,
138
- batch_size=batch_size,
139
- shuffle=True,
140
- collate_fn=DataCollatorForMultitaskCellClassification(),
141
- num_workers=os.cpu_count(),
142
- pin_memory=True,
 
 
143
  )
144
 
145
 
146
  def preload_data(config):
147
  """Preprocesses train and validation data for trials."""
148
- train_loader = get_data_loader(*preload_and_process_data(config)[:2], config["batch_size"])
149
- val_loader = get_data_loader(*preload_and_process_data(config)[2:4], config["batch_size"])
150
- return train_loader, val_loader
151
 
152
 
153
  def load_and_preprocess_test_data(config):
154
  """Loads and preprocesses test data."""
155
- return load_and_preprocess_data(config["test_path"], config, is_test=True)
 
 
 
 
 
 
 
 
 
 
 
156
 
157
 
158
  def prepare_test_loader(config):
159
  """Prepares DataLoader for test data."""
160
- test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
161
- test_loader = get_data_loader(test_dataset, config["batch_size"])
162
- return test_loader, cell_id_mapping, num_labels_list
 
1
  import os
2
+ import pickle
3
+ import torch
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from datasets import load_from_disk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ from .collators import DataCollatorForMultitaskCellClassification
8
 
 
 
 
 
9
 
10
+ class StreamingMultiTaskDataset(Dataset):
11
+
12
+ def __init__(self, dataset_path, config, is_test=False, dataset_type=""):
13
+ """Initialize the streaming dataset."""
14
+ self.dataset = load_from_disk(dataset_path)
15
+ self.config = config
16
+ self.is_test = is_test
17
+ self.dataset_type = dataset_type
18
+ self.cell_id_mapping = {}
19
+
20
  # Setup task and column mappings
21
+ self.task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
22
+ self.task_to_column = dict(zip(self.task_names, config["task_columns"]))
23
+ config["task_names"] = self.task_names
24
+
25
+ # Check if unique_cell_id column exists in the dataset
26
+ self.has_unique_cell_ids = "unique_cell_id" in self.dataset.column_names
27
+ print(f"{'Found' if self.has_unique_cell_ids else 'No'} unique_cell_id column in {dataset_type} dataset")
28
+
29
+ # Setup label mappings
30
+ self.label_mappings_path = os.path.join(
31
  config["results_dir"],
32
  f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
33
  )
34
+
35
  if not is_test:
36
+ self._validate_columns()
37
+ self.task_label_mappings, self.num_labels_list = self._create_label_mappings()
38
+ self._save_label_mappings()
 
 
39
  else:
40
  # Load existing mappings for test data
41
+ self.task_label_mappings = self._load_label_mappings()
42
+ self.num_labels_list = [len(mapping) for mapping in self.task_label_mappings.values()]
43
+
44
+ def _validate_columns(self):
45
+ """Ensures required columns are present in the dataset."""
46
+ missing_columns = [col for col in self.task_to_column.values()
47
+ if col not in self.dataset.column_names]
48
+ if missing_columns:
49
+ raise KeyError(
50
+ f"Missing columns in {self.dataset_type} dataset: {missing_columns}. "
51
+ f"Available columns: {self.dataset.column_names}"
52
+ )
53
+
54
+ def _create_label_mappings(self):
55
+ """Creates label mappings for the dataset."""
56
+ task_label_mappings = {}
57
+ num_labels_list = []
58
+
59
+ for task, column in self.task_to_column.items():
60
+ unique_values = sorted(set(self.dataset[column]))
61
+ mapping = {label: idx for idx, label in enumerate(unique_values)}
62
+ task_label_mappings[task] = mapping
63
+ num_labels_list.append(len(unique_values))
64
+
65
+ return task_label_mappings, num_labels_list
66
+
67
+ def _save_label_mappings(self):
68
+ """Saves label mappings to a pickle file."""
69
+ with open(self.label_mappings_path, "wb") as f:
70
+ pickle.dump(self.task_label_mappings, f)
71
+
72
+ def _load_label_mappings(self):
73
+ """Loads label mappings from a pickle file."""
74
+ with open(self.label_mappings_path, "rb") as f:
75
+ return pickle.load(f)
76
+
77
+ def __len__(self):
78
+ return len(self.dataset)
79
+
80
+ def __getitem__(self, idx):
81
+ record = self.dataset[idx]
82
+
83
+ # Store cell ID mapping
84
+ if self.has_unique_cell_ids:
85
+ unique_cell_id = record["unique_cell_id"]
86
+ self.cell_id_mapping[idx] = unique_cell_id
87
+ else:
88
+ self.cell_id_mapping[idx] = f"cell_{idx}"
89
+
90
+ # Create transformed record
91
+ transformed_record = {
92
+ "input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
93
+ "cell_id": idx,
94
+ }
95
+
96
+ # Add labels
97
+ if not self.is_test:
98
+ label_dict = {
99
+ task: self.task_label_mappings[task][record[column]]
100
+ for task, column in self.task_to_column.items()
101
+ }
102
+ else:
103
+ label_dict = {task: -1 for task in self.config["task_names"]}
104
+
105
+ transformed_record["label"] = label_dict
106
+
107
+ return transformed_record
108
 
109
 
110
+ def get_data_loader(dataset, batch_size, sampler=None, shuffle=True):
111
+ """Create a DataLoader with the given dataset and parameters."""
112
+ return DataLoader(
113
+ dataset,
114
+ batch_size=batch_size,
115
+ sampler=sampler,
116
+ shuffle=shuffle if sampler is None else False,
117
+ num_workers=0,
118
+ pin_memory=True,
119
+ collate_fn=DataCollatorForMultitaskCellClassification(),
120
+ )
121
 
 
 
122
 
123
+ def prepare_data_loaders(config, include_test=False):
124
+ """Prepare data loaders for training, validation, and optionally test."""
125
+ result = {}
126
+
127
+ # Process train data
128
+ train_dataset = StreamingMultiTaskDataset(
129
+ config["train_path"],
130
+ config,
131
+ dataset_type="train"
132
+ )
133
+ result["train_loader"] = get_data_loader(train_dataset, config["batch_size"])
134
+
135
+ # Store the cell ID mapping from the dataset
136
+ result["train_cell_mapping"] = {k: v for k, v in train_dataset.cell_id_mapping.items()}
137
+ print(f"Collected {len(result['train_cell_mapping'])} cell IDs from training dataset")
138
+
139
+ result["num_labels_list"] = train_dataset.num_labels_list
140
+
141
+ # Process validation data
142
+ val_dataset = StreamingMultiTaskDataset(
143
+ config["val_path"],
144
+ config,
145
+ dataset_type="validation"
146
+ )
147
+ result["val_loader"] = get_data_loader(val_dataset, config["batch_size"])
148
+
149
+ # Store the complete cell ID mapping for validation
150
+ for idx in range(len(val_dataset)):
151
+ _ = val_dataset[idx]
152
+
153
+ result["val_cell_mapping"] = {k: v for k, v in val_dataset.cell_id_mapping.items()}
154
+ print(f"Collected {len(result['val_cell_mapping'])} cell IDs from validation dataset")
155
+
156
+ # Validate label mappings
157
  validate_label_mappings(config)
158
+
159
+ # Process test data if requested
160
+ if include_test and "test_path" in config:
161
+ test_dataset = StreamingMultiTaskDataset(
162
+ config["test_path"],
163
+ config,
164
+ is_test=True,
165
+ dataset_type="test"
166
+ )
167
+ result["test_loader"] = get_data_loader(test_dataset, config["batch_size"])
168
+
169
+ for idx in range(len(test_dataset)):
170
+ _ = test_dataset[idx]
171
+
172
+ result["test_cell_mapping"] = {k: v for k, v in test_dataset.cell_id_mapping.items()}
173
+ print(f"Collected {len(result['test_cell_mapping'])} cell IDs from test dataset")
174
+
175
+ return result
176
 
177
 
178
  def validate_label_mappings(config):
179
  """Ensures train and validation label mappings are consistent."""
180
  train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
181
  val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
182
+
183
+ with open(train_mappings_path, "rb") as f:
184
+ train_mappings = pickle.load(f)
185
+
186
+ with open(val_mappings_path, "rb") as f:
187
+ val_mappings = pickle.load(f)
188
 
189
  for task_name in config["task_names"]:
190
  if train_mappings[task_name] != val_mappings[task_name]:
 
195
  )
196
 
197
 
198
+ # Legacy functions for backward compatibility
199
+ def preload_and_process_data(config):
200
+ """Preloads and preprocesses train and validation datasets."""
201
+ data = prepare_data_loaders(config)
202
+
203
+ return (
204
+ data["train_loader"].dataset,
205
+ data["train_cell_mapping"],
206
+ data["val_loader"].dataset,
207
+ data["val_cell_mapping"],
208
+ data["num_labels_list"]
209
  )
210
 
211
 
212
  def preload_data(config):
213
  """Preprocesses train and validation data for trials."""
214
+ data = prepare_data_loaders(config)
215
+ return data["train_loader"], data["val_loader"]
 
216
 
217
 
218
  def load_and_preprocess_test_data(config):
219
  """Loads and preprocesses test data."""
220
+ test_dataset = StreamingMultiTaskDataset(
221
+ config["test_path"],
222
+ config,
223
+ is_test=True,
224
+ dataset_type="test"
225
+ )
226
+
227
+ return (
228
+ test_dataset,
229
+ test_dataset.cell_id_mapping,
230
+ test_dataset.num_labels_list
231
+ )
232
 
233
 
234
  def prepare_test_loader(config):
235
  """Prepares DataLoader for test data."""
236
+ data = prepare_data_loaders(config, include_test=True)
237
+ return data["test_loader"], data["test_cell_mapping"], data["num_labels_list"]
 
geneformer/mtl/eval_utils.py CHANGED
@@ -1,19 +1,16 @@
 
 
 
1
  import pandas as pd
2
 
3
- from .imports import * # noqa # isort:skip
4
- from .data import prepare_test_loader # noqa # isort:skip
5
  from .model import GeneformerMultiTask
6
 
7
-
8
  def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
9
  task_pred_labels = {task_name: [] for task_name in config["task_names"]}
10
  task_pred_probs = {task_name: [] for task_name in config["task_names"]}
11
  cell_ids = []
12
 
13
- # # Load task label mappings from pickle file
14
- # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
15
- # task_label_mappings = pickle.load(f)
16
-
17
  model.eval()
18
  with torch.no_grad():
19
  for batch in test_loader:
@@ -85,4 +82,4 @@ def load_and_evaluate_test_model(config):
85
  best_model.to(device)
86
 
87
  evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
88
- print("Evaluation completed.")
 
1
+ import os
2
+ import json
3
+ import torch
4
  import pandas as pd
5
 
6
+ from .data import prepare_test_loader
 
7
  from .model import GeneformerMultiTask
8
 
 
9
  def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
10
  task_pred_labels = {task_name: [] for task_name in config["task_names"]}
11
  task_pred_probs = {task_name: [] for task_name in config["task_names"]}
12
  cell_ids = []
13
 
 
 
 
 
14
  model.eval()
15
  with torch.no_grad():
16
  for batch in test_loader:
 
82
  best_model.to(device)
83
 
84
  evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
85
+ print("Evaluation completed.")
geneformer/mtl/imports.py DELETED
@@ -1,43 +0,0 @@
1
- import functools
2
- import gc
3
- import json
4
- import os
5
- import pickle
6
- import sys
7
- import warnings
8
- from enum import Enum
9
- from itertools import chain
10
- from typing import Dict, List, Optional, Union
11
-
12
- import numpy as np
13
- import optuna
14
- import pandas as pd
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- import torch.optim as optim
19
- from datasets import load_from_disk
20
- from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
21
- from sklearn.model_selection import train_test_split
22
- from sklearn.preprocessing import LabelEncoder
23
- from torch.utils.data import DataLoader
24
- from transformers import (
25
- AdamW,
26
- BatchEncoding,
27
- BertConfig,
28
- BertModel,
29
- DataCollatorForTokenClassification,
30
- SpecialTokensMixin,
31
- get_cosine_schedule_with_warmup,
32
- get_linear_schedule_with_warmup,
33
- get_scheduler,
34
- )
35
- from transformers.utils import logging, to_py_obj
36
-
37
- from .collators import DataCollatorForMultitaskCellClassification
38
-
39
- # local modules
40
- from .data import get_data_loader, preload_and_process_data
41
- from .model import GeneformerMultiTask
42
- from .optuna_utils import create_optuna_study
43
- from .utils import save_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/model.py CHANGED
@@ -118,4 +118,4 @@ class GeneformerMultiTask(nn.Module):
118
  f"Error during loss computation for task {task_id}: {e}"
119
  )
120
 
121
- return total_loss, logits, losses if labels is not None else logits
 
118
  f"Error during loss computation for task {task_id}: {e}"
119
  )
120
 
121
+ return total_loss, logits, losses if labels is not None else logits
geneformer/mtl/optuna_utils.py DELETED
@@ -1,27 +0,0 @@
1
- import optuna
2
- from optuna.integration import TensorBoardCallback
3
-
4
-
5
- def save_trial_callback(study, trial, trials_result_path):
6
- with open(trials_result_path, "a") as f:
7
- f.write(
8
- f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
9
- )
10
-
11
-
12
- def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
13
- study = optuna.create_study(direction="maximize")
14
-
15
- # init TensorBoard callback
16
- tensorboard_callback = TensorBoardCallback(
17
- dirname=tensorboard_log_dir, metric_name="F1 Macro"
18
- )
19
-
20
- # callback and TensorBoard callback
21
- callbacks = [
22
- lambda study, trial: save_trial_callback(study, trial, trials_result_path),
23
- tensorboard_callback,
24
- ]
25
-
26
- study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
27
- return study
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/train.py CHANGED
@@ -1,380 +1,707 @@
1
  import os
2
- import random
3
-
4
- import numpy as np
5
  import pandas as pd
6
  import torch
 
 
 
7
  from torch.utils.tensorboard import SummaryWriter
8
  from tqdm import tqdm
 
 
 
9
 
10
- from .imports import *
11
  from .model import GeneformerMultiTask
12
- from .utils import calculate_task_specific_metrics, get_layer_freeze_range
13
-
14
-
15
- def set_seed(seed):
16
- random.seed(seed)
17
- np.random.seed(seed)
18
- torch.manual_seed(seed)
19
- torch.cuda.manual_seed_all(seed)
20
- torch.backends.cudnn.deterministic = True
21
- torch.backends.cudnn.benchmark = False
22
-
23
-
24
- def initialize_wandb(config):
25
- if config.get("use_wandb", False):
26
- import wandb
27
-
28
- wandb.init(project=config["wandb_project"], config=config)
29
- print("Weights & Biases (wandb) initialized and will be used for logging.")
30
- else:
31
- print(
32
- "Weights & Biases (wandb) is not enabled. Logging will use other methods."
33
- )
34
-
35
-
36
- def create_model(config, num_labels_list, device):
37
- model = GeneformerMultiTask(
38
- config["pretrained_path"],
39
- num_labels_list,
40
- dropout_rate=config["dropout_rate"],
41
- use_task_weights=config["use_task_weights"],
42
- task_weights=config["task_weights"],
43
- max_layers_to_freeze=config["max_layers_to_freeze"],
44
- use_attention_pooling=config["use_attention_pooling"],
45
- )
46
- if config["use_data_parallel"]:
47
- model = nn.DataParallel(model)
48
- return model.to(device)
49
-
50
-
51
- def setup_optimizer_and_scheduler(model, config, total_steps):
52
- optimizer = AdamW(
53
- model.parameters(),
54
- lr=config["learning_rate"],
55
- weight_decay=config["weight_decay"],
56
- )
57
- warmup_steps = int(config["warmup_ratio"] * total_steps)
58
-
59
- if config["lr_scheduler_type"] == "linear":
60
- scheduler = get_linear_schedule_with_warmup(
61
- optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
62
- )
63
- elif config["lr_scheduler_type"] == "cosine":
64
- scheduler = get_cosine_schedule_with_warmup(
65
- optimizer,
66
- num_warmup_steps=warmup_steps,
67
- num_training_steps=total_steps,
68
- num_cycles=0.5,
69
- )
70
-
71
- return optimizer, scheduler
72
-
73
-
74
- def train_epoch(
75
- model, train_loader, optimizer, scheduler, device, config, writer, epoch
76
- ):
77
- model.train()
78
- progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
79
- for batch_idx, batch in enumerate(progress_bar):
80
- optimizer.zero_grad()
81
- input_ids = batch["input_ids"].to(device)
82
- attention_mask = batch["attention_mask"].to(device)
83
- labels = [
84
- batch["labels"][task_name].to(device) for task_name in config["task_names"]
85
- ]
86
-
87
- loss, _, _ = model(input_ids, attention_mask, labels)
88
- loss.backward()
89
-
90
- if config["gradient_clipping"]:
91
- torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
92
-
93
- optimizer.step()
94
- scheduler.step()
95
-
96
- writer.add_scalar(
97
- "Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
98
- )
99
- if config.get("use_wandb", False):
100
- import wandb
101
-
102
- wandb.log({"Training Loss": loss.item()})
103
-
104
- # Update progress bar
105
- progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
106
-
107
- return loss.item() # Return the last batch loss
108
-
109
-
110
- def validate_model(model, val_loader, device, config):
111
- model.eval()
112
- val_loss = 0.0
113
- task_true_labels = {task_name: [] for task_name in config["task_names"]}
114
- task_pred_labels = {task_name: [] for task_name in config["task_names"]}
115
- task_pred_probs = {task_name: [] for task_name in config["task_names"]}
116
-
117
- with torch.no_grad():
118
- for batch in val_loader:
119
- input_ids = batch["input_ids"].to(device)
120
- attention_mask = batch["attention_mask"].to(device)
121
  labels = [
122
- batch["labels"][task_name].to(device)
123
- for task_name in config["task_names"]
124
  ]
125
- loss, logits, _ = model(input_ids, attention_mask, labels)
126
- val_loss += loss.item()
127
-
128
- for sample_idx in range(len(batch["input_ids"])):
129
- for i, task_name in enumerate(config["task_names"]):
130
- true_label = batch["labels"][task_name][sample_idx].item()
131
- pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
132
- pred_prob = (
133
- torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
134
- )
135
- task_true_labels[task_name].append(true_label)
136
- task_pred_labels[task_name].append(pred_label)
137
- task_pred_probs[task_name].append(pred_prob)
138
-
139
- val_loss /= len(val_loader)
140
- return val_loss, task_true_labels, task_pred_labels, task_pred_probs
141
-
142
-
143
- def log_metrics(task_metrics, val_loss, config, writer, epochs):
144
- for task_name, metrics in task_metrics.items():
145
- print(
146
- f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
147
- )
148
- if config.get("use_wandb", False):
149
- import wandb
150
 
151
- wandb.log(
152
- {
153
- f"{task_name} Validation F1 Macro": metrics["f1"],
154
- f"{task_name} Validation Accuracy": metrics["accuracy"],
155
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
-
158
- writer.add_scalar("Validation Loss", val_loss, epochs)
159
- for task_name, metrics in task_metrics.items():
160
- writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
161
- writer.add_scalar(
162
- f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
 
 
 
 
 
 
 
 
163
  )
164
 
165
-
166
- def save_validation_predictions(
167
- val_cell_id_mapping,
168
- task_true_labels,
169
- task_pred_labels,
170
- task_pred_probs,
171
- config,
172
- trial_number=None,
173
- ):
174
- if trial_number is not None:
175
- trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
176
- os.makedirs(trial_results_dir, exist_ok=True)
177
- val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
178
- else:
179
- val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
180
-
181
- rows = []
182
- for sample_idx in range(len(val_cell_id_mapping)):
183
- row = {"Cell ID": val_cell_id_mapping[sample_idx]}
184
- for task_name in config["task_names"]:
185
- row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
186
- row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
187
- row[f"{task_name} Probabilities"] = ",".join(
188
- map(str, task_pred_probs[task_name][sample_idx])
189
  )
190
- rows.append(row)
191
-
192
- df = pd.DataFrame(rows)
193
- df.to_csv(val_preds_file, index=False)
194
- print(f"Validation predictions saved to {val_preds_file}")
 
 
 
 
 
195
 
196
 
197
- def train_model(
198
- config,
199
- device,
200
  train_loader,
201
  val_loader,
202
  train_cell_id_mapping,
203
  val_cell_id_mapping,
204
  num_labels_list,
 
 
205
  ):
 
206
  set_seed(config["seed"])
207
  initialize_wandb(config)
208
 
209
- model = create_model(config, num_labels_list, device)
210
- total_steps = len(train_loader) * config["epochs"]
211
- optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
212
-
213
- log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
214
- writer = SummaryWriter(log_dir=log_dir)
215
-
216
- epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
217
- for epoch in epoch_progress:
218
- last_loss = train_epoch(
219
- model, train_loader, optimizer, scheduler, device, config, writer, epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  )
221
- epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
222
 
223
- val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
224
- model, val_loader, device, config
225
- )
226
- task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
227
 
228
- log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
229
- writer.close()
 
 
230
 
231
- save_validation_predictions(
232
- val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
233
- )
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- if config.get("use_wandb", False):
236
- import wandb
237
 
238
- wandb.finish()
239
 
240
- print(f"\nFinal Validation Loss: {val_loss:.4f}")
241
- return val_loss, model # Return both the validation loss and the trained model
 
 
 
 
 
 
 
 
242
 
 
 
 
 
243
 
244
- def objective(
245
- trial,
246
- train_loader,
247
- val_loader,
248
- train_cell_id_mapping,
249
- val_cell_id_mapping,
250
- num_labels_list,
251
- config,
252
- device,
253
- ):
254
- set_seed(config["seed"]) # Set the seed before each trial
255
- initialize_wandb(config)
256
 
257
- # Hyperparameters
258
- config["learning_rate"] = trial.suggest_float(
259
- "learning_rate",
260
- config["hyperparameters"]["learning_rate"]["low"],
261
- config["hyperparameters"]["learning_rate"]["high"],
262
- log=config["hyperparameters"]["learning_rate"]["log"],
263
- )
264
- config["warmup_ratio"] = trial.suggest_float(
265
- "warmup_ratio",
266
- config["hyperparameters"]["warmup_ratio"]["low"],
267
- config["hyperparameters"]["warmup_ratio"]["high"],
268
- )
269
- config["weight_decay"] = trial.suggest_float(
270
- "weight_decay",
271
- config["hyperparameters"]["weight_decay"]["low"],
272
- config["hyperparameters"]["weight_decay"]["high"],
273
- )
274
- config["dropout_rate"] = trial.suggest_float(
275
- "dropout_rate",
276
- config["hyperparameters"]["dropout_rate"]["low"],
277
- config["hyperparameters"]["dropout_rate"]["high"],
278
- )
279
- config["lr_scheduler_type"] = trial.suggest_categorical(
280
- "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
281
- )
282
- config["use_attention_pooling"] = trial.suggest_categorical(
283
- "use_attention_pooling", [False]
284
  )
285
 
286
- if config["use_task_weights"]:
287
- config["task_weights"] = [
288
- trial.suggest_float(
289
- f"task_weight_{i}",
290
- config["hyperparameters"]["task_weights"]["low"],
291
- config["hyperparameters"]["task_weights"]["high"],
292
- )
293
- for i in range(len(num_labels_list))
294
- ]
295
- weight_sum = sum(config["task_weights"])
296
- config["task_weights"] = [
297
- weight / weight_sum for weight in config["task_weights"]
298
- ]
299
- else:
300
- config["task_weights"] = None
301
-
302
- # Dynamic range for max_layers_to_freeze
303
- freeze_range = get_layer_freeze_range(config["pretrained_path"])
304
- config["max_layers_to_freeze"] = trial.suggest_int(
305
- "max_layers_to_freeze",
306
- freeze_range["min"],
307
- freeze_range["max"]
308
- )
309
 
310
- model = create_model(config, num_labels_list, device)
311
- total_steps = len(train_loader) * config["epochs"]
312
- optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
315
- writer = SummaryWriter(log_dir=log_dir)
316
 
317
- for epoch in range(config["epochs"]):
318
- train_epoch(
319
- model, train_loader, optimizer, scheduler, device, config, writer, epoch
320
- )
321
 
322
- val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
323
- model, val_loader, device, config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  )
325
- task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
326
 
327
- log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
328
- writer.close()
 
 
 
 
 
329
 
330
- save_validation_predictions(
331
- val_cell_id_mapping,
332
- task_true_labels,
333
- task_pred_labels,
334
- task_pred_probs,
335
- config,
336
- trial.number,
 
 
 
 
 
 
 
 
 
337
  )
338
 
339
- trial.set_user_attr("model_state_dict", model.state_dict())
340
- trial.set_user_attr("task_weights", config["task_weights"])
341
 
342
- trial.report(val_loss, config["epochs"])
 
 
343
 
344
- if trial.should_prune():
345
- raise optuna.TrialPruned()
346
 
347
- if config.get("use_wandb", False):
348
- import wandb
 
 
349
 
350
- wandb.log(
351
- {
352
- "trial_number": trial.number,
353
- "val_loss": val_loss,
354
- **{
355
- f"{task_name}_f1": metrics["f1"]
356
- for task_name, metrics in task_metrics.items()
357
- },
358
- **{
359
- f"{task_name}_accuracy": metrics["accuracy"]
360
- for task_name, metrics in task_metrics.items()
361
- },
362
- **{
363
- k: v
364
- for k, v in config.items()
365
- if k
366
- in [
367
- "learning_rate",
368
- "warmup_ratio",
369
- "weight_decay",
370
- "dropout_rate",
371
- "lr_scheduler_type",
372
- "use_attention_pooling",
373
- "max_layers_to_freeze",
374
- ]
375
- },
376
- }
377
- )
378
- wandb.finish()
379
 
380
- return val_loss
 
1
  import os
 
 
 
2
  import pandas as pd
3
  import torch
4
+ import torch.distributed as dist
5
+ import torch.multiprocessing as mp
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
  from torch.utils.tensorboard import SummaryWriter
8
  from tqdm import tqdm
9
+ import optuna
10
+ import functools
11
+ import time
12
 
 
13
  from .model import GeneformerMultiTask
14
+ from .utils import (
15
+ calculate_metrics,
16
+ get_layer_freeze_range,
17
+ set_seed,
18
+ initialize_wandb,
19
+ create_model,
20
+ setup_optimizer_and_scheduler,
21
+ save_model,
22
+ save_hyperparameters,
23
+ prepare_training_environment,
24
+ log_training_step,
25
+ log_validation_metrics,
26
+ save_validation_predictions,
27
+ setup_logging,
28
+ setup_distributed_environment,
29
+ train_distributed
30
+ )
31
+
32
+
33
+ class Trainer:
34
+ """Trainer class for multi-task learning"""
35
+
36
+ def __init__(self, config):
37
+ self.config = config
38
+ self.device = None
39
+ self.model = None
40
+ self.optimizer = None
41
+ self.scheduler = None
42
+ self.writer = None
43
+ self.is_distributed = config.get("distributed_training", False)
44
+ self.local_rank = config.get("local_rank", 0)
45
+ self.is_main_process = not self.is_distributed or self.local_rank == 0
46
+
47
+ def train_epoch(self, train_loader, epoch):
48
+ """Train the model for one epoch."""
49
+ epoch_start = time.time()
50
+ self.model.train()
51
+
52
+ # For distributed training, we need to be aware of the global batch count
53
+ if self.is_distributed:
54
+ # Get world size for reporting
55
+ world_size = dist.get_world_size()
56
+ # Calculate total batches across all GPUs
57
+ total_batches_global = len(train_loader) * world_size if self.local_rank == 0 else len(train_loader)
58
+ else:
59
+ world_size = 1
60
+ total_batches_global = len(train_loader)
61
+
62
+ progress_bar = None
63
+ if self.is_main_process:
64
+ # Use the global batch count for progress reporting in distributed mode
65
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}",
66
+ total=len(train_loader))
67
+ iterator = progress_bar
68
+
69
+ # Report distributed training information
70
+ if self.is_distributed:
71
+ print(f"Distributed training: {world_size} GPUs, {len(train_loader)} batches per GPU, "
72
+ f"{total_batches_global} total batches globally")
73
+ else:
74
+ iterator = train_loader
75
+
76
+ batch_times = []
77
+ forward_times = []
78
+ backward_times = []
79
+ optimizer_times = []
80
+
81
+ # Get gradient accumulation steps from config (default to 1 if not specified)
82
+ accumulation_steps = self.config.get("gradient_accumulation_steps", 1)
83
+
84
+ # Zero gradients at the beginning
85
+ self.optimizer.zero_grad()
86
+
87
+ # Track loss for the entire epoch
88
+ total_loss = 0.0
89
+ num_batches = 0
90
+ accumulated_loss = 0.0
91
+
92
+ for batch_idx, batch in enumerate(iterator):
93
+ batch_start = time.time()
94
+
95
+ input_ids = batch["input_ids"].to(self.device)
96
+ attention_mask = batch["attention_mask"].to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  labels = [
98
+ batch["labels"][task_name].to(self.device) for task_name in self.config["task_names"]
 
99
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ forward_start = time.time()
102
+ loss, _, _ = self.model(input_ids, attention_mask, labels)
103
+
104
+ # Scale loss by accumulation steps for gradient accumulation
105
+ if accumulation_steps > 1:
106
+ loss = loss / accumulation_steps
107
+
108
+ forward_end = time.time()
109
+ forward_times.append(forward_end - forward_start)
110
+
111
+ # Track loss - store the unscaled loss for reporting
112
+ unscaled_loss = loss.item() * (1 if accumulation_steps == 1 else accumulation_steps)
113
+ total_loss += unscaled_loss
114
+ num_batches += 1
115
+ accumulated_loss += loss.item() # For gradient accumulation tracking
116
+
117
+ backward_start = time.time()
118
+
119
+ # Use no_sync() for all but the last accumulation step to avoid unnecessary communication
120
+ if self.is_distributed and accumulation_steps > 1:
121
+ # If this is not the last accumulation step or the last batch
122
+ if (batch_idx + 1) % accumulation_steps != 0 and (batch_idx + 1) != len(train_loader):
123
+ with self.model.no_sync():
124
+ loss.backward()
125
+ else:
126
+ loss.backward()
127
+ else:
128
+ # Non-distributed training or accumulation_steps=1
129
+ loss.backward()
130
+
131
+ backward_end = time.time()
132
+ backward_times.append(backward_end - backward_start)
133
+
134
+ # Only update weights after accumulation_steps or at the end of the epoch
135
+ if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
136
+ if self.config["gradient_clipping"]:
137
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
138
+
139
+ optimizer_start = time.time()
140
+ self.optimizer.step()
141
+ self.scheduler.step()
142
+ self.optimizer.zero_grad()
143
+ optimizer_end = time.time()
144
+ optimizer_times.append(optimizer_end - optimizer_start)
145
+
146
+ # Log after optimizer step
147
+ if self.is_main_process:
148
+ # Calculate running average loss
149
+ avg_loss = total_loss / num_batches
150
+
151
+ log_training_step(avg_loss, self.writer, self.config, epoch, len(train_loader), batch_idx)
152
+
153
+ # Update progress bar with just the running average loss
154
+ progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
155
+
156
+ accumulated_loss = 0.0
157
+ else:
158
+ optimizer_times.append(0) # No optimizer step taken
159
+
160
+ batch_end = time.time()
161
+ batch_times.append(batch_end - batch_start)
162
+
163
+ epoch_end = time.time()
164
+
165
+ # Calculate average loss for the epoch
166
+ epoch_avg_loss = total_loss / num_batches
167
+
168
+ # If distributed, gather losses from all processes to compute global average
169
+ if self.is_distributed:
170
+ # Create a tensor to hold the loss
171
+ loss_tensor = torch.tensor([epoch_avg_loss], device=self.device)
172
+ # Gather losses from all processes
173
+ all_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
174
+ dist.all_gather(all_losses, loss_tensor)
175
+ # Compute the global average loss across all processes
176
+ epoch_avg_loss = torch.mean(torch.stack(all_losses)).item()
177
+
178
+ if self.is_main_process:
179
+ # douhble check if batch_size has already been adjusted for world_size in the config
180
+ # This avoids double-counting the effective batch size
181
+ per_gpu_batch_size = self.config['batch_size']
182
+ total_effective_batch = per_gpu_batch_size * accumulation_steps * world_size
183
+
184
+ print(f"Epoch {epoch+1} timing:")
185
+ print(f" Total epoch time: {epoch_end - epoch_start:.2f}s")
186
+ print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
187
+ print(f" Average forward time: {sum(forward_times)/len(forward_times):.4f}s")
188
+ print(f" Average backward time: {sum(backward_times)/len(backward_times):.4f}s")
189
+ print(f" Average optimizer time: {sum([t for t in optimizer_times if t > 0])/max(1, len([t for t in optimizer_times if t > 0])):.4f}s")
190
+ print(f" Gradient accumulation steps: {accumulation_steps}")
191
+ print(f" Batch size per GPU: {per_gpu_batch_size}")
192
+ print(f" Effective global batch size: {total_effective_batch}")
193
+ print(f" Average training loss: {epoch_avg_loss:.4f}")
194
+ if self.is_distributed:
195
+ print(f" Total batches processed across all GPUs: {total_batches_global}")
196
+ print(f" Communication optimization: Using no_sync() for gradient accumulation")
197
+
198
+ return epoch_avg_loss # Return the average loss for the epoch
199
+
200
+ def validate_model(self, val_loader):
201
+ val_start = time.time()
202
+ self.model.eval()
203
+ val_loss = 0.0
204
+ task_true_labels = {task_name: [] for task_name in self.config["task_names"]}
205
+ task_pred_labels = {task_name: [] for task_name in self.config["task_names"]}
206
+ task_pred_probs = {task_name: [] for task_name in self.config["task_names"]}
207
+
208
+ val_cell_ids = {}
209
+ sample_counter = 0
210
+
211
+ batch_times = []
212
+
213
+ # Print validation dataset size
214
+ if self.is_main_process:
215
+ print(f"Validation dataset size: {len(val_loader.dataset)} samples")
216
+ print(f"Number of validation batches: {len(val_loader)}")
217
+
218
+ if self.is_distributed:
219
+ world_size = dist.get_world_size()
220
+ print(f"Distributed validation: {world_size} GPUs")
221
+ if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
222
+ samples_per_gpu = val_loader.sampler.num_samples
223
+ print(f"Each GPU processes {samples_per_gpu} validation samples")
224
+ print(f"Total validation samples processed: {samples_per_gpu * world_size}")
225
+
226
+ with torch.no_grad():
227
+ for batch in val_loader:
228
+ batch_start = time.time()
229
+ input_ids = batch["input_ids"].to(self.device)
230
+ attention_mask = batch["attention_mask"].to(self.device)
231
+ labels = [
232
+ batch["labels"][task_name].to(self.device)
233
+ for task_name in self.config["task_names"]
234
+ ]
235
+ loss, logits, _ = self.model(input_ids, attention_mask, labels)
236
+ val_loss += loss.item()
237
+
238
+ if "cell_id" in batch:
239
+ for i, cell_id in enumerate(batch["cell_id"]):
240
+ # Store the actual index for later mapping to unique_cell_id
241
+ val_cell_ids[sample_counter + i] = cell_id.item()
242
+
243
+ for sample_idx in range(len(batch["input_ids"])):
244
+ for i, task_name in enumerate(self.config["task_names"]):
245
+ true_label = batch["labels"][task_name][sample_idx].item()
246
+ pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
247
+ # Store the full probability distribution
248
+ pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy().tolist()
249
+ task_true_labels[task_name].append(true_label)
250
+ task_pred_labels[task_name].append(pred_label)
251
+ task_pred_probs[task_name].append(pred_prob)
252
+
253
+ # Update current index for cell ID tracking
254
+ sample_counter += len(batch["input_ids"])
255
+
256
+ batch_end = time.time()
257
+ batch_times.append(batch_end - batch_start)
258
+
259
+ # norm validation loss by the number of batches
260
+ val_loss /= len(val_loader)
261
+
262
+ # distributed, gather results from all processes
263
+ if self.is_distributed:
264
+ # Create tensors to hold the local results
265
+ loss_tensor = torch.tensor([val_loss], device=self.device)
266
+ gathered_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
267
+ dist.all_gather(gathered_losses, loss_tensor)
268
+
269
+ # Compute average loss across all processes
270
+ val_loss = torch.mean(torch.cat(gathered_losses)).item()
271
+
272
+ world_size = dist.get_world_size()
273
+
274
+ if self.is_main_process:
275
+ print(f"Collected predictions from rank {self.local_rank}")
276
+ print(f"Number of samples processed by this rank: {sample_counter}")
277
+
278
+ val_end = time.time()
279
+
280
+ if self.is_main_process:
281
+ print(f"Validation timing:")
282
+ print(f" Total validation time: {val_end - val_start:.2f}s")
283
+ print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
284
+ print(f" Collected {len(val_cell_ids)} cell indices from validation data")
285
+ print(f" Processed {sample_counter} total samples during validation")
286
+
287
+ # Print number of samples per task
288
+ for task_name in self.config["task_names"]:
289
+ print(f" Task {task_name}: {len(task_true_labels[task_name])} samples")
290
+
291
+ return val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids
292
+
293
+ def train(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
294
+ """Train the model and return validation loss and trained model."""
295
+ if self.config.get("use_wandb", False) and self.is_main_process:
296
+ initialize_wandb(self.config)
297
+
298
+ # Create model
299
+ self.model = create_model(self.config, num_labels_list, self.device, self.is_distributed, self.local_rank)
300
+
301
+ # Setup optimizer and scheduler
302
+ total_steps = len(train_loader) * self.config["epochs"]
303
+ self.optimizer, self.scheduler = setup_optimizer_and_scheduler(self.model, self.config, total_steps)
304
+
305
+ # Training loop
306
+ if self.is_main_process:
307
+ epoch_progress = tqdm(range(self.config["epochs"]), desc="Training Progress")
308
+ else:
309
+ epoch_progress = range(self.config["epochs"])
310
+
311
+ best_val_loss = float('inf')
312
+ train_losses = []
313
+
314
+ with setup_logging(self.config) as self.writer:
315
+ for epoch in epoch_progress:
316
+ if self.is_distributed:
317
+ train_loader.sampler.set_epoch(epoch)
318
+
319
+ train_loss = self.train_epoch(train_loader, epoch)
320
+ train_losses.append(train_loss)
321
+
322
+ # Run validation after each epoch if configured to do so
323
+ if self.config.get("validate_each_epoch", False):
324
+ val_loss, _, _, _, _ = self.validate_model(val_loader)
325
+ if val_loss < best_val_loss:
326
+ best_val_loss = val_loss
327
+
328
+ if self.is_main_process:
329
+ epoch_progress.set_postfix({
330
+ "train_loss": f"{train_loss:.4f}",
331
+ "val_loss": f"{val_loss:.4f}",
332
+ "best_val_loss": f"{best_val_loss:.4f}"
333
+ })
334
+ else:
335
+ if self.is_main_process:
336
+ epoch_progress.set_postfix({
337
+ "train_loss": f"{train_loss:.4f}"
338
+ })
339
+
340
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = self.validate_model(val_loader)
341
+ task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
342
+
343
+ if self.is_main_process:
344
+ log_validation_metrics(task_metrics, val_loss, self.config, self.writer, self.config["epochs"])
345
+
346
+ # Save validation predictions
347
+ save_validation_predictions(
348
+ val_cell_ids,
349
+ task_true_labels,
350
+ task_pred_labels,
351
+ task_pred_probs,
352
+ {**self.config, "val_cell_mapping": val_cell_id_mapping} # Include the mapping
353
+ )
354
+
355
+ if self.config.get("use_wandb", False):
356
+ import wandb
357
+ wandb.finish()
358
+
359
+ print(f"\nTraining Summary:")
360
+ print(f" Final Training Loss: {train_losses[-1]:.4f}")
361
+ print(f" Final Validation Loss: {val_loss:.4f}")
362
+ for task_name, metrics in task_metrics.items():
363
+ print(f" {task_name} - F1: {metrics['f1']:.4f}, Accuracy: {metrics['accuracy']:.4f}")
364
+
365
+ return val_loss, self.model # Return both the validation loss and the trained model
366
+
367
+ def setup(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
368
+ if self.is_distributed:
369
+ self.device = torch.device(f"cuda:{self.local_rank}")
370
+ else:
371
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
372
+
373
+ self.model = create_model(self.config, num_labels_list, self.device)
374
+
375
+ # war model w DDP
376
+ if self.is_distributed:
377
+ self.model = DDP(self.model, device_ids=[self.local_rank])
378
+
379
+ # communication hook to optimize gradient synchronization
380
+ from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
381
+
382
+ # default hook which maintains full precision
383
+ self.model.register_comm_hook(
384
+ state=None,
385
+ hook=comm_hooks.allreduce_hook
386
  )
387
+
388
+ print(f"Rank {self.local_rank}: Registered communication hook for optimized gradient synchronization")
389
+
390
+ print(f"Rank {self.local_rank}: Using samplers created in distributed worker")
391
+ print(f"Rank {self.local_rank}: Training dataset has {len(train_loader.dataset)} samples")
392
+ if hasattr(train_loader, 'sampler') and hasattr(train_loader.sampler, 'num_samples'):
393
+ print(f"Rank {self.local_rank}: This GPU will process {train_loader.sampler.num_samples} training samples per epoch")
394
+
395
+ if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
396
+ print(f"Rank {self.local_rank}: This GPU will process {val_loader.sampler.num_samples} validation samples")
397
+
398
+ # Set up optimizer and scheduler
399
+ self.optimizer, self.scheduler = setup_optimizer_and_scheduler(
400
+ self.model, self.config, len(train_loader)
401
  )
402
 
403
+ if self.is_main_process and self.config.get("use_wandb", False):
404
+ initialize_wandb(self.config)
405
+
406
+ return train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
407
+
408
+
409
+ def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
410
+ """Train a model with the given configuration and data."""
411
+ # Check if distributed training is enabled
412
+ if config.get("distributed_training", False):
413
+ # Check if we have multiple GPUs
414
+ if torch.cuda.device_count() > 1:
415
+ result = train_distributed(
416
+ Trainer,
417
+ config,
418
+ train_loader,
419
+ val_loader,
420
+ train_cell_id_mapping,
421
+ val_cell_id_mapping,
422
+ num_labels_list
 
 
 
 
423
  )
424
+ if result is not None:
425
+ return result
426
+ else:
427
+ print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
428
+ config["distributed_training"] = False
429
+
430
+ # Non-distributed training
431
+ trainer = Trainer(config)
432
+ trainer.device = device
433
+ return trainer.train(train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
434
 
435
 
436
+ def objective(
437
+ trial,
 
438
  train_loader,
439
  val_loader,
440
  train_cell_id_mapping,
441
  val_cell_id_mapping,
442
  num_labels_list,
443
+ config,
444
+ device,
445
  ):
446
+ """Objective function for Optuna hyperparameter optimization."""
447
  set_seed(config["seed"])
448
  initialize_wandb(config)
449
 
450
+ trial_config = config.copy()
451
+
452
+ # Suggest hyperparameters for this trial
453
+ for param_name, param_config in config["hyperparameters"].items():
454
+ if param_name == "lr_scheduler_type":
455
+ trial_config[param_name] = trial.suggest_categorical(
456
+ param_name, param_config["choices"]
457
+ )
458
+ elif param_name == "task_weights" and config["use_task_weights"]:
459
+ weights = [
460
+ trial.suggest_float(
461
+ f"task_weight_{i}",
462
+ param_config["low"],
463
+ param_config["high"],
464
+ )
465
+ for i in range(len(num_labels_list))
466
+ ]
467
+ weight_sum = sum(weights)
468
+ trial_config[param_name] = [w / weight_sum for w in weights]
469
+ elif "log" in param_config and param_config["log"]:
470
+ trial_config[param_name] = trial.suggest_float(
471
+ param_name, param_config["low"], param_config["high"], log=True
472
+ )
473
+ else:
474
+ trial_config[param_name] = trial.suggest_float(
475
+ param_name, param_config["low"], param_config["high"]
476
+ )
477
+
478
+ # Set appropriate max layers to freeze based on pretrained model
479
+ if "max_layers_to_freeze" in trial_config:
480
+ freeze_range = get_layer_freeze_range(trial_config["pretrained_path"])
481
+ trial_config["max_layers_to_freeze"] = int(trial.suggest_int(
482
+ "max_layers_to_freeze",
483
+ freeze_range["min"],
484
+ freeze_range["max"]
485
+ ))
486
+
487
+ trial_config["run_name"] = f"trial_{trial.number}"
488
+
489
+ # Handle distributed training for this trial
490
+ if trial_config.get("distributed_training", False) and torch.cuda.device_count() > 1:
491
+ manager = mp.Manager()
492
+ shared_dict = manager.dict()
493
+
494
+ train_distributed(
495
+ Trainer,
496
+ trial_config,
497
+ train_loader,
498
+ val_loader,
499
+ train_cell_id_mapping,
500
+ val_cell_id_mapping,
501
+ num_labels_list,
502
+ trial.number,
503
+ shared_dict
504
+ )
505
+
506
+ val_loss = shared_dict.get('val_loss', float('inf'))
507
+ task_metrics = shared_dict.get('task_metrics', {})
508
+
509
+ trial.set_user_attr("model_state_dict", shared_dict.get('model_state_dict', {}))
510
+ trial.set_user_attr("task_weights", trial_config["task_weights"])
511
+
512
+ if config.get("use_wandb", False):
513
+ import wandb
514
+ wandb.log({
515
+ "trial_number": trial.number,
516
+ "val_loss": val_loss,
517
+ **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
518
+ **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
519
+ })
520
+ wandb.finish()
521
+
522
+ return val_loss
523
+
524
+ with setup_logging(trial_config) as writer:
525
+ trainer = Trainer(trial_config)
526
+ trainer.device = device
527
+ trainer.writer = writer
528
+
529
+ # Create model with trial hyperparameters
530
+ trainer.model = create_model(trial_config, num_labels_list, device)
531
+ total_steps = len(train_loader) * config["epochs"]
532
+ trainer.optimizer, trainer.scheduler = setup_optimizer_and_scheduler(trainer.model, trial_config, total_steps)
533
+
534
+ # Training loop
535
+ for epoch in range(config["epochs"]):
536
+ trainer.train_epoch(train_loader, epoch)
537
+
538
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = trainer.validate_model(val_loader)
539
+ task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
540
+
541
+ # Log metrics
542
+ log_validation_metrics(task_metrics, val_loss, trial_config, writer, config["epochs"])
543
+
544
+ # Save validation predictions
545
+ save_validation_predictions(
546
+ val_cell_ids,
547
+ task_true_labels,
548
+ task_pred_labels,
549
+ task_pred_probs,
550
+ {**trial_config, "val_cell_mapping": val_cell_id_mapping},
551
+ trial.number,
552
  )
 
553
 
554
+ # Store model state dict and task weights in trial user attributes
555
+ trial.set_user_attr("model_state_dict", trainer.model.state_dict())
556
+ trial.set_user_attr("task_weights", trial_config["task_weights"])
 
557
 
558
+ # Report intermediate value to Optuna
559
+ trial.report(val_loss, config["epochs"])
560
+ if trial.should_prune():
561
+ raise optuna.TrialPruned()
562
 
563
+ if config.get("use_wandb", False):
564
+ import wandb
565
+ wandb.log(
566
+ {
567
+ "trial_number": trial.number,
568
+ "val_loss": val_loss,
569
+ **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
570
+ **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
571
+ **{k: v for k, v in trial_config.items() if k in [
572
+ "learning_rate", "warmup_ratio", "weight_decay", "dropout_rate",
573
+ "lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze"
574
+ ]},
575
+ }
576
+ )
577
+ wandb.finish()
578
 
579
+ return val_loss
 
580
 
 
581
 
582
+ def run_manual_tuning(config):
583
+ """Run training with manually specified hyperparameters."""
584
+ (
585
+ device,
586
+ train_loader,
587
+ val_loader,
588
+ train_cell_id_mapping,
589
+ val_cell_id_mapping,
590
+ num_labels_list,
591
+ ) = prepare_training_environment(config)
592
 
593
+ print("\nManual hyperparameters being used:")
594
+ for key, value in config["manual_hyperparameters"].items():
595
+ print(f"{key}: {value}")
596
+ print()
597
 
598
+ # Update config with manual hyperparameters
599
+ for key, value in config["manual_hyperparameters"].items():
600
+ config[key] = value
 
 
 
 
 
 
 
 
 
601
 
602
+ # Train the model
603
+ val_loss, trained_model = train_model(
604
+ config,
605
+ device,
606
+ train_loader,
607
+ val_loader,
608
+ train_cell_id_mapping,
609
+ val_cell_id_mapping,
610
+ num_labels_list,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  )
612
 
613
+ print(f"\nValidation loss with manual hyperparameters: {val_loss}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
 
615
+ # Save the trained model - only if not using distributed training
616
+ # (distributed training saves the model in the worker)
617
+ if not config.get("distributed_training", False):
618
+ model_save_directory = os.path.join(
619
+ config["model_save_path"], "GeneformerMultiTask"
620
+ )
621
+ save_model(trained_model, model_save_directory)
622
+
623
+ # Save the hyperparameters
624
+ hyperparams_to_save = {
625
+ **config["manual_hyperparameters"],
626
+ "dropout_rate": config["dropout_rate"],
627
+ "use_task_weights": config["use_task_weights"],
628
+ "task_weights": config["task_weights"],
629
+ "max_layers_to_freeze": config["max_layers_to_freeze"],
630
+ "use_attention_pooling": config["use_attention_pooling"],
631
+ }
632
+ save_hyperparameters(model_save_directory, hyperparams_to_save)
633
 
634
+ return val_loss
 
635
 
 
 
 
 
636
 
637
+ def run_optuna_study(config):
638
+ """Run hyperparameter optimization using Optuna."""
639
+ # Prepare training environment
640
+ (
641
+ device,
642
+ train_loader,
643
+ val_loader,
644
+ train_cell_id_mapping,
645
+ val_cell_id_mapping,
646
+ num_labels_list,
647
+ ) = prepare_training_environment(config)
648
+
649
+ # If manual hyperparameters are specified, use them instead of running Optuna
650
+ if config.get("use_manual_hyperparameters", False):
651
+ return run_manual_tuning(config)
652
+
653
+ # Create a partial function with fixed arguments for the objective
654
+ objective_with_config_and_data = functools.partial(
655
+ objective,
656
+ train_loader=train_loader,
657
+ val_loader=val_loader,
658
+ train_cell_id_mapping=train_cell_id_mapping,
659
+ val_cell_id_mapping=val_cell_id_mapping,
660
+ num_labels_list=num_labels_list,
661
+ config=config,
662
+ device=device,
663
  )
 
664
 
665
+ # Create and run the Optuna study
666
+ study = optuna.create_study(
667
+ direction="minimize", # Minimize validation loss
668
+ study_name=config["study_name"],
669
+ # storage=config["storage"],
670
+ load_if_exists=True,
671
+ )
672
 
673
+ study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
674
+
675
+ # After finding the best trial
676
+ best_params = study.best_trial.params
677
+ best_task_weights = study.best_trial.user_attrs["task_weights"]
678
+ print("Saving the best model and its hyperparameters...")
679
+
680
+ # Create a model with the best hyperparameters
681
+ best_model = GeneformerMultiTask(
682
+ config["pretrained_path"],
683
+ num_labels_list,
684
+ dropout_rate=best_params["dropout_rate"],
685
+ use_task_weights=config["use_task_weights"],
686
+ task_weights=best_task_weights,
687
+ max_layers_to_freeze=best_params.get("max_layers_to_freeze", 0),
688
+ use_attention_pooling=best_params.get("use_attention_pooling", False),
689
  )
690
 
691
+ # Get the best model state dictionary
692
+ best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
693
 
694
+ best_model_state_dict = {
695
+ k.replace("module.", ""): v for k, v in best_model_state_dict.items()
696
+ }
697
 
698
+ best_model.load_state_dict(best_model_state_dict, strict=False)
 
699
 
700
+ model_save_directory = os.path.join(
701
+ config["model_save_path"], "GeneformerMultiTask"
702
+ )
703
+ save_model(best_model, model_save_directory)
704
 
705
+ save_hyperparameters(model_save_directory, {**best_params, "task_weights": best_task_weights})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
 
707
+ return study.best_trial.value # Return the best validation loss
geneformer/mtl/train_utils.py DELETED
@@ -1,161 +0,0 @@
1
- import random
2
-
3
- from .data import get_data_loader, preload_and_process_data
4
- from .imports import *
5
- from .model import GeneformerMultiTask
6
- from .train import objective, train_model
7
- from .utils import save_model
8
-
9
-
10
- def set_seed(seed):
11
- random.seed(seed)
12
- np.random.seed(seed)
13
- torch.manual_seed(seed)
14
- torch.cuda.manual_seed_all(seed)
15
- torch.backends.cudnn.deterministic = True
16
- torch.backends.cudnn.benchmark = False
17
-
18
-
19
- def run_manual_tuning(config):
20
- # Set seed for reproducibility
21
- set_seed(config["seed"])
22
-
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- (
25
- train_dataset,
26
- train_cell_id_mapping,
27
- val_dataset,
28
- val_cell_id_mapping,
29
- num_labels_list,
30
- ) = preload_and_process_data(config)
31
- train_loader = get_data_loader(train_dataset, config["batch_size"])
32
- val_loader = get_data_loader(val_dataset, config["batch_size"])
33
-
34
- # Print the manual hyperparameters being used
35
- print("\nManual hyperparameters being used:")
36
- for key, value in config["manual_hyperparameters"].items():
37
- print(f"{key}: {value}")
38
- print() # Add an empty line for better readability
39
-
40
- # Use the manual hyperparameters
41
- for key, value in config["manual_hyperparameters"].items():
42
- config[key] = value
43
-
44
- # Train the model
45
- val_loss, trained_model = train_model(
46
- config,
47
- device,
48
- train_loader,
49
- val_loader,
50
- train_cell_id_mapping,
51
- val_cell_id_mapping,
52
- num_labels_list,
53
- )
54
-
55
- print(f"\nValidation loss with manual hyperparameters: {val_loss}")
56
-
57
- # Save the trained model
58
- model_save_directory = os.path.join(
59
- config["model_save_path"], "GeneformerMultiTask"
60
- )
61
- save_model(trained_model, model_save_directory)
62
-
63
- # Save the hyperparameters
64
- hyperparams_to_save = {
65
- **config["manual_hyperparameters"],
66
- "dropout_rate": config["dropout_rate"],
67
- "use_task_weights": config["use_task_weights"],
68
- "task_weights": config["task_weights"],
69
- "max_layers_to_freeze": config["max_layers_to_freeze"],
70
- "use_attention_pooling": config["use_attention_pooling"],
71
- }
72
- hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
73
- with open(hyperparams_path, "w") as f:
74
- json.dump(hyperparams_to_save, f)
75
- print(f"Manual hyperparameters saved to {hyperparams_path}")
76
-
77
- return val_loss
78
-
79
-
80
- def run_optuna_study(config):
81
- # Set seed for reproducibility
82
- set_seed(config["seed"])
83
-
84
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
- (
86
- train_dataset,
87
- train_cell_id_mapping,
88
- val_dataset,
89
- val_cell_id_mapping,
90
- num_labels_list,
91
- ) = preload_and_process_data(config)
92
- train_loader = get_data_loader(train_dataset, config["batch_size"])
93
- val_loader = get_data_loader(val_dataset, config["batch_size"])
94
-
95
- if config["use_manual_hyperparameters"]:
96
- train_model(
97
- config,
98
- device,
99
- train_loader,
100
- val_loader,
101
- train_cell_id_mapping,
102
- val_cell_id_mapping,
103
- num_labels_list,
104
- )
105
- else:
106
- objective_with_config_and_data = functools.partial(
107
- objective,
108
- train_loader=train_loader,
109
- val_loader=val_loader,
110
- train_cell_id_mapping=train_cell_id_mapping,
111
- val_cell_id_mapping=val_cell_id_mapping,
112
- num_labels_list=num_labels_list,
113
- config=config,
114
- device=device,
115
- )
116
-
117
- study = optuna.create_study(
118
- direction="minimize", # Minimize validation loss
119
- study_name=config["study_name"],
120
- # storage=config["storage"],
121
- load_if_exists=True,
122
- )
123
-
124
- study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
125
-
126
- # After finding the best trial
127
- best_params = study.best_trial.params
128
- best_task_weights = study.best_trial.user_attrs["task_weights"]
129
- print("Saving the best model and its hyperparameters...")
130
-
131
- # Saving model as before
132
- best_model = GeneformerMultiTask(
133
- config["pretrained_path"],
134
- num_labels_list,
135
- dropout_rate=best_params["dropout_rate"],
136
- use_task_weights=config["use_task_weights"],
137
- task_weights=best_task_weights,
138
- )
139
-
140
- # Get the best model state dictionary
141
- best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
142
-
143
- # Remove the "module." prefix from the state dictionary keys if present
144
- best_model_state_dict = {
145
- k.replace("module.", ""): v for k, v in best_model_state_dict.items()
146
- }
147
-
148
- # Load the modified state dictionary into the model, skipping unexpected keys
149
- best_model.load_state_dict(best_model_state_dict, strict=False)
150
-
151
- model_save_directory = os.path.join(
152
- config["model_save_path"], "GeneformerMultiTask"
153
- )
154
- save_model(best_model, model_save_directory)
155
-
156
- # Additionally, save the best hyperparameters and task weights
157
- hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
158
-
159
- with open(hyperparams_path, "w") as f:
160
- json.dump({**best_params, "task_weights": best_task_weights}, f)
161
- print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/utils.py CHANGED
@@ -1,129 +1,641 @@
 
 
1
  import os
2
- import shutil
3
-
 
 
 
 
4
  from sklearn.metrics import accuracy_score, f1_score
5
  from sklearn.preprocessing import LabelEncoder
6
- from transformers import AutoConfig, BertConfig, BertModel
 
 
 
 
 
 
 
7
 
8
- from .imports import *
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- def save_model(model, model_save_directory):
12
- if not os.path.exists(model_save_directory):
13
- os.makedirs(model_save_directory)
14
-
15
- # Get the state dict
16
- if isinstance(model, nn.DataParallel):
17
- model_state_dict = (
18
- model.module.state_dict()
19
- ) # Use model.module to access the underlying model
20
- else:
21
- model_state_dict = model.state_dict()
22
 
23
- # Remove the "module." prefix from the keys if present
24
- model_state_dict = {
25
- k.replace("module.", ""): v for k, v in model_state_dict.items()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
29
  torch.save(model_state_dict, model_save_path)
30
 
31
  # Save the model configuration
32
- if isinstance(model, nn.DataParallel):
33
- model.module.config.to_json_file(
34
- os.path.join(model_save_directory, "config.json")
35
- )
36
- else:
37
- model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
38
 
39
  print(f"Model and configuration saved to {model_save_directory}")
40
 
41
 
42
- def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
43
- task_metrics = {}
44
- for task_name in task_true_labels.keys():
45
- true_labels = task_true_labels[task_name]
46
- pred_labels = task_pred_labels[task_name]
47
- f1 = f1_score(true_labels, pred_labels, average="macro")
48
- accuracy = accuracy_score(true_labels, pred_labels)
49
- task_metrics[task_name] = {"f1": f1, "accuracy": accuracy}
50
- return task_metrics
51
 
52
 
53
- def calculate_combined_f1(combined_labels, combined_preds):
54
- # Initialize the LabelEncoder
55
- le = LabelEncoder()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Fit and transform combined labels and predictions to numerical values
58
- le.fit(combined_labels + combined_preds)
59
- encoded_true_labels = le.transform(combined_labels)
60
- encoded_pred_labels = le.transform(combined_preds)
61
 
62
- # Print out the mapping for sanity check
63
- print("\nLabel Encoder Mapping:")
64
- for index, class_label in enumerate(le.classes_):
65
- print(f"'{class_label}': {index}")
 
 
 
66
 
67
- # Calculate accuracy
68
- accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
69
 
70
- # Calculate F1 Macro score
71
- f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro")
 
 
 
 
 
 
 
 
 
 
72
 
73
- return f1, accuracy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
- # def save_model_without_heads(original_model_save_directory):
77
- # # Create a new directory for the model without heads
78
- # new_model_save_directory = original_model_save_directory + "_No_Heads"
79
- # if not os.path.exists(new_model_save_directory):
80
- # os.makedirs(new_model_save_directory)
 
 
 
 
 
 
 
 
 
81
 
82
- # # Load the model state dictionary
83
- # model_state_dict = torch.load(
84
- # os.path.join(original_model_save_directory, "pytorch_model.bin")
85
- # )
86
 
87
- # # Initialize a new BERT model without the classification heads
88
- # config = BertConfig.from_pretrained(
89
- # os.path.join(original_model_save_directory, "config.json")
90
- # )
91
- # model_without_heads = BertModel(config)
 
 
 
 
 
 
 
 
 
92
 
93
- # # Filter the state dict to exclude classification heads
94
- # model_without_heads_state_dict = {
95
- # k: v
96
- # for k, v in model_state_dict.items()
97
- # if not k.startswith("classification_heads")
98
- # }
99
 
100
- # # Load the filtered state dict into the model
101
- # model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
 
 
 
 
 
 
 
102
 
103
- # # Save the model without heads
104
- # model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
105
- # torch.save(model_without_heads.state_dict(), model_save_path)
106
 
107
- # # Copy the configuration file
108
- # shutil.copy(
109
- # os.path.join(original_model_save_directory, "config.json"),
110
- # new_model_save_directory,
111
- # )
 
 
 
112
 
113
- # print(f"Model without classification heads saved to {new_model_save_directory}")
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- def get_layer_freeze_range(pretrained_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  """
118
- Dynamically determines the number of layers to freeze based on the model depth from its configuration.
 
119
  Args:
120
- pretrained_path (str): Path to the pretrained model directory or model identifier.
121
- Returns:
122
- dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze.
123
  """
124
- if pretrained_path:
125
- config = AutoConfig.from_pretrained(pretrained_path)
126
- total_layers = config.num_hidden_layers
127
- return {"min": 0, "max": total_layers - 1}
128
- else:
129
- return {"min": 0, "max": 0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+ import json
3
  import os
4
+ import pickle
5
+ import random
6
+ import torch
7
+ import numpy as np
8
+ import wandb
9
+ import optuna
10
  from sklearn.metrics import accuracy_score, f1_score
11
  from sklearn.preprocessing import LabelEncoder
12
+ from torch.utils.tensorboard import SummaryWriter
13
+ from transformers import AutoConfig, BertConfig, BertModel, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
14
+ from torch.optim import AdamW
15
+ import pandas as pd
16
+ import torch.distributed as dist
17
+ from torch.nn.parallel import DistributedDataParallel as DDP
18
+ import torch.multiprocessing as mp
19
+ from contextlib import contextmanager
20
 
 
21
 
22
+ def set_seed(seed):
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+ torch.cuda.manual_seed_all(seed)
27
+ torch.backends.cudnn.deterministic = True
28
+ torch.backends.cudnn.benchmark = False
29
+
30
+
31
+ def initialize_wandb(config):
32
+ if config.get("use_wandb", False):
33
+ wandb.init(
34
+ project=config.get("wandb_project", "geneformer_multitask"),
35
+ name=config.get("run_name", "experiment"),
36
+ config=config,
37
+ reinit=True,
38
+ )
39
+
40
+
41
+ def create_model(config, num_labels_list, device, is_distributed=False, local_rank=0):
42
+ """Create and initialize the model based on configuration."""
43
+ from .model import GeneformerMultiTask
44
+
45
+ model = GeneformerMultiTask(
46
+ config["pretrained_path"],
47
+ num_labels_list,
48
+ dropout_rate=config.get("dropout_rate", 0.1),
49
+ use_task_weights=config.get("use_task_weights", False),
50
+ task_weights=config.get("task_weights", None),
51
+ max_layers_to_freeze=config.get("max_layers_to_freeze", 0),
52
+ use_attention_pooling=config.get("use_attention_pooling", False),
53
+ )
54
+
55
+ # Move model to device
56
+ model.to(device)
57
+
58
+ if is_distributed:
59
+ model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
60
+
61
+ return model
62
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ def setup_optimizer_and_scheduler(model, config, total_steps):
65
+ """Set up optimizer and learning rate scheduler."""
66
+ no_decay = ["bias", "LayerNorm.weight"]
67
+ optimizer_grouped_parameters = [
68
+ {
69
+ "params": [p for n, p in model.named_parameters()
70
+ if not any(nd in n for nd in no_decay) and p.requires_grad],
71
+ "weight_decay": config["weight_decay"],
72
+ },
73
+ {
74
+ "params": [p for n, p in model.named_parameters()
75
+ if any(nd in n for nd in no_decay) and p.requires_grad],
76
+ "weight_decay": 0.0,
77
+ },
78
+ ]
79
+
80
+ optimizer = AdamW(
81
+ optimizer_grouped_parameters,
82
+ lr=config["learning_rate"],
83
+ eps=config.get("adam_epsilon", 1e-8)
84
+ )
85
+
86
+ # Prepare scheduler
87
+ warmup_steps = int(total_steps * config["warmup_ratio"])
88
+
89
+ scheduler_map = {
90
+ "linear": get_linear_schedule_with_warmup,
91
+ "cosine": get_cosine_schedule_with_warmup
92
  }
93
+
94
+ scheduler_fn = scheduler_map.get(config["lr_scheduler_type"])
95
+ if not scheduler_fn:
96
+ raise ValueError(f"Unsupported scheduler type: {config['lr_scheduler_type']}")
97
+
98
+ scheduler = scheduler_fn(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
99
+
100
+ return optimizer, scheduler
101
+
102
+
103
+ def save_model(model, model_save_directory):
104
+ """Save model weights and configuration."""
105
+ os.makedirs(model_save_directory, exist_ok=True)
106
+
107
+ # Handle DDP model
108
+ if isinstance(model, DDP):
109
+ model_to_save = model.module
110
+ else:
111
+ model_to_save = model
112
+
113
+ model_state_dict = model_to_save.state_dict()
114
 
115
  model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
116
  torch.save(model_state_dict, model_save_path)
117
 
118
  # Save the model configuration
119
+ model_to_save.config.to_json_file(os.path.join(model_save_directory, "config.json"))
 
 
 
 
 
120
 
121
  print(f"Model and configuration saved to {model_save_directory}")
122
 
123
 
124
+ def save_hyperparameters(model_save_directory, hyperparams):
125
+ """Save hyperparameters to a JSON file."""
126
+ hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
127
+ with open(hyperparams_path, "w") as f:
128
+ json.dump(hyperparams, f)
129
+ print(f"Hyperparameters saved to {hyperparams_path}")
 
 
 
130
 
131
 
132
+ def calculate_metrics(labels=None, preds=None, task_data=None, metric_type="task_specific", return_format="dict"):
133
+ if metric_type == "single":
134
+ # Calculate metrics for a single task
135
+ if labels is None or preds is None:
136
+ raise ValueError("Labels and predictions must be provided for single task metrics")
137
+
138
+ task_name = None
139
+ if isinstance(labels, dict) and len(labels) == 1:
140
+ task_name = list(labels.keys())[0]
141
+ labels = labels[task_name]
142
+ preds = preds[task_name]
143
+
144
+ f1 = f1_score(labels, preds, average="macro")
145
+ accuracy = accuracy_score(labels, preds)
146
+
147
+ if return_format == "tuple":
148
+ return f1, accuracy
149
+
150
+ result = {"f1": f1, "accuracy": accuracy}
151
+ if task_name:
152
+ return {task_name: result}
153
+ return result
154
+
155
+ elif metric_type == "task_specific":
156
+ # Calculate metrics for multiple tasks
157
+ if task_data:
158
+ result = {}
159
+ for task_name, (task_labels, task_preds) in task_data.items():
160
+ f1 = f1_score(task_labels, task_preds, average="macro")
161
+ accuracy = accuracy_score(task_labels, task_preds)
162
+ result[task_name] = {"f1": f1, "accuracy": accuracy}
163
+ return result
164
+ elif isinstance(labels, dict) and isinstance(preds, dict):
165
+ result = {}
166
+ for task_name in labels:
167
+ if task_name in preds:
168
+ f1 = f1_score(labels[task_name], preds[task_name], average="macro")
169
+ accuracy = accuracy_score(labels[task_name], preds[task_name])
170
+ result[task_name] = {"f1": f1, "accuracy": accuracy}
171
+ return result
172
+ else:
173
+ raise ValueError("For task_specific metrics, either task_data or labels and preds dictionaries must be provided")
174
+
175
+ elif metric_type == "combined":
176
+ # Calculate combined metrics across all tasks
177
+ if labels is None or preds is None:
178
+ raise ValueError("Labels and predictions must be provided for combined metrics")
179
+
180
+ # Handle label encoding for non-numeric labels
181
+ if not all(isinstance(x, (int, float)) for x in labels + preds):
182
+ le = LabelEncoder()
183
+ le.fit(labels + preds)
184
+ labels = le.transform(labels)
185
+ preds = le.transform(preds)
186
+
187
+ f1 = f1_score(labels, preds, average="macro")
188
+ accuracy = accuracy_score(labels, preds)
189
+
190
+ if return_format == "tuple":
191
+ return f1, accuracy
192
+ return {"f1": f1, "accuracy": accuracy}
193
+
194
+ else:
195
+ raise ValueError(f"Unknown metric_type: {metric_type}")
196
 
 
 
 
 
197
 
198
+ def get_layer_freeze_range(pretrained_path):
199
+ if not pretrained_path:
200
+ return {"min": 0, "max": 0}
201
+
202
+ config = AutoConfig.from_pretrained(pretrained_path)
203
+ total_layers = config.num_hidden_layers
204
+ return {"min": 0, "max": total_layers - 1}
205
 
 
 
206
 
207
+ def prepare_training_environment(config):
208
+ """
209
+ Prepare the training environment by setting seed and loading data.
210
+
211
+ Returns:
212
+ tuple: (device, train_loader, val_loader, train_cell_id_mapping,
213
+ val_cell_id_mapping, num_labels_list)
214
+ """
215
+ from .data import prepare_data_loaders
216
+
217
+ # Set seed for reproducibility
218
+ set_seed(config["seed"])
219
 
220
+ # Set up device - for non-distributed training
221
+ if not config.get("distributed_training", False):
222
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223
+ else:
224
+ # For distributed training, device will be set per process
225
+ device = None
226
+
227
+ # Load data using the streaming dataset
228
+ data = prepare_data_loaders(config)
229
+
230
+ # For distributed training, we'll set up samplers later in the distributed worker
231
+ # Don't create DistributedSampler here as process group isn't initialized yet
232
+
233
+ return (
234
+ device,
235
+ data["train_loader"],
236
+ data["val_loader"],
237
+ data["train_cell_mapping"],
238
+ data["val_cell_mapping"],
239
+ data["num_labels_list"],
240
+ )
241
 
242
 
243
+ # Optuna hyperparameter optimization utilities
244
+ def save_trial_callback(study, trial, trials_result_path):
245
+ """
246
+ Callback to save Optuna trial results to a file.
247
+
248
+ Args:
249
+ study: Optuna study object
250
+ trial: Current trial object
251
+ trials_result_path: Path to save trial results
252
+ """
253
+ with open(trials_result_path, "a") as f:
254
+ f.write(
255
+ f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
256
+ )
257
 
 
 
 
 
258
 
259
+ def create_optuna_study(objective, n_trials: int, trials_result_path: str, tensorboard_log_dir: str) -> optuna.Study:
260
+ """Create and run an Optuna study with TensorBoard logging."""
261
+ from optuna.integration import TensorBoardCallback
262
+
263
+ study = optuna.create_study(direction="maximize")
264
+ study.optimize(
265
+ objective,
266
+ n_trials=n_trials,
267
+ callbacks=[
268
+ lambda study, trial: save_trial_callback(study, trial, trials_result_path),
269
+ TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro")
270
+ ]
271
+ )
272
+ return study
273
 
 
 
 
 
 
 
274
 
275
+ @contextmanager
276
+ def setup_logging(config):
277
+ run_name = config.get("run_name", "manual_run")
278
+ log_dir = os.path.join(config["tensorboard_log_dir"], run_name)
279
+ writer = SummaryWriter(log_dir=log_dir)
280
+ try:
281
+ yield writer
282
+ finally:
283
+ writer.close()
284
 
 
 
 
285
 
286
+ def log_training_step(loss, writer, config, epoch, steps_per_epoch, batch_idx):
287
+ """Log training step metrics to TensorBoard and optionally W&B."""
288
+ writer.add_scalar(
289
+ "Training Loss", loss, epoch * steps_per_epoch + batch_idx
290
+ )
291
+ if config.get("use_wandb", False):
292
+ import wandb
293
+ wandb.log({"Training Loss": loss})
294
 
 
295
 
296
+ def log_validation_metrics(task_metrics, val_loss, config, writer, epoch):
297
+ """Log validation metrics to console, TensorBoard, and optionally W&B."""
298
+ for task_name, metrics in task_metrics.items():
299
+ print(
300
+ f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
301
+ )
302
+ if config.get("use_wandb", False):
303
+ import wandb
304
+ wandb.log(
305
+ {
306
+ f"{task_name} Validation F1 Macro": metrics["f1"],
307
+ f"{task_name} Validation Accuracy": metrics["accuracy"],
308
+ }
309
+ )
310
 
311
+ writer.add_scalar("Validation Loss", val_loss, epoch)
312
+ for task_name, metrics in task_metrics.items():
313
+ writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epoch)
314
+ writer.add_scalar(
315
+ f"{task_name} - Validation Accuracy", metrics["accuracy"], epoch
316
+ )
317
+
318
+
319
+ def load_label_mappings(results_dir: str, task_names: List[str]) -> Dict[str, Dict]:
320
+ """Load or initialize task label mappings."""
321
+ label_mappings_path = os.path.join(results_dir, "task_label_mappings_val.pkl")
322
+ if os.path.exists(label_mappings_path):
323
+ with open(label_mappings_path, 'rb') as f:
324
+ return pickle.load(f)
325
+ return {task_name: {} for task_name in task_names}
326
+
327
+
328
+ def create_prediction_row(sample_idx: int, val_cell_indices: Dict, task_true_labels: Dict,
329
+ task_pred_labels: Dict, task_pred_probs: Dict, task_names: List[str],
330
+ inverted_mappings: Dict, val_cell_mapping: Dict) -> Dict:
331
+ """Create a row for validation predictions."""
332
+ batch_cell_idx = val_cell_indices.get(sample_idx)
333
+ cell_id = val_cell_mapping.get(batch_cell_idx, f"unknown_cell_{sample_idx}") if batch_cell_idx is not None else f"unknown_cell_{sample_idx}"
334
+
335
+ row = {"Cell ID": cell_id}
336
+ for task_name in task_names:
337
+ if task_name in task_true_labels and sample_idx < len(task_true_labels[task_name]):
338
+ true_idx = task_true_labels[task_name][sample_idx]
339
+ pred_idx = task_pred_labels[task_name][sample_idx]
340
+ true_label = inverted_mappings.get(task_name, {}).get(true_idx, f"Unknown-{true_idx}")
341
+ pred_label = inverted_mappings.get(task_name, {}).get(pred_idx, f"Unknown-{pred_idx}")
342
+
343
+ row.update({
344
+ f"{task_name}_true_idx": true_idx,
345
+ f"{task_name}_pred_idx": pred_idx,
346
+ f"{task_name}_true_label": true_label,
347
+ f"{task_name}_pred_label": pred_label
348
+ })
349
+
350
+ if task_name in task_pred_probs and sample_idx < len(task_pred_probs[task_name]):
351
+ probs = task_pred_probs[task_name][sample_idx]
352
+ if isinstance(probs, (list, np.ndarray)) or (hasattr(probs, '__iter__') and not isinstance(probs, str)):
353
+ prob_list = list(probs) if not isinstance(probs, list) else probs
354
+ row[f"{task_name}_all_probs"] = ",".join(map(str, prob_list))
355
+ for class_idx, prob in enumerate(prob_list):
356
+ class_label = inverted_mappings.get(task_name, {}).get(class_idx, f"Unknown-{class_idx}")
357
+ row[f"{task_name}_prob_{class_label}"] = prob
358
+ else:
359
+ row[f"{task_name}_all_probs"] = str(probs)
360
+
361
+ return row
362
+
363
+
364
+ def save_validation_predictions(
365
+ val_cell_indices,
366
+ task_true_labels,
367
+ task_pred_labels,
368
+ task_pred_probs,
369
+ config,
370
+ trial_number=None,
371
+ ):
372
+ """Save validation predictions to a CSV file with class labels and probabilities."""
373
+ os.makedirs(config["results_dir"], exist_ok=True)
374
+
375
+ if trial_number is not None:
376
+ os.makedirs(os.path.join(config["results_dir"], f"trial_{trial_number}"), exist_ok=True)
377
+ val_preds_file = os.path.join(config["results_dir"], f"trial_{trial_number}/val_preds.csv")
378
+ else:
379
+ val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
380
+
381
+ if not val_cell_indices or not task_true_labels:
382
+ pd.DataFrame().to_csv(val_preds_file, index=False)
383
+ return
384
+
385
+ try:
386
+ label_mappings = load_label_mappings(config["results_dir"], config["task_names"])
387
+ inverted_mappings = {task: {idx: label for label, idx in mapping.items()} for task, mapping in label_mappings.items()}
388
+ val_cell_mapping = config.get("val_cell_mapping", {})
389
+
390
+ # Determine maximum number of samples
391
+ max_samples = max(
392
+ [len(val_cell_indices)] +
393
+ [len(task_true_labels[task]) for task in task_true_labels]
394
+ )
395
+
396
+ rows = [
397
+ create_prediction_row(
398
+ sample_idx, val_cell_indices, task_true_labels, task_pred_labels,
399
+ task_pred_probs, config["task_names"], inverted_mappings, val_cell_mapping
400
+ )
401
+ for sample_idx in range(max_samples)
402
+ ]
403
+
404
+ pd.DataFrame(rows).to_csv(val_preds_file, index=False)
405
+ except Exception as e:
406
+ pd.DataFrame([{"Error": str(e)}]).to_csv(val_preds_file, index=False)
407
+
408
+
409
+ def setup_distributed_environment(rank, world_size, config):
410
  """
411
+ Setup the distributed training environment.
412
+
413
  Args:
414
+ rank (int): The rank of the current process
415
+ world_size (int): Total number of processes
416
+ config (dict): Configuration dictionary
417
  """
418
+ os.environ['MASTER_ADDR'] = config.get('master_addr', 'localhost')
419
+ os.environ['MASTER_PORT'] = config.get('master_port', '12355')
420
+
421
+ # Initialize the process group
422
+ dist.init_process_group(
423
+ backend='nccl',
424
+ init_method='env://',
425
+ world_size=world_size,
426
+ rank=rank
427
+ )
428
+
429
+ # Set the device for this process
430
+ torch.cuda.set_device(rank)
431
+
432
+
433
+ def train_distributed(trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
434
+ """Run distributed training across multiple GPUs with fallback to single GPU."""
435
+ world_size = torch.cuda.device_count()
436
+
437
+ if world_size <= 1:
438
+ print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
439
+ config["distributed_training"] = False
440
+ trainer = trainer_class(config)
441
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
442
+ trainer.device = device
443
+ train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
444
+ train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
445
+ )
446
+ val_loss, model = trainer.train(
447
+ train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
448
+ )
449
+ model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
450
+ save_model(model, model_save_directory)
451
+ save_hyperparameters(model_save_directory, {
452
+ **get_config_value(config, "manual_hyperparameters", {}),
453
+ "dropout_rate": config["dropout_rate"],
454
+ "use_task_weights": config["use_task_weights"],
455
+ "task_weights": config["task_weights"],
456
+ "max_layers_to_freeze": config["max_layers_to_freeze"],
457
+ "use_attention_pooling": config["use_attention_pooling"],
458
+ })
459
+
460
+ if shared_dict is not None:
461
+ shared_dict['val_loss'] = val_loss
462
+ task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(model, val_loader, device, config)
463
+ shared_dict['task_metrics'] = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
464
+ shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model.state_dict().items()}
465
+
466
+ return val_loss, model
467
+
468
+ print(f"Using distributed training with {world_size} GPUs")
469
+ mp.spawn(
470
+ _distributed_worker,
471
+ args=(world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number, shared_dict),
472
+ nprocs=world_size,
473
+ join=True
474
+ )
475
+
476
+ if trial_number is None and shared_dict is None:
477
+ model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
478
+ model_path = os.path.join(model_save_directory, "pytorch_model.bin")
479
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
480
+ model = create_model(config, num_labels_list, device)
481
+ model.load_state_dict(torch.load(model_path))
482
+ return 0.0, model
483
+
484
+ return None
485
+
486
+
487
+ def _distributed_worker(rank, world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
488
+ """Worker function for distributed training."""
489
+ setup_distributed_environment(rank, world_size, config)
490
+ config["local_rank"] = rank
491
+
492
+ # Set up distributed samplers
493
+ from torch.utils.data import DistributedSampler
494
+ from .data import get_data_loader
495
+
496
+ train_sampler = DistributedSampler(train_loader.dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
497
+ val_sampler = DistributedSampler(val_loader.dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
498
+
499
+ train_loader = get_data_loader(train_loader.dataset, config["batch_size"], sampler=train_sampler, shuffle=False)
500
+ val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=val_sampler, shuffle=False)
501
+
502
+ if rank == 0:
503
+ print(f"Rank {rank}: Training {len(train_sampler)} samples, Validation {len(val_sampler)} samples")
504
+ print(f"Total samples across {world_size} GPUs: Training {len(train_sampler) * world_size}, Validation {len(val_sampler) * world_size}")
505
+
506
+ # Create and setup trainer
507
+ trainer = trainer_class(config)
508
+ train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
509
+ train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
510
+ )
511
+
512
+ # Train the model
513
+ val_loss, model = trainer.train(
514
+ train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
515
+ )
516
+
517
+ # Save model only from the main process
518
+ if rank == 0:
519
+ model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
520
+ save_model(model, model_save_directory)
521
+
522
+ save_hyperparameters(model_save_directory, {
523
+ **get_config_value(config, "manual_hyperparameters", {}),
524
+ "dropout_rate": config["dropout_rate"],
525
+ "use_task_weights": config["use_task_weights"],
526
+ "task_weights": config["task_weights"],
527
+ "max_layers_to_freeze": config["max_layers_to_freeze"],
528
+ "use_attention_pooling": config["use_attention_pooling"],
529
+ })
530
+
531
+ # For Optuna trials, store results in shared dictionary
532
+ if shared_dict is not None:
533
+ shared_dict['val_loss'] = val_loss
534
+
535
+ # Run validation on full dataset from rank 0 for consistent metrics
536
+ full_val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=None, shuffle=False)
537
+
538
+ # Get validation predictions using our utility function
539
+ task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(
540
+ model, full_val_loader, trainer.device, config
541
+ )
542
+
543
+ # Calculate metrics
544
+ task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
545
+ shared_dict['task_metrics'] = task_metrics
546
+
547
+ # Store model state dict
548
+ if isinstance(model, DDP):
549
+ model_state_dict = model.module.state_dict()
550
+ else:
551
+ model_state_dict = model.state_dict()
552
+
553
+ shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model_state_dict.items()}
554
+
555
+ # Clean up distributed environment
556
+ dist.destroy_process_group()
557
+
558
+
559
+ def save_model_without_heads(model_directory):
560
+ """
561
+ Save a version of the fine-tuned model without classification heads.
562
+
563
+ Args:
564
+ model_directory (str): Path to the directory containing the fine-tuned model
565
+ """
566
+ import torch
567
+ from transformers import BertConfig, BertModel
568
+
569
+ # Load the full model
570
+ model_path = os.path.join(model_directory, "pytorch_model.bin")
571
+ config_path = os.path.join(model_directory, "config.json")
572
+
573
+ if not os.path.exists(model_path) or not os.path.exists(config_path):
574
+ raise FileNotFoundError(f"Model files not found in {model_directory}")
575
+
576
+ # Load the configuration
577
+ config = BertConfig.from_json_file(config_path)
578
+
579
+ # Load the model state dict
580
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
581
+
582
+ # Create a new model without heads
583
+ base_model = BertModel(config)
584
+
585
+ # Filter out the classification head parameters
586
+ base_model_state_dict = {}
587
+ for key, value in state_dict.items():
588
+ # Only keep parameters that belong to the base model (not classification heads)
589
+ if not key.startswith('classification_heads') and not key.startswith('attention_pool'):
590
+ base_model_state_dict[key] = value
591
+
592
+ # Load the filtered state dict into the base model
593
+ base_model.load_state_dict(base_model_state_dict, strict=False)
594
+
595
+ # Save the model without heads
596
+ output_dir = os.path.join(model_directory, "model_without_heads")
597
+ os.makedirs(output_dir, exist_ok=True)
598
+
599
+ # Save the model weights
600
+ torch.save(base_model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
601
+
602
+ # Save the configuration
603
+ base_model.config.to_json_file(os.path.join(output_dir, "config.json"))
604
+
605
+ print(f"Model without classification heads saved to {output_dir}")
606
+ return output_dir
607
+
608
+
609
+ def get_config_value(config: Dict, key: str, default=None):
610
+
611
+ return config.get(key, default)
612
+
613
+
614
+ def collect_validation_predictions(model, val_loader, device, config) -> tuple:
615
+ task_true_labels = {}
616
+ task_pred_labels = {}
617
+ task_pred_probs = {}
618
+
619
+ with torch.no_grad():
620
+ for batch in val_loader:
621
+ input_ids = batch["input_ids"].to(device)
622
+ attention_mask = batch["attention_mask"].to(device)
623
+ labels = [batch["labels"][task_name].to(device) for task_name in config["task_names"]]
624
+ _, logits, _ = model(input_ids, attention_mask, labels)
625
+
626
+ for sample_idx in range(len(batch["input_ids"])):
627
+ for i, task_name in enumerate(config["task_names"]):
628
+ if task_name not in task_true_labels:
629
+ task_true_labels[task_name] = []
630
+ task_pred_labels[task_name] = []
631
+ task_pred_probs[task_name] = []
632
+
633
+ true_label = batch["labels"][task_name][sample_idx].item()
634
+ pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
635
+ pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
636
+
637
+ task_true_labels[task_name].append(true_label)
638
+ task_pred_labels[task_name].append(pred_label)
639
+ task_pred_probs[task_name].append(pred_prob)
640
+
641
+ return task_true_labels, task_pred_labels, task_pred_probs
geneformer/mtl_classifier.py CHANGED
@@ -29,7 +29,8 @@ Geneformer multi-task cell classifier.
29
  import logging
30
  import os
31
 
32
- from .mtl import eval_utils, train_utils, utils
 
33
 
34
  logger = logging.getLogger(__name__)
35
 
@@ -49,7 +50,9 @@ class MTLClassifier:
49
  "max_layers_to_freeze": {None, dict},
50
  "epochs": {None, int},
51
  "tensorboard_log_dir": {None, str},
52
- "use_data_parallel": {None, bool},
 
 
53
  "use_attention_pooling": {None, bool},
54
  "use_task_weights": {None, bool},
55
  "hyperparameters": {None, dict},
@@ -61,6 +64,7 @@ class MTLClassifier:
61
  "max_grad_norm": {None, int, float},
62
  "seed": {None, int},
63
  "trials_result_path": {None, str},
 
64
  }
65
 
66
  def __init__(
@@ -79,7 +83,9 @@ class MTLClassifier:
79
  max_layers_to_freeze=None,
80
  epochs=1,
81
  tensorboard_log_dir="/results/tblogdir",
82
- use_data_parallel=False,
 
 
83
  use_attention_pooling=True,
84
  use_task_weights=True,
85
  hyperparameters=None, # Default is None
@@ -89,6 +95,7 @@ class MTLClassifier:
89
  wandb_project=None,
90
  gradient_clipping=False,
91
  max_grad_norm=None,
 
92
  seed=42, # Default seed value
93
  ):
94
  """
@@ -117,8 +124,12 @@ class MTLClassifier:
117
  | Path to directory to save results
118
  tensorboard_log_dir : None, str
119
  | Path to directory for Tensorboard logging results
120
- use_data_parallel : None, bool
121
- | Whether to use data parallelization
 
 
 
 
122
  use_attention_pooling : None, bool
123
  | Whether to use attention pooling
124
  use_task_weights : None, bool
@@ -150,6 +161,8 @@ class MTLClassifier:
150
  | Whether to use gradient clipping
151
  max_grad_norm : None, int, float
152
  | Maximum norm for gradient clipping
 
 
153
  seed : None, int
154
  | Random seed
155
  """
@@ -165,6 +178,7 @@ class MTLClassifier:
165
  self.batch_size = batch_size
166
  self.n_trials = n_trials
167
  self.study_name = study_name
 
168
 
169
  if max_layers_to_freeze is None:
170
  # Dynamically determine the range of layers to freeze
@@ -175,7 +189,9 @@ class MTLClassifier:
175
 
176
  self.epochs = epochs
177
  self.tensorboard_log_dir = tensorboard_log_dir
178
- self.use_data_parallel = use_data_parallel
 
 
179
  self.use_attention_pooling = use_attention_pooling
180
  self.use_task_weights = use_task_weights
181
  self.hyperparameters = (
@@ -293,7 +309,7 @@ class MTLClassifier:
293
  self.config["manual_hyperparameters"] = self.manual_hyperparameters
294
  self.config["use_manual_hyperparameters"] = True
295
 
296
- train_utils.run_manual_tuning(self.config)
297
 
298
  def validate_additional_options(self, req_var_dict):
299
  missing_variable = False
@@ -330,7 +346,7 @@ class MTLClassifier:
330
  req_var_dict = dict(zip(required_variable_names, required_variables))
331
  self.validate_additional_options(req_var_dict)
332
 
333
- train_utils.run_optuna_study(self.config)
334
 
335
  def load_and_evaluate_test_model(
336
  self,
 
29
  import logging
30
  import os
31
 
32
+ from .mtl import eval_utils, utils
33
+ from .mtl.train import run_manual_tuning, run_optuna_study
34
 
35
  logger = logging.getLogger(__name__)
36
 
 
50
  "max_layers_to_freeze": {None, dict},
51
  "epochs": {None, int},
52
  "tensorboard_log_dir": {None, str},
53
+ "distributed_training": {None, bool},
54
+ "master_addr": {None, str},
55
+ "master_port": {None, str},
56
  "use_attention_pooling": {None, bool},
57
  "use_task_weights": {None, bool},
58
  "hyperparameters": {None, dict},
 
64
  "max_grad_norm": {None, int, float},
65
  "seed": {None, int},
66
  "trials_result_path": {None, str},
67
+ "gradient_accumulation_steps": {None, int},
68
  }
69
 
70
  def __init__(
 
83
  max_layers_to_freeze=None,
84
  epochs=1,
85
  tensorboard_log_dir="/results/tblogdir",
86
+ distributed_training=False,
87
+ master_addr="localhost",
88
+ master_port="12355",
89
  use_attention_pooling=True,
90
  use_task_weights=True,
91
  hyperparameters=None, # Default is None
 
95
  wandb_project=None,
96
  gradient_clipping=False,
97
  max_grad_norm=None,
98
+ gradient_accumulation_steps=1, # Add this line with default value 1
99
  seed=42, # Default seed value
100
  ):
101
  """
 
124
  | Path to directory to save results
125
  tensorboard_log_dir : None, str
126
  | Path to directory for Tensorboard logging results
127
+ distributed_training : None, bool
128
+ | Whether to use distributed data parallel training across multiple GPUs
129
+ master_addr : None, str
130
+ | Master address for distributed training (default: localhost)
131
+ master_port : None, str
132
+ | Master port for distributed training (default: 12355)
133
  use_attention_pooling : None, bool
134
  | Whether to use attention pooling
135
  use_task_weights : None, bool
 
161
  | Whether to use gradient clipping
162
  max_grad_norm : None, int, float
163
  | Maximum norm for gradient clipping
164
+ gradient_accumulation_steps : None, int
165
+ | Number of steps to accumulate gradients before performing a backward/update pass
166
  seed : None, int
167
  | Random seed
168
  """
 
178
  self.batch_size = batch_size
179
  self.n_trials = n_trials
180
  self.study_name = study_name
181
+ self.gradient_accumulation_steps = gradient_accumulation_steps
182
 
183
  if max_layers_to_freeze is None:
184
  # Dynamically determine the range of layers to freeze
 
189
 
190
  self.epochs = epochs
191
  self.tensorboard_log_dir = tensorboard_log_dir
192
+ self.distributed_training = distributed_training
193
+ self.master_addr = master_addr
194
+ self.master_port = master_port
195
  self.use_attention_pooling = use_attention_pooling
196
  self.use_task_weights = use_task_weights
197
  self.hyperparameters = (
 
309
  self.config["manual_hyperparameters"] = self.manual_hyperparameters
310
  self.config["use_manual_hyperparameters"] = True
311
 
312
+ run_manual_tuning(self.config)
313
 
314
  def validate_additional_options(self, req_var_dict):
315
  missing_variable = False
 
346
  req_var_dict = dict(zip(required_variable_names, required_variables))
347
  self.validate_additional_options(req_var_dict)
348
 
349
+ run_optuna_study(self.config)
350
 
351
  def load_and_evaluate_test_model(
352
  self,