"""Miscellaneous helper utils for Tensorflow.""" |
import os |
import numpy as np |
import tensorflow as tf |
import logging |
logging.getLogger('tensorflow').setLevel(logging.ERROR) |
import tensorflow.contrib |
tf.contrib = tensorflow.contrib |
from typing import Any, Iterable, List, Union |
TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] |
"""A type that represents a valid Tensorflow expression.""" |
TfExpressionEx = Union[TfExpression, int, float, np.ndarray] |
"""A type that can be converted to a valid Tensorflow expression.""" |
def run(*args, **kwargs) -> Any: |
"""Run the specified ops in the default session.""" |
assert_tf_initialized() |
return tf.get_default_session().run(*args, **kwargs) |
def is_tf_expression(x: Any) -> bool: |
"""Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" |
return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) |
def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: |
"""Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code.""" |
return [dim.value for dim in shape] |
def flatten(x: TfExpressionEx) -> TfExpression: |
"""Shortcut function for flattening a tensor.""" |
with tf.name_scope("Flatten"): |
return tf.reshape(x, [-1]) |
def log2(x: TfExpressionEx) -> TfExpression: |
"""Logarithm in base 2.""" |
with tf.name_scope("Log2"): |
return tf.log(x) * np.float32(1.0 / np.log(2.0)) |
def exp2(x: TfExpressionEx) -> TfExpression: |
"""Exponent in base 2.""" |
with tf.name_scope("Exp2"): |
return tf.exp(x * np.float32(np.log(2.0))) |
def erfinv(y: TfExpressionEx) -> TfExpression: |
"""Inverse of the error function.""" |
from tensorflow.python.ops.distributions import special_math |
return special_math.erfinv(y) |
def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: |
"""Linear interpolation.""" |
with tf.name_scope("Lerp"): |
return a + (b - a) * t |
def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: |
"""Linear interpolation with clip.""" |
with tf.name_scope("LerpClip"): |
return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) |
def absolute_name_scope(scope: str) -> tf.name_scope: |
"""Forcefully enter the specified name scope, ignoring any surrounding scopes.""" |
return tf.name_scope(scope + "/") |
def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: |
"""Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" |
return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) |
def _sanitize_tf_config(config_dict: dict = None) -> dict: |
cfg = dict() |
cfg["rnd.np_random_seed"] = None |
cfg["rnd.tf_random_seed"] = "auto" |
cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" |
cfg["env.HDF5_USE_FILE_LOCKING"] = "FALSE" |
cfg["graph_options.place_pruned_graph"] = True |
cfg["gpu_options.allow_growth"] = True |
for key in list(cfg): |
fields = key.split(".") |
if fields[0] == "env": |
assert len(fields) == 2 |
if fields[1] in os.environ: |
del cfg[key] |
if config_dict is not None: |
cfg.update(config_dict) |
return cfg |
def init_tf(config_dict: dict = None) -> None: |
"""Initialize TensorFlow session using good default settings.""" |
if tf.get_default_session() is not None: |
return |
cfg = _sanitize_tf_config(config_dict) |
np_random_seed = cfg["rnd.np_random_seed"] |
if np_random_seed is not None: |
np.random.seed(np_random_seed) |
tf_random_seed = cfg["rnd.tf_random_seed"] |
if tf_random_seed == "auto": |
tf_random_seed = np.random.randint(1 << 31) |
if tf_random_seed is not None: |
tf.set_random_seed(tf_random_seed) |
for key, value in cfg.items(): |
fields = key.split(".") |
if fields[0] == "env": |
assert len(fields) == 2 |
os.environ[fields[1]] = str(value) |
create_session(cfg, force_as_default=True) |
def assert_tf_initialized(): |
"""Check that TensorFlow session has been initialized.""" |
if tf.get_default_session() is None: |
raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") |
def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: |
"""Create tf.Session based on config dict.""" |
cfg = _sanitize_tf_config(config_dict) |
config_proto = tf.ConfigProto() |
for key, value in cfg.items(): |
fields = key.split(".") |
if fields[0] not in ["rnd", "env"]: |
obj = config_proto |
for field in fields[:-1]: |
obj = getattr(obj, field) |
setattr(obj, fields[-1], value) |
session = tf.Session(config=config_proto) |
if force_as_default: |
session._default_session = session.as_default() |
session._default_session.enforce_nesting = False |
session._default_session.__enter__() |
return session |
def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: |
"""Initialize all tf.Variables that have not already been initialized. |
Equivalent to the following, but more efficient and does not bloat the tf graph: |
tf.variables_initializer(tf.report_uninitialized_variables()).run() |
""" |
assert_tf_initialized() |
if target_vars is None: |
target_vars = tf.global_variables() |
test_vars = [] |
test_ops = [] |
with tf.control_dependencies(None): |
for var in target_vars: |
assert is_tf_expression(var) |
try: |
tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) |
except KeyError: |
test_vars.append(var) |
with absolute_name_scope(var.name.split(":")[0]): |
test_ops.append(tf.is_variable_initialized(var)) |
init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] |
run([var.initializer for var in init_vars]) |
def set_vars(var_to_value_dict: dict) -> None: |
"""Set the values of given tf.Variables. |
Equivalent to the following, but more efficient and does not bloat the tf graph: |
tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] |
""" |
assert_tf_initialized() |
ops = [] |
feed_dict = {} |
for var, value in var_to_value_dict.items(): |
assert is_tf_expression(var) |
try: |
setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) |
except KeyError: |
with absolute_name_scope(var.name.split(":")[0]): |
with tf.control_dependencies(None): |
setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") |
ops.append(setter) |
feed_dict[setter.op.inputs[1]] = value |
run(ops, feed_dict) |
def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): |
"""Create tf.Variable with large initial value without bloating the tf graph.""" |
assert_tf_initialized() |
assert isinstance(initial_value, np.ndarray) |
zeros = tf.zeros(initial_value.shape, initial_value.dtype) |
var = tf.Variable(zeros, *args, **kwargs) |
set_vars({var: initial_value}) |
return var |
def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): |
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range. |
Can be used as an input transformation for Network.run(). |
""" |
images = tf.cast(images, tf.float32) |
if nhwc_to_nchw: |
images = tf.transpose(images, [0, 3, 1, 2]) |
return images * ((drange[1] - drange[0]) / 255) + drange[0] |
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1, uint8_cast=True): |
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range. |
Can be used as an output transformation for Network.run(). |
""" |
images = tf.cast(images, tf.float32) |
if shrink > 1: |
ksize = [1, 1, shrink, shrink] |
images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") |
if nchw_to_nhwc: |
images = tf.transpose(images, [0, 2, 3, 1]) |
scale = 255 / (drange[1] - drange[0]) |
images = images * scale + (0.5 - drange[0] * scale) |
if uint8_cast: |
images = tf.saturate_cast(images, tf.uint8) |
return images |