tristan-deep commited on
Commit
c1857cb
·
1 Parent(s): 1c76709

fix labels

Browse files
Files changed (1) hide show
  1. plots.py +22 -2
plots.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  from pathlib import Path
3
  from typing import Any, Dict, List
4
 
@@ -9,9 +10,9 @@ import tyro
9
  from keras import ops
10
  from matplotlib.patches import PathPatch
11
  from matplotlib.path import Path as pltPath
 
12
  from skimage import measure
13
  from zea import log
14
- from zea.io_lib import matplotlib_figure_to_numpy
15
  from zea.utils import save_to_gif
16
  from zea.visualize import plot_image_grid
17
 
@@ -46,6 +47,25 @@ def add_shape_from_mask(ax, mask, **kwargs):
46
  return patches
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def plot_batch_with_named_masks(
50
  images, masks_dict, mask_colors=None, titles=None, **kwargs
51
  ):
@@ -359,7 +379,7 @@ def create_animation_frame(hazy_images, tissue_frame, haze_frame):
359
  vmin=0,
360
  vmax=255,
361
  )
362
- labels = ["Hazy", "Tissue"] if haze_frame is None else ["Hazy", "Tissue", "Haze"]
363
  for i, ax in enumerate(fig_frame.axes):
364
  label = labels[i % len(labels)]
365
  ax.set_ylabel(label, fontsize=12)
 
1
  import json
2
+ from io import BytesIO
3
  from pathlib import Path
4
  from typing import Any, Dict, List
5
 
 
10
  from keras import ops
11
  from matplotlib.patches import PathPatch
12
  from matplotlib.path import Path as pltPath
13
+ from PIL import Image
14
  from skimage import measure
15
  from zea import log
 
16
  from zea.utils import save_to_gif
17
  from zea.visualize import plot_image_grid
18
 
 
47
  return patches
48
 
49
 
50
+ def matplotlib_figure_to_numpy(fig):
51
+ """Convert matplotlib figure to numpy array.
52
+
53
+ Args:
54
+ fig (matplotlib.figure.Figure): figure to convert.
55
+
56
+ Returns:
57
+ np.ndarray: numpy array of figure.
58
+
59
+ """
60
+ buf = BytesIO()
61
+ fig.savefig(buf, format="png", bbox_inches="tight")
62
+ buf.seek(0)
63
+ image = Image.open(buf).convert("RGB")
64
+ image = np.array(image)[..., :3]
65
+ buf.close()
66
+ return image
67
+
68
+
69
  def plot_batch_with_named_masks(
70
  images, masks_dict, mask_colors=None, titles=None, **kwargs
71
  ):
 
379
  vmin=0,
380
  vmax=255,
381
  )
382
+ labels = ["Hazy", "Haze", "Tissue"]
383
  for i, ax in enumerate(fig_frame.axes):
384
  label = labels[i % len(labels)]
385
  ax.set_ylabel(label, fontsize=12)