import argparse import math import os os.sys.path += ['expman'] import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.ticker as ticker from matplotlib.image import imread from matplotlib.backends.backend_pdf import PdfPages import numpy as np import pandas as pd import seaborn as sns from glob import glob import expman def ee(args): sns.set_theme(context='notebook', style='whitegrid') exps = expman.gather(args.run).filter(args.filter) mask_metrics = exps.collect('test_pred/mask_metrics.csv').groupby('exp_id')[['dice', 'iou']].max() flops_nparams = exps.collect('flops_nparams.csv') data = pd.merge(mask_metrics, flops_nparams, on='exp_id') data['dice'] *= 100 named_data = data.rename({ 'nparams': '# Params', 'dice': 'mean Dice Coeff. (%)', 'conv_type': '$t$ (Conv. Type)', 'grow_factor': r'$\gamma$', 'num_filters': '$k$ (# Filters)', 'flops': 'FLOPs', 'num_stages': '$s$ (# Stages)', }, axis=1).replace({ 'bn-conv': 'conv-bn', 'sep-bn-conv': 'sep-conv-bn' }) g = sns.relplot(data=named_data, x='FLOPs', y='mean Dice Coeff. (%)', hue='$t$ (Conv. Type)', hue_order=['conv', 'conv-bn', 'sep-conv', 'sep-conv-bn'], col='$s$ (# Stages)', style='$k$ (# Filters)', markers=True, markersize=9, kind='line', dashes=True, facet_kws=dict(despine=False, legend_out=False), legend=True, height=3.8, aspect=1.3, markeredgecolor='white') b_formatter = ticker.FuncFormatter(lambda x, pos: '{:.2f}'.format(x / 10 ** 9) + 'B') h, l = g.axes.flatten()[0].get_legend_handles_labels() for hi in h: hi.set_markeredgecolor('white') g.axes.flatten()[0].legend_.remove() g.fig.legend(h, l, ncol=2, bbox_to_anchor=(0.53 ,0.53), fancybox=False, columnspacing=0, framealpha=1, handlelength=1.2) for ax in g.axes.flatten(): ax.yaxis.set_minor_locator(ticker.AutoMinorLocator()) ax.set_ylim(bottom=40, top=90) ax.set_xscale('symlog') ax.set_xlim(left=0.04 * 10 ** 9, right=2 * 10 ** 9) ax.xaxis.set_minor_locator(ticker.SymmetricalLogLocator(base=10, linthresh=2, subs=[1.5, 2,3,4,5,6,8])) ax.xaxis.set_minor_formatter(b_formatter) ax.grid(which='minor', linestyle='--', color='#eeeeee') ax.xaxis.set_major_formatter(b_formatter) ax.tick_params(axis="x", which="both", rotation=90) plt.savefig(args.output, bbox_inches='tight') def bd(args): exps = expman.gather(args.run).filter(args.filter) blink_metrics = exps.collect('test_pred/all_blink_roc_metrics.csv') blink_metrics = blink_metrics.iloc[3::4].rename({'0': 'auc'}, axis=1) aucs = blink_metrics.auc.values print(f'{aucs.mean()} +- {aucs.std()}') def dice_fps(args): exps = expman.gather(args.run).filter(args.filter) mask_metrics = exps.collect('test_pred/mask_metrics.csv') mask_metrics = mask_metrics.groupby('exp_name').dice.max() time_metrics = exps.collect('timings.csv') time_metrics = time_metrics.rename({'Unnamed: 0': 'metrics', '0':'value'}, axis=1) time_metrics = time_metrics.pivot_table(index='exp_name', columns='metrics', values='value') flops_nparams = exps.collect('flops_nparams.csv') flops_nparams = flops_nparams.set_index('exp_name')[['flops','nparams']] table = pd.concat((time_metrics, mask_metrics, flops_nparams), axis=1)[['dice', 'fps', 'throughput', 'flops', 'nparams']] table['dice'] = table.dice.map('{:.1%}'.format) table['fps'] = table.fps.map('{:.1f}'.format) table['throughput'] = (table.throughput*1000).map('{:.1f}ms'.format) table['flops'] = (table.flops / 10**9).map('{:.1f}G'.format) table['nparams'] = (table.nparams / 10**6).map('{:.2f}M'.format) print(table) def metrics(args): exps = expman.gather(args.run).filter(args.filter) mask_metrics = exps.collect('test_pred/mask_metrics.csv') sns.lineplot(data=mask_metrics, x='thr', y='dice', hue='conv_type', size='grow_factor', style='num_filters') plt.savefig(args.output) def log(args): exps = expman.gather(args.run).filter(args.filter) with PdfPages(args.output) as pdf: for exp_name, exp in sorted(exps.items()): print(exp_name) log = pd.read_csv(exp.path_to('log.csv'), index_col='epoch') train_cols = [c for c in log.columns if 'val' not in c] val_cols = [c for c in log.columns if 'val' in c] test_images = glob(os.path.join(exp.path_to('test_pred'), '*_samples.png')) fig = plt.figure(figsize=(14, 10)) fig_shape = (2, 2) if test_images else (2, 1) ax1 = plt.subplot2grid(fig_shape, (0, 0)) ax2 = plt.subplot2grid(fig_shape, (1, 0)) log[train_cols].plot(ax=ax1) log[val_cols].plot(ax=ax2) ax1.legend(loc='center right', bbox_to_anchor=(-0.05, 0.5)) ax2.legend(loc='center right', bbox_to_anchor=(-0.05, 0.5)) ax2.set_ylim((0, 1)) if test_images: test_images = sorted(test_images) test_images = list(map(imread, test_images)) max_w = max(i.shape[1] for i in test_images) pads = [((0,0), (0, max_w - i.shape[1]), (0, 0)) for i in test_images] test_images = np.concatenate([np.pad(i, pad) for i, pad in zip(test_images, pads)], axis=0) ax3 = plt.subplot2grid(fig_shape, (0, 1), rowspan=2) ax3.imshow(test_images) ax3.set_axis_off() log_plot_file = exp.path_to('log_plot.pdf') plt.suptitle(exp_name) plt.savefig(log_plot_file, bbox_inches='tight') pdf.savefig(fig, bbox_inches='tight') plt.close() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Show stuff') parser.add_argument('-f', '--filter', default={}, type=expman.exp_filter) subparsers = parser.add_subparsers() parser_log = subparsers.add_parser('log') parser_log.add_argument('run', default='runs/') parser_log.add_argument('-o', '--output', default='log_summary.pdf') parser_log.set_defaults(func=log) parser_metrics = subparsers.add_parser('metrics') parser_metrics.add_argument('run', default='runs/') parser_metrics.add_argument('-o', '--output', default='mask_metrics_summary.pdf') parser_metrics.set_defaults(func=metrics) parser_ee = subparsers.add_parser('ee') parser_ee.add_argument('run', default='runs/') parser_ee.add_argument('-o', '--output', default='ee_summary.pdf') parser_ee.set_defaults(func=ee) parser_bd = subparsers.add_parser('bd') parser_bd.add_argument('run', default='runs/') parser_bd.set_defaults(func=bd) parser_dice_fps = subparsers.add_parser('dice-fps') parser_dice_fps.add_argument('run', default='runs/') parser_dice_fps.set_defaults(func=dice_fps) args = parser.parse_args() args.func(args)