|
|
|
|
|
import argparse |
|
import os |
|
os.sys.path += ['expman'] |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
import matplotlib.pyplot as plt |
|
import math |
|
import numpy as np |
|
import pandas as pd |
|
import tensorflow as tf |
|
import tensorflowjs as tfjs |
|
from tensorflow.keras import backend as K |
|
from tensorflow.keras.models import load_model |
|
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import classification_report, roc_curve, auc, precision_recall_curve, average_precision_score |
|
from adabelief_tf import AdaBeliefOptimizer |
|
from tqdm.keras import TqdmCallback |
|
from tqdm import tqdm |
|
from functools import partial |
|
|
|
from dataloader import get_loader, load_datasets, validate_data_files |
|
from models.unet import build_model |
|
from utils import visualize |
|
from expman import Experiment |
|
import evaluate |
|
|
|
def boundary_loss(y_true, y_pred): |
|
"""Additional loss focusing on boundaries""" |
|
y_true = tf.cast(y_true, tf.float32) |
|
y_pred = tf.cast(y_pred, tf.float32) |
|
|
|
dy_true, dx_true = tf.image.image_gradients(y_true) |
|
dy_pred, dx_pred = tf.image.image_gradients(y_pred) |
|
|
|
|
|
loss = tf.reduce_mean(tf.abs(dy_pred - dy_true) + tf.abs(dx_pred - dx_true)) |
|
return loss * 0.5 |
|
|
|
def enhanced_binary_crossentropy(y_true, y_pred): |
|
"""Combine standard BCE with boundary loss""" |
|
y_true = tf.cast(y_true, tf.float32) |
|
y_pred = tf.cast(y_pred, tf.float32) |
|
bce = tf.keras.losses.binary_crossentropy(y_true, y_pred) |
|
boundary = boundary_loss(y_true, y_pred) |
|
return bce + boundary |
|
|
|
def cosine_decay_with_warmup(epoch, total_epochs, warmup_epochs=5, initial_lr=0.001): |
|
if epoch < warmup_epochs: |
|
|
|
return initial_lr * (epoch + 1) / warmup_epochs |
|
|
|
progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs) |
|
return initial_lr * (1 + math.cos(math.pi * progress)) / 2 |
|
|
|
def main(args): |
|
try: |
|
|
|
for data_dir in args.data: |
|
if not os.path.exists(data_dir): |
|
raise FileNotFoundError(f"Data directory not found: {data_dir}") |
|
|
|
exp = Experiment(args, ignore=('epochs', 'resume')) |
|
print(exp) |
|
|
|
np.random.seed(args.seed) |
|
tf.random.set_seed(args.seed) |
|
|
|
data = load_datasets(args.data) |
|
if len(data) == 0: |
|
raise ValueError("No valid data found after loading datasets") |
|
|
|
|
|
validate_data_files(data) |
|
|
|
|
|
if args.split == 'subjects': |
|
val_subjects = (6, 9, 11, 13, 16, 28, 30, 48, 49) |
|
test_subjects = (3, 4, 19, 38, 45, 46, 51, 52) |
|
train_data = data[~data['sub'].isin(val_subjects + test_subjects)] |
|
val_data = data[data['sub'].isin(val_subjects)] |
|
test_data = data[data['sub'].isin(test_subjects)] |
|
|
|
elif args.split == 'random': |
|
train_data, valtest_data = train_test_split(data, test_size=.3, shuffle=True) |
|
val_data, test_data = train_test_split(valtest_data, test_size=.33) |
|
|
|
lengths = map(len, (data, train_data, val_data, test_data)) |
|
print("Total: {} - Train / Val / Test: {} / {} / {}".format(*lengths)) |
|
|
|
x_shape = (args.resolution, args.resolution, 1) |
|
y_shape = (args.resolution, args.resolution, 1) |
|
|
|
train_gen, _ = get_loader(train_data, batch_size=args.batch_size, shuffle=True, augment=True, x_shape=x_shape) |
|
val_gen, val_categories = get_loader(val_data, batch_size=args.batch_size, x_shape=x_shape) |
|
|
|
log = exp.path_to('log.csv') |
|
|
|
|
|
best_weights_path = exp.path_to('best_weights.weights.h5') |
|
best_mask_weights_path = exp.path_to('best_weights_mask.weights.h5') |
|
|
|
|
|
best_ckpt_path = exp.path_to('best_model.keras') |
|
last_ckpt_path = exp.path_to('last_model.keras') |
|
|
|
if args.resume and os.path.exists(last_ckpt_path): |
|
custom_objects = { |
|
'iou_coef': evaluate.iou_coef, |
|
'dice_coef': evaluate.dice_coef, |
|
'enhanced_binary_crossentropy': enhanced_binary_crossentropy, |
|
'boundary_loss': boundary_loss |
|
} |
|
model = tf.keras.models.load_model(last_ckpt_path, custom_objects=custom_objects) |
|
optimizer = model.optimizer |
|
initial_epoch = len(pd.read_csv(log)) if os.path.exists(log) else 0 |
|
else: |
|
config = vars(args) |
|
model = build_model(x_shape, y_shape, config) |
|
|
|
|
|
optimizer = tf.keras.optimizers.Adam( |
|
learning_rate=float(args.lr), |
|
beta_1=0.9, |
|
beta_2=0.999, |
|
epsilon=1e-7 |
|
) |
|
initial_epoch = 0 |
|
|
|
model.compile( |
|
optimizer=optimizer, |
|
loss={ |
|
'mask': enhanced_binary_crossentropy, |
|
'tags': 'binary_crossentropy' |
|
}, |
|
metrics={ |
|
'mask': [evaluate.iou_coef, evaluate.dice_coef], |
|
'tags': 'binary_accuracy' |
|
} |
|
) |
|
|
|
model_stopped_file = exp.path_to('early_stopped.txt') |
|
need_training = not os.path.exists(model_stopped_file) and initial_epoch < args.epochs |
|
|
|
if need_training: |
|
lr_schedule = partial(cosine_decay_with_warmup, |
|
total_epochs=args.epochs, |
|
warmup_epochs=5, |
|
initial_lr=args.lr) |
|
|
|
best_checkpointer = ModelCheckpoint( |
|
best_weights_path, |
|
monitor='val_loss', |
|
save_best_only=True, |
|
save_weights_only=True, |
|
mode='min' |
|
) |
|
|
|
best_mask_checkpointer = ModelCheckpoint( |
|
best_mask_weights_path, |
|
monitor='val_mask_dice_coef', |
|
mode='max', |
|
save_best_only=True, |
|
save_weights_only=True |
|
) |
|
|
|
last_checkpointer = ModelCheckpoint( |
|
last_ckpt_path, |
|
save_best_only=False, |
|
save_weights_only=False |
|
) |
|
|
|
logger = CSVLogger(log, append=args.resume) |
|
progress = TqdmCallback(verbose=1, initial=initial_epoch, dynamic_ncols=True) |
|
|
|
early_stop = tf.keras.callbacks.EarlyStopping( |
|
monitor='val_mask_dice_coef', |
|
mode='max', |
|
patience=100, |
|
restore_best_weights=True |
|
) |
|
|
|
lr_scheduler = LearningRateScheduler(lr_schedule) |
|
|
|
callbacks = [ |
|
best_checkpointer, |
|
best_mask_checkpointer, |
|
last_checkpointer, |
|
logger, |
|
progress, |
|
early_stop, |
|
lr_scheduler |
|
] |
|
|
|
try: |
|
model.fit( |
|
train_gen, |
|
epochs=args.epochs, |
|
callbacks=callbacks, |
|
initial_epoch=initial_epoch, |
|
steps_per_epoch=len(train_gen), |
|
validation_data=val_gen, |
|
validation_steps=len(val_gen), |
|
verbose=False |
|
) |
|
except Exception as e: |
|
print(f"Training failed: {str(e)}") |
|
raise |
|
|
|
if model.stop_training: |
|
open(model_stopped_file, 'w').close() |
|
|
|
|
|
best_ckpt_path = exp.path_to('best_model.keras') |
|
tf.keras.models.save_model(model, best_ckpt_path, include_optimizer=False) |
|
|
|
|
|
evaluate.evaluate(exp, force=need_training) |
|
|
|
|
|
model.load_weights(best_mask_weights_path) |
|
best_savedmodel_path = exp.path_to('best_savedmodel') |
|
model.save(best_savedmodel_path, save_traces=True) |
|
|
|
|
|
tfjs_model_dir = exp.path_to('tfjs') |
|
tfjs.converters.save_keras_model(model, tfjs_model_dir) |
|
else: |
|
print("No training needed, model already exists and training completed.") |
|
|
|
evaluate.evaluate(exp, force=False) |
|
|
|
except Exception as e: |
|
print(f"Error in main: {str(e)}") |
|
raise |
|
|
|
if __name__ == '__main__': |
|
default_data = ['data/NN_human_mouse_eyes'] |
|
|
|
parser = argparse.ArgumentParser(description='MEye Training Script') |
|
|
|
parser.add_argument('-d', '--data', nargs='+', default=default_data, help='Data directory (may be multiple)') |
|
parser.add_argument('--split', default='random', choices=('random', 'subjects'), help='How to split data') |
|
parser.add_argument('-r', '--resolution', type=int, default=128, help='Input image resolution') |
|
|
|
|
|
parser.add_argument('--num-stages', type=int, default=5, help='number of down-up sample stages') |
|
parser.add_argument('--num-conv', type=int, default=1, help='number of convolutions per stage') |
|
parser.add_argument('--num-filters', type=int, default=16, help='number of conv filter at first stage') |
|
parser.add_argument('--grow-factor', type=float, default=1.5, |
|
help='# filters at stage i = num-filters * grow-factor ** i') |
|
parser.add_argument('--up-activation', default='relu', choices=('relu', 'lrelu'), |
|
help='activation in upsample stages') |
|
parser.add_argument('--conv-type', default='conv', choices=('conv', 'bn-conv', 'sep-conv', 'sep-bn-conv'), |
|
help='convolution type') |
|
parser.add_argument('--use-aspp', default=False, action='store_true', help='Use Atrous Spatial Pyramid Pooling') |
|
|
|
|
|
parser.add_argument('--lr', type=float, default=0.001, help='learning rate') |
|
parser.add_argument('-b', '--batch-size', type=int, default=32, help='Batch size') |
|
parser.add_argument('-e', '--epochs', type=int, default=1500, help='Number of training epochs') |
|
parser.add_argument('-s', '--seed', type=int, default=23, help='Random seed') |
|
parser.add_argument('--resume', default=False, action='store_true', help='Resume training') |
|
|
|
args = parser.parse_args() |
|
main(args) |