|
import argparse, os, shutil, json |
|
from netdissect.easydict import EasyDict |
|
from xml.etree import ElementTree as et |
|
from collections import defaultdict |
|
|
|
def parseargs(): |
|
parser = argparse.ArgumentParser() |
|
def aa(*args, **kwargs): |
|
parser.add_argument(*args, **kwargs) |
|
aa('--model', choices=['resnet18-bn', 'resnet18-hn', 'resnet18-ha'], |
|
default='resnet18-bn') |
|
aa('--iteration', type=int, default=0) |
|
aa('--dataset', choices=['places'], |
|
default='places') |
|
aa('--seg', choices=['net', 'netp', 'netq', 'netpq', |
|
'netpqc', 'netpqxc', 'human'], |
|
default='net') |
|
aa('--layers', nargs='+') |
|
aa('--quantile', type=float, default=0.01) |
|
aa('--miniou', type=float, default=0.025) |
|
args = parser.parse_args() |
|
return args |
|
|
|
def main(): |
|
args = parseargs() |
|
threshold_iou = args.miniou |
|
layer_report = {} |
|
qdir = '-%d' % (args.quantile * 1000) if args.quantile != 0.01 else '' |
|
for layer in args.layers: |
|
input_filename = 'results/%s-%d-%s-%s-%s%s/report.json' % ( |
|
args.model, args.iteration, args.dataset, args.seg, layer, qdir) |
|
with open(input_filename) as f: |
|
layer_report[layer] = EasyDict(json.load(f)) |
|
|
|
|
|
cat_order = ['object', 'part', 'material', 'color'] |
|
graph_data = [] |
|
for layer in args.layers: |
|
layer_data = [] |
|
catmap = defaultdict(lambda: defaultdict(int)) |
|
units = layer_report[layer].get('units', |
|
layer_report[layer].get('images', None)) |
|
for unitrec in units: |
|
if unitrec.iou is None or unitrec.iou < threshold_iou: |
|
continue |
|
catmap[unitrec.cat][unitrec.label] += 1 |
|
for cat in cat_order: |
|
if cat not in catmap: |
|
continue |
|
|
|
cat_data = list(catmap[cat].values()) |
|
cat_data.sort(key=lambda x: -x) |
|
layer_data.append((cat, cat_data)) |
|
graph_data.append((layer, layer_data)) |
|
|
|
largest_layer = max(sum(len(cat_data) |
|
for cat, cat_data in layer_data) |
|
for layer, layer_data in graph_data) |
|
layer_height = 14 |
|
layer_gap = 2 |
|
barwidth = 3 |
|
bargap = 0 |
|
leftmargin = 48 |
|
margin = 8 |
|
svgwidth = largest_layer * (barwidth + bargap) + margin + leftmargin |
|
svgheight = ((layer_height + layer_gap) * len(args.layers) - layer_gap + |
|
2 * margin) |
|
textsize = 10 |
|
|
|
|
|
svg = et.Element('svg', width=str(svgwidth), height=str(svgheight), |
|
version='1.1', xmlns='http://www.w3.org/2000/svg') |
|
|
|
|
|
y = margin |
|
for layer, layer_data in graph_data: |
|
et.SubElement(svg, 'text', x='0', y='0', |
|
style=('font-family:sans-serif;font-size:%dpx;' + |
|
'text-anchor:end;alignment-baseline:hanging;' + |
|
'transform:translate(%dpx, %dpx);') % |
|
(textsize, leftmargin - 4, y + (layer_height - textsize) / 2) |
|
).text = str(layer) |
|
barmax = max(max(cat_data) if len(cat_data) else 1 |
|
for cat, cat_data in layer_data) if len(layer_data) else 1 |
|
barscale = float(layer_height) / barmax |
|
x = leftmargin |
|
for cat, cat_data in layer_data: |
|
catwidth = len(cat_data) * (barwidth + bargap) |
|
et.SubElement(svg, 'rect', |
|
x=str(x), y=str(y), |
|
width=str(catwidth), |
|
height=str(layer_height), |
|
fill=cat_palette[cat][1]) |
|
for bar in cat_data: |
|
barheight = barscale * bar |
|
et.SubElement(svg, 'rect', |
|
x=str(x), y=str(y + layer_height - barheight), |
|
width=str(barwidth), |
|
height=str(barheight), |
|
fill=cat_palette[cat][0]) |
|
x += barwidth + bargap |
|
y += layer_height + layer_gap |
|
|
|
|
|
result = et.tostring(svg).decode('utf-8') |
|
|
|
result = ''.join([ |
|
'<?xml version=\"1.0\" standalone=\"no\"?>\n', |
|
'<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n', |
|
'\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n', |
|
result]) |
|
output_filename = 'results/%s-%s-%s-%s%s/multilayer-%d.svg' % ( |
|
args.model, args.iteration, args.dataset, args.seg, qdir, |
|
args.miniou * 1000) |
|
os.makedirs(os.path.dirname(output_filename), exist_ok=True) |
|
print('writing to %s' % output_filename) |
|
with open(output_filename, 'w') as f: |
|
f.write(result) |
|
|
|
cat_palette = { |
|
'object': ('#4B4CBF', '#B6B6F2'), |
|
'part': ('#55B05B', '#B6F2BA'), |
|
'material': ('#50BDAC', '#A5E5DB'), |
|
'texture': ('#81C679', '#C0FF9B'), |
|
'color': ('#F0883B', '#F2CFB6'), |
|
'other1': ('#D4CF24', '#F2F1B6'), |
|
'other2': ('#D92E2B', '#F2B6B6'), |
|
'other3': ('#AB6BC6', '#CFAAFF') |
|
} |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|