Spaces:
Runtime error
Runtime error
| # Copyright 2022 Google. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Class to handle summarizing of metrics over multiple training steps.""" | |
| import abc | |
| from typing import Any, Dict, Mapping, Optional, Tuple, Union | |
| from absl import logging | |
| from clu import metric_writers | |
| import gin | |
| import jax | |
| from jax import numpy as jnp | |
| import numpy as np | |
| Array = Union[jnp.ndarray, np.ndarray] | |
| class Aggregator(abc.ABC): # Superclass for type checks | |
| def add(self, value: Any): | |
| pass | |
| def is_valid(self) -> bool: | |
| pass | |
| def to_value(self): | |
| pass | |
| class _MeanAggregator(Aggregator): | |
| """Maintains the mean of incoming values.""" | |
| mean: float = 0.0 | |
| weight: float = 0.0 | |
| def add(self, new_value: Any): | |
| """Aggregates a new value into the mean.""" | |
| if np.ndim(new_value) == 0: # is a scalar; works with int, float, Array | |
| val, weight = new_value, 1.0 # assuming weight 1 by default | |
| else: | |
| val, weight = new_value | |
| if weight < 0.0: | |
| raise ValueError("Adding value with negative weight.") | |
| total_weight = self.weight + weight | |
| if total_weight != 0.0 and weight > 0.0: | |
| delta = (val - self.mean) * weight / total_weight | |
| self.mean += delta | |
| self.weight = total_weight | |
| def is_valid(self) -> bool: | |
| return self.weight > 0.0 | |
| def to_value(self): | |
| assert self.weight > 0.0 | |
| return self.mean | |
| class _SumAggregator(_MeanAggregator): | |
| # We aggregate sum and mean in the same way as a tuple of the form: | |
| # (weighted mean, total weights). "sum" can then be computed by | |
| # multiplying the two values. | |
| def is_valid(self) -> bool: | |
| return True | |
| def to_value(self): | |
| return self.mean * self.weight | |
| class _LastAggregator(Aggregator): | |
| """Remembers the last value given.""" | |
| last_value: Optional[float] = None | |
| def add(self, new_value: Any): | |
| self.last_value = new_value | |
| def is_valid(self) -> bool: | |
| return self.last_value is not None | |
| def to_value(self): | |
| assert self.last_value is not None | |
| return self.last_value | |
| class MetricsSummary: | |
| """Summarizes a set of a metrics over multiple training steps.""" | |
| def __init__(self, | |
| metric_types: Mapping[str, str], | |
| upscale_images: bool = True, | |
| remove_outliers: bool = False): | |
| """Creates a MetricSummarizer. | |
| Args: | |
| metric_types: Map from metrics to the type of summary. Types are: | |
| "mean" = Compute the cumulative moving average. | |
| "sum" = Compute the sum. | |
| "last" = No summary, just return the last value. | |
| upscale_images: Upscale small images for easier viewing. | |
| remove_outliers: Remove outliers from histograms. | |
| """ | |
| self.metric_dict = {} # type: Dict[str, Aggregator] | |
| self.text_dict = {} | |
| self.metric_types = metric_types | |
| self.upscale_images = upscale_images | |
| self.remove_outliers = remove_outliers | |
| self.constructor_map = { | |
| "mean": _MeanAggregator, | |
| "sum": _SumAggregator, | |
| "last": _LastAggregator, | |
| } | |
| logging.debug("Registered metrics: %r", metric_types) | |
| def current_metric_dict(self) -> Mapping[str, Aggregator]: | |
| return self.metric_dict | |
| def _is_image(self, image: Array) -> bool: | |
| if image.ndim != 4: | |
| return False | |
| # Greyscale or RGB image. | |
| return image.shape[-1] == 1 or image.shape[-1] == 3 | |
| def _upscale_image(self, image: Array) -> Array: | |
| """Upscale small images to more pixels, for easier viewing.""" | |
| if not self.upscale_images: | |
| return image | |
| assert image.ndim == 4 # (num_images, ysize, xsize, num_channels) | |
| ys = image.shape[1] | |
| xs = image.shape[2] | |
| if xs > 512 or ys > 512: | |
| return image # No scaling. | |
| elif xs > 256 or ys > 256: | |
| scale = 2 | |
| else: | |
| scale = 4 | |
| yidx = np.arange(ys * scale) // scale | |
| xidx = np.arange(xs * scale) // scale | |
| scaled_image = image[:, yidx, :, :][:, :, xidx, :] | |
| return scaled_image | |
| def _remove_outliers(self, v, std_range: float = 4): | |
| if not self.remove_outliers: | |
| return v | |
| v_mean = np.mean(v) | |
| v_std = np.std(v) | |
| return np.where(np.abs(v) > (v_std * std_range), v_mean, v) | |
| def merge_replicated_metrics(device_metrics: Mapping[str, Any], | |
| metric_types: Mapping[str, str]): | |
| """Merge metrics across devices by psum over "batch" axis. | |
| Args: | |
| device_metrics: dictionary of device metrics. | |
| metric_types: map from the metric name to { "mean", "sum" } | |
| Returns: | |
| A dictionary of metrics. | |
| """ | |
| logging.info("Merging metrics across devices %r: ", | |
| [(k, metric_types[k] if k in metric_types else None) | |
| for k in device_metrics.keys()]) | |
| def aggregate_sum(value: Array) -> Array: | |
| assert not isinstance(value, tuple), ( | |
| "Weighted sums are not supported when aggregating over devices.") | |
| return jax.lax.psum(value, axis_name="batch") | |
| def aggregate_mean(value: Array, weight: Array) -> Tuple[Array, Array]: | |
| weighted_value = value * weight | |
| weighted_value = jax.lax.psum(weighted_value, axis_name="batch") | |
| weight = jax.lax.psum(weight, axis_name="batch") | |
| return weighted_value / (weight + 1.0e-6), weight | |
| aggregated_metrics = dict(device_metrics) | |
| for k, value in aggregated_metrics.items(): | |
| if k not in metric_types: | |
| # If no metric type is given, metric remains untouched. | |
| continue | |
| if metric_types[k] == "sum": | |
| aggregated_metrics[k] = aggregate_sum(value) | |
| elif metric_types[k] == "mean": | |
| if not isinstance(aggregated_metrics[k], tuple): | |
| logging.info("Metric '%s' has no weight; assuming 1.0.", k) | |
| value = (value, jnp.array(1.0)) | |
| aggregated_metrics[k] = aggregate_mean(*value) | |
| else: | |
| raise ValueError("Can only aggregate 'sum' and 'mean' over devices. " | |
| f"Got {metric_types[k]}.") | |
| return aggregated_metrics | |
| def _new_aggregator(self, key) -> Aggregator: | |
| if key in self.metric_types: | |
| return self.constructor_map[self.metric_types[key]]() | |
| else: | |
| # TODO(mrabe): The default to last_value is not obvious. Force all metric | |
| # types to be given explicitly. | |
| logging.debug("No metric type for accumulator: %s", key) | |
| return _LastAggregator() | |
| def add(self, metrics: Mapping[str, Any]): | |
| """Add metrics from the current training step to the summary. | |
| Args: | |
| metrics: Dictionary of metrics. | |
| """ | |
| for k, new_value in metrics.items(): | |
| if k not in self.metric_dict: | |
| self.metric_dict[k] = self._new_aggregator(k) | |
| self.metric_dict[k].add(new_value) | |
| def add_text(self, text_metrics: Mapping[str, str]): | |
| """Add text metrics from the current step to the summary.""" | |
| for (k, v) in text_metrics.items(): | |
| self.text_dict[k] = str(v) | |
| def empty(self): | |
| """Return true if there are no summaries to write.""" | |
| return not (self.metric_dict or self.text_dict) | |
| def clear(self): | |
| """Clear acculumated summaries.""" | |
| self.metric_dict = {} | |
| self.text_dict = {} | |
| def write(self, writer: metric_writers.MetricWriter, step: int, prefix: str): | |
| """Write metrics using summary_writer, and clear all summaries.""" | |
| if self.empty(): | |
| return | |
| # Special logic for organizing metrics under tensorboard. | |
| # Tensorboard has top-level groups, but doesn't have subgroups. | |
| # Scalars are put into separate top-level groups for easier viewing. | |
| # e.g. all scalars in "train", "test", etc. | |
| # For images, each set of images should be a different top-level group, | |
| # otherwise all images will get tossed into a single group under, | |
| # e.g. "generate". | |
| if prefix: | |
| s_prefix = prefix + "/" | |
| i_prefix = prefix + "_" | |
| else: | |
| # Each prefix is stored in a separate subdirectory already. | |
| s_prefix = "" | |
| i_prefix = "" | |
| # Split metrics into different types. | |
| scalars = {} | |
| images = {} | |
| histograms = {} | |
| text_dict = {} | |
| # Sort metrics into scalars, images, text, and histograms. | |
| for k, aggregator in self.metric_dict.items(): | |
| if not isinstance(aggregator, Aggregator): | |
| raise ValueError("Internal error: metric_dict should contain only " | |
| "_Aggregator objects; contained %s" % aggregator) | |
| if not aggregator.is_valid(): | |
| raise ValueError(f"No valid value for metric {k}.") | |
| v = aggregator.to_value() | |
| s_key = s_prefix + k | |
| i_key = i_prefix + k | |
| finite_mask = np.isfinite(v) | |
| if not np.all(finite_mask): | |
| logging.warning("Item %s contains non-finite elements.", k) | |
| v = np.where(finite_mask, v, np.zeros_like(v)) | |
| if v is None: | |
| logging.warning("Invalid value for %s", k) | |
| elif np.ndim(v) == 0: | |
| scalars[s_key] = v | |
| elif self._is_image(v): | |
| images[i_key] = self._upscale_image(v) | |
| else: | |
| histograms[s_key] = self._remove_outliers(v) | |
| # Handle text data. | |
| for (k, v) in self.text_dict.items(): | |
| s_key = s_prefix + k | |
| text_dict[s_key] = v | |
| # Write metrics. | |
| if scalars: | |
| writer.write_scalars(step, scalars) | |
| if images: | |
| writer.write_images(step, images) | |
| if histograms: | |
| writer.write_histograms(step, histograms) | |
| if text_dict: | |
| writer.write_texts(step, text_dict) | |
| # Clear accumulated summaries. | |
| self.clear() | |