# -*- coding: utf-8 -*- """ MEye: Semantic Segmentation """ import argparse import os os.sys.path += ['expman', 'models/deeplab'] import expman import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import pandas as pd import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras.models import load_model from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, roc_curve, auc, precision_recall_curve, average_precision_score from adabelief_tf import AdaBeliefOptimizer from glob import glob from tqdm import tqdm from PIL import Image from deeplabv3p.models.deeplabv3p_mobilenetv3 import hard_swish from dataloader import get_loader, load_datasets from utils import visualize, visualizable from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph def iou_coef(y_true, y_pred): y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) intersection = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3]) union = K.sum(y_true, axis=[1, 2, 3]) + K.sum(y_pred, axis=[1, 2, 3]) - intersection return K.mean((intersection + 1e-6) / (union + 1e-6)) def dice_coef(y_true, y_pred): y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) intersection = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3]) return K.mean((2. * intersection + 1e-6) / (K.sum(y_true, axis=[1, 2, 3]) + K.sum(y_pred, axis=[1, 2, 3]) + 1e-6)) def boundary_loss(y_true, y_pred): y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) dy_true, dx_true = tf.image.image_gradients(y_true) dy_pred, dx_pred = tf.image.image_gradients(y_pred) loss = tf.reduce_mean(tf.abs(dy_pred - dy_true) + tf.abs(dx_pred - dx_true)) return loss * 0.5 def enhanced_binary_crossentropy(y_true, y_pred): y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) bce = tf.keras.losses.binary_crossentropy(y_true, y_pred) boundary = boundary_loss(y_true, y_pred) return bce + boundary def _filter_by_closeness(a, eps=10e-3): keep = [] prev = np.array([-1, -1]) for row in a.drop('thr', axis=1).values: if (np.abs(prev - row) > eps).any(): keep.append(True) prev = row else: keep.append(False) return a[keep] def _weighted_roc_pr(y_true, y_scores, label, outdir, simplify=False): # Convert to binary labels using 0.5 threshold for true values y_true_binary = (y_true > 0.5).astype(int) npos = y_true_binary.sum() nneg = len(y_true_binary) - npos pos_weight = nneg / npos if npos > 0 else 1.0 print(label, 'Tot:', len(y_true), 'P:', npos, 'N:', nneg, 'N/P:', pos_weight) sample_weight = np.where(y_true_binary, pos_weight, 1) fpr, tpr, thr = roc_curve(y_true_binary, y_scores, sample_weight=sample_weight) auc_score = auc(fpr, tpr) print(label, 'AuROC:', auc_score) roc_metrics = pd.Series({'npos': npos, 'nneg': nneg, 'nneg_over_npos': pos_weight, 'roc_auc': auc_score}) roc_metrics_file = os.path.join(outdir, '{}_roc_metrics.csv'.format(label)) roc_metrics.to_csv(roc_metrics_file, index=False) roc = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thr': thr}) if simplify: full_roc_file = os.path.join(outdir, '{}_roc_curve_full.csv.gz'.format(label)) roc.to_csv(full_roc_file, index=False) roc = _filter_by_closeness(roc) roc_file = os.path.join(outdir, '{}_roc_curve.csv'.format(label)) roc.to_csv(roc_file, index=False) roc.plot(x='fpr', y='tpr', xlim=(0, 1), ylim=(0, 1)) roc_plot_file = os.path.join(outdir, '{}_roc.pdf'.format(label)) plt.savefig(roc_plot_file) plt.close() precision, recall, thr = precision_recall_curve(y_true_binary, y_scores, sample_weight=sample_weight) f1_score = 2 * precision * recall / (precision + recall) pr_auc = auc(recall, precision) pr_metrics = pd.Series({'npos': npos, 'nneg': nneg, 'nneg_over_npos': pos_weight, 'pr_auc': pr_auc}) pr_metrics_file = os.path.join(outdir, '{}_pr_metrics.csv'.format(label)) pr_metrics.to_csv(pr_metrics_file, index=False) thr = np.append(thr, [thr[-1]]) pr = pd.DataFrame({'precision': precision, 'recall': recall, 'f1_score': f1_score, 'thr': thr}) if simplify: full_pr_file = os.path.join(outdir, '{}_pr_curve_full.csv.gz'.format(label)) pr.to_csv(full_pr_file, index=False) pr = _filter_by_closeness(pr) pr_file = os.path.join(outdir, '{}_pr_curve.csv'.format(label)) pr.to_csv(pr_file, index=False) pr.plot(x='recall', y='precision', xlim=(0, 1), ylim=(0, 1)) pr_plot_file = os.path.join(outdir, '{}_pr.pdf'.format(label)) plt.savefig(pr_plot_file) plt.close() print(label, 'AuPR:', pr_auc, 'AvgP:', average_precision_score(y_true_binary, y_scores, sample_weight=sample_weight)) def get_flops(model): concrete = tf.function(lambda inputs: model(inputs)) concrete_func = concrete.get_concrete_function( [tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs]) frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func) with tf.Graph().as_default() as graph: tf.graph_util.import_graph_def(graph_def, name='') run_meta = tf.compat.v1.RunMetadata() opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation() flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd="op", options=opts) tf.compat.v1.reset_default_graph() return flops.total_float_ops def evaluate(exp, force=False): ckpt_path = exp.path_to('best_model.keras') # If the model doesn't exist and we're not forcing evaluation, just return if not os.path.exists(ckpt_path): print(f"Model file not found: {ckpt_path}") return custom_objects = { 'AdaBeliefOptimizer': AdaBeliefOptimizer, 'iou_coef': iou_coef, 'dice_coef': dice_coef, 'hard_swish': hard_swish, 'enhanced_binary_crossentropy': enhanced_binary_crossentropy, 'boundary_loss': boundary_loss } try: model = tf.keras.models.load_model(ckpt_path, custom_objects=custom_objects) except Exception as e: print(f"Error loading model: {str(e)}") return # get flops flop_params_path = exp.path_to('flops_nparams.csv') if force or not os.path.exists(flop_params_path): model.compile() # Add .keras extension to tmp_model tf.keras.models.save_model(model, 'tmp_model.keras', overwrite=True, include_optimizer=False) stripped_model = tf.keras.models.load_model('tmp_model.keras') flops = get_flops(stripped_model) nparams = stripped_model.count_params() del stripped_model # Clean up temporary model file if os.path.exists('tmp_model.keras'): os.remove('tmp_model.keras') print('FLOPS:', flops) print('#PARAMS:', nparams) pd.DataFrame({'flops': flops, 'nparams': nparams}, index=[0]).to_csv(flop_params_path) model.compile( loss='binary_crossentropy', metrics={'mask': [iou_coef, dice_coef], 'tags': 'binary_accuracy'} ) params = exp.params np.random.seed(params.seed) tf.random.set_seed(params.seed) data = load_datasets(params.data) # TRAIN/VAL/TEST SPLIT if params.split == 'subjects': # by SUBJECTS test_subjects = (3, 4, 19, 38, 45, 46, 51, 52) test_data = data[data['sub'].isin(test_subjects)] elif params.split == 'random': # 70-20-10 % _, valtest_data = train_test_split(data, test_size=.3, shuffle=True) _, test_data = train_test_split(valtest_data, test_size=.33) x_shape = (params.resolution, params.resolution, 1) test_gen, test_categories = get_loader(test_data, batch_size=1, x_shape=x_shape) prediction_dir = exp.path_to('test_pred') os.makedirs(prediction_dir, exist_ok=True) loss_per_sample = None def _get_test_predictions(test_gen, model): x_masks = [] y_masks, y_tags = [], [] pred_masks, pred_tags = [], [] loss_per_sample = [] for x, y in tqdm(test_gen, desc='TEST'): try: # Generate predictions first p_mask, p_tags = model.predict_on_batch(x) # Append predictions and ground truth pred_masks.append(p_mask) pred_tags.append(p_tags) y_masks.append(y['mask'].numpy()) y_tags.append(y['tags'].numpy()) x_masks.append(x.numpy()) # Calculate losses manually if needed sample_loss = model.compiled_loss( {'mask': y['mask'], 'tags': y['tags']}, {'mask': p_mask, 'tags': p_tags} ) loss_per_sample.append(sample_loss.numpy()) except Exception as e: print(f"Error processing batch: {str(e)}") continue # Check if we have any successful predictions if not pred_masks or not pred_tags or not y_masks or not y_tags or not x_masks: raise ValueError("No predictions were collected - all batches failed") try: loss_per_sample = np.array(loss_per_sample) pred_masks = np.concatenate(pred_masks) pred_tags = np.concatenate(pred_tags) y_masks = np.concatenate(y_masks) y_tags = np.concatenate(y_tags) x_masks = np.concatenate(x_masks) except Exception as e: print(f"Error concatenating results: {str(e)}") raise return loss_per_sample, x_masks, y_masks, y_tags, pred_masks, pred_tags mask_metrics_path = exp.path_to('test_pred/mask_metrics.csv') if force or not os.path.exists(mask_metrics_path): if loss_per_sample is None: loss_per_sample, x_masks, y_masks, y_tags, pred_masks, pred_tags = _get_test_predictions(test_gen, model) thrs = np.linspace(0, 1, 101) ious = [iou_coef(y_masks, pred_masks).numpy() for thr in thrs] dices = [dice_coef(y_masks, pred_masks).numpy() for thr in thrs] best_thr = max(zip(dices, thrs))[1] mask_metrics = pd.DataFrame({'iou': ious, 'dice': dices, 'thr': thrs}) print(mask_metrics.max(axis=0)) mask_metrics.to_csv(mask_metrics_path) else: mask_metrics = pd.read_csv(mask_metrics_path, index_col=0) best_thr = mask_metrics.loc[mask_metrics.dice.idxmax(), 'thr'] if force: if loss_per_sample is None: loss_per_sample, x_masks, y_masks, y_tags, pred_masks, pred_tags = _get_test_predictions(test_gen, model) _weighted_roc_pr(y_tags[:, 0], pred_tags[:, 0], 'all_eye', prediction_dir) _weighted_roc_pr(y_tags[:, 1], pred_tags[:, 1], 'all_blink', prediction_dir) filenames = ('top_samples.png', 'bottom_samples.png', 'random_samples.png') if force or any(not os.path.exists(os.path.join(prediction_dir, f)) for f in filenames): if loss_per_sample is None: loss_per_sample, x_masks, y_masks, y_tags, pred_masks, pred_tags = _get_test_predictions(test_gen, model) k = 5 best_selector = [] worst_selector = [] random_selector = [] idx = np.arange(len(test_data)) for cat in np.unique(test_categories): cat_outdir = os.path.join(prediction_dir, cat) os.makedirs(cat_outdir, exist_ok=True) selector = test_categories == cat _weighted_roc_pr(y_tags[selector, 0], pred_tags[selector, 0], '{}_eye'.format(cat), cat_outdir) _weighted_roc_pr(y_tags[selector, 1], pred_tags[selector, 1], '{}_blink'.format(cat), cat_outdir) # Fix the indexing here - loss_per_sample is 1D cat_losses = loss_per_sample[selector] rank = cat_losses.argsort() topk, bottomk = rank[:k], rank[-k:] best_selector += idx[selector][topk].tolist() worst_selector += idx[selector][bottomk].tolist() random_selector += np.random.choice(idx[selector], k, replace=False).tolist() # topk-bottomk images selectors = (best_selector, worst_selector, random_selector) for selector, outfile in zip(selectors, filenames): combined_m = np.concatenate((pred_masks[selector], y_masks[selector]), axis=-1)[:, :, :, ::-1] combined_t = np.concatenate((pred_tags[selector], y_tags[selector]), axis=-1) combined_y = (combined_m, combined_t) out = os.path.join(prediction_dir, outfile) visualize(x_masks[selector], combined_y, out=out, thr=best_thr, n_cols=k, width=10) for i, (xi, yi_mask) in enumerate(zip(x_masks[selector], combined_m)): img = visualizable(xi, yi_mask, thr=best_thr) img = (img * 255).astype(np.uint8) out = os.path.join(prediction_dir, outfile[:-4]) os.makedirs(out, exist_ok=True) out = os.path.join(out, f'{i}.png') Image.fromarray(img).save(out) def main(args): try: for exp in expman.gather(args.run).filter(args.filter): print(exp) evaluate(exp, force=args.force) except Exception as e: print(f"Error in main: {str(e)}") raise if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluate Run') parser.add_argument('run', help='Run(s) directory') parser.add_argument('-f', '--filter', default={}, type=expman.exp_filter) parser.add_argument('--force', default=False, action='store_true', help='Force metrics recomputation') args = parser.parse_args() main(args)