pupil_repo / train.py
g30rv17ys's picture
Add files using upload-large-folder tool
fd4b932 verified
# train.py
import argparse
import os
os.sys.path += ['expman']
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import math
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflowjs as tfjs
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_curve, auc, precision_recall_curve, average_precision_score
from adabelief_tf import AdaBeliefOptimizer
from tqdm.keras import TqdmCallback
from tqdm import tqdm
from functools import partial
from dataloader import get_loader, load_datasets, validate_data_files
from models.unet import build_model
from utils import visualize
from expman import Experiment
import evaluate
def boundary_loss(y_true, y_pred):
"""Additional loss focusing on boundaries"""
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
# Compute gradients
dy_true, dx_true = tf.image.image_gradients(y_true)
dy_pred, dx_pred = tf.image.image_gradients(y_pred)
# Compute boundary loss
loss = tf.reduce_mean(tf.abs(dy_pred - dy_true) + tf.abs(dx_pred - dx_true))
return loss * 0.5
def enhanced_binary_crossentropy(y_true, y_pred):
"""Combine standard BCE with boundary loss"""
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
boundary = boundary_loss(y_true, y_pred)
return bce + boundary
def cosine_decay_with_warmup(epoch, total_epochs, warmup_epochs=5, initial_lr=0.001):
if epoch < warmup_epochs:
# Linear warmup
return initial_lr * (epoch + 1) / warmup_epochs
# Cosine decay after warmup
progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
return initial_lr * (1 + math.cos(math.pi * progress)) / 2
def main(args):
try:
# Verify data directories exist
for data_dir in args.data:
if not os.path.exists(data_dir):
raise FileNotFoundError(f"Data directory not found: {data_dir}")
exp = Experiment(args, ignore=('epochs', 'resume'))
print(exp)
np.random.seed(args.seed)
tf.random.set_seed(args.seed)
data = load_datasets(args.data)
if len(data) == 0:
raise ValueError("No valid data found after loading datasets")
# Validate all files exist
validate_data_files(data)
# TRAIN/VAL/TEST SPLIT
if args.split == 'subjects': # by SUBJECTS
val_subjects = (6, 9, 11, 13, 16, 28, 30, 48, 49)
test_subjects = (3, 4, 19, 38, 45, 46, 51, 52)
train_data = data[~data['sub'].isin(val_subjects + test_subjects)]
val_data = data[data['sub'].isin(val_subjects)]
test_data = data[data['sub'].isin(test_subjects)]
elif args.split == 'random': # 70-20-10 %
train_data, valtest_data = train_test_split(data, test_size=.3, shuffle=True)
val_data, test_data = train_test_split(valtest_data, test_size=.33)
lengths = map(len, (data, train_data, val_data, test_data))
print("Total: {} - Train / Val / Test: {} / {} / {}".format(*lengths))
x_shape = (args.resolution, args.resolution, 1)
y_shape = (args.resolution, args.resolution, 1)
train_gen, _ = get_loader(train_data, batch_size=args.batch_size, shuffle=True, augment=True, x_shape=x_shape)
val_gen, val_categories = get_loader(val_data, batch_size=args.batch_size, x_shape=x_shape)
log = exp.path_to('log.csv')
# weights_only checkpoints
best_weights_path = exp.path_to('best_weights.weights.h5')
best_mask_weights_path = exp.path_to('best_weights_mask.weights.h5')
# whole model checkpoints
best_ckpt_path = exp.path_to('best_model.keras')
last_ckpt_path = exp.path_to('last_model.keras')
if args.resume and os.path.exists(last_ckpt_path):
custom_objects = {
'iou_coef': evaluate.iou_coef,
'dice_coef': evaluate.dice_coef,
'enhanced_binary_crossentropy': enhanced_binary_crossentropy,
'boundary_loss': boundary_loss
}
model = tf.keras.models.load_model(last_ckpt_path, custom_objects=custom_objects)
optimizer = model.optimizer
initial_epoch = len(pd.read_csv(log)) if os.path.exists(log) else 0
else:
config = vars(args)
model = build_model(x_shape, y_shape, config)
# Use Adam optimizer
optimizer = tf.keras.optimizers.Adam(
learning_rate=float(args.lr),
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7
)
initial_epoch = 0
model.compile(
optimizer=optimizer,
loss={
'mask': enhanced_binary_crossentropy,
'tags': 'binary_crossentropy'
},
metrics={
'mask': [evaluate.iou_coef, evaluate.dice_coef],
'tags': 'binary_accuracy'
}
)
model_stopped_file = exp.path_to('early_stopped.txt')
need_training = not os.path.exists(model_stopped_file) and initial_epoch < args.epochs
if need_training:
lr_schedule = partial(cosine_decay_with_warmup,
total_epochs=args.epochs,
warmup_epochs=5,
initial_lr=args.lr)
best_checkpointer = ModelCheckpoint(
best_weights_path,
monitor='val_loss',
save_best_only=True,
save_weights_only=True,
mode='min'
)
best_mask_checkpointer = ModelCheckpoint(
best_mask_weights_path,
monitor='val_mask_dice_coef',
mode='max',
save_best_only=True,
save_weights_only=True
)
last_checkpointer = ModelCheckpoint(
last_ckpt_path,
save_best_only=False,
save_weights_only=False
)
logger = CSVLogger(log, append=args.resume)
progress = TqdmCallback(verbose=1, initial=initial_epoch, dynamic_ncols=True)
early_stop = tf.keras.callbacks.EarlyStopping(
monitor='val_mask_dice_coef',
mode='max',
patience=100,
restore_best_weights=True
)
lr_scheduler = LearningRateScheduler(lr_schedule)
callbacks = [
best_checkpointer,
best_mask_checkpointer,
last_checkpointer,
logger,
progress,
early_stop,
lr_scheduler
]
try:
model.fit(
train_gen,
epochs=args.epochs,
callbacks=callbacks,
initial_epoch=initial_epoch,
steps_per_epoch=len(train_gen),
validation_data=val_gen,
validation_steps=len(val_gen),
verbose=False
)
except Exception as e:
print(f"Training failed: {str(e)}")
raise
if model.stop_training:
open(model_stopped_file, 'w').close()
# Save the model in .keras format
best_ckpt_path = exp.path_to('best_model.keras')
tf.keras.models.save_model(model, best_ckpt_path, include_optimizer=False)
# Only evaluate if training was successful
evaluate.evaluate(exp, force=need_training)
# save best snapshot in SavedModel format
model.load_weights(best_mask_weights_path)
best_savedmodel_path = exp.path_to('best_savedmodel')
model.save(best_savedmodel_path, save_traces=True)
# export to tfjs (Layers model)
tfjs_model_dir = exp.path_to('tfjs')
tfjs.converters.save_keras_model(model, tfjs_model_dir)
else:
print("No training needed, model already exists and training completed.")
# Optionally evaluate existing model
evaluate.evaluate(exp, force=False)
except Exception as e:
print(f"Error in main: {str(e)}")
raise
if __name__ == '__main__':
default_data = ['data/NN_human_mouse_eyes']
parser = argparse.ArgumentParser(description='MEye Training Script')
# data params
parser.add_argument('-d', '--data', nargs='+', default=default_data, help='Data directory (may be multiple)')
parser.add_argument('--split', default='random', choices=('random', 'subjects'), help='How to split data')
parser.add_argument('-r', '--resolution', type=int, default=128, help='Input image resolution')
# model params
parser.add_argument('--num-stages', type=int, default=5, help='number of down-up sample stages')
parser.add_argument('--num-conv', type=int, default=1, help='number of convolutions per stage')
parser.add_argument('--num-filters', type=int, default=16, help='number of conv filter at first stage')
parser.add_argument('--grow-factor', type=float, default=1.5,
help='# filters at stage i = num-filters * grow-factor ** i')
parser.add_argument('--up-activation', default='relu', choices=('relu', 'lrelu'),
help='activation in upsample stages')
parser.add_argument('--conv-type', default='conv', choices=('conv', 'bn-conv', 'sep-conv', 'sep-bn-conv'),
help='convolution type')
parser.add_argument('--use-aspp', default=False, action='store_true', help='Use Atrous Spatial Pyramid Pooling')
# train params
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('-b', '--batch-size', type=int, default=32, help='Batch size')
parser.add_argument('-e', '--epochs', type=int, default=1500, help='Number of training epochs')
parser.add_argument('-s', '--seed', type=int, default=23, help='Random seed')
parser.add_argument('--resume', default=False, action='store_true', help='Resume training')
args = parser.parse_args()
main(args)