# Copyright 2022 Google. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Core NN components used in models. """ from typing import Any, Callable, Optional, Tuple, Union from absl import logging from flax import linen as nn import gin import jax from jax import lax from jax.nn import initializers import jax.numpy as jnp PRNGKey = Any Array = jnp.ndarray Shape = Tuple[int, ...] Dtype = Union[jnp.dtype, str] def scalar_initializer(x): """Like linen.zeros, but initializes a parameter to a scalar value.""" def init_fun(key, shape, dtype): del key return jnp.broadcast_to(jnp.array(x, dtype=dtype), shape) return init_fun def swish(x: Array) -> Array: """Swish function, which is very similar to gelu.""" return x * nn.sigmoid(x) def soft_abs(x: Array) -> Array: """Soft version of absolute value, that is smoothly differentiable.""" return jnp.sqrt(jnp.square(x) + 1) - 1 def get_activation_function(fname: Optional[str]) -> Callable[[Array], Array]: """Get activation function from the specified string.""" if fname is None: return lambda x: x elif fname == "relu": return nn.relu elif fname == "swish": return swish elif fname == "sigmoid": return nn.sigmoid elif fname == "tanh": return nn.tanh else: raise ValueError("Unknown activation function %s" % fname) # Adapted from flax.linen.softmax. def safe_softmax(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, min_x: Optional[Array] = None) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` such that the elements along :code:`axis` sum to :math:`1`. This version of softmax is intended for use with causal attention masks, and safely covers the situation where all elements are masked out. If min_x is not None, then probabability will be distributed between the values in x, and min_x. If x >> min_x, then the probability allocated to min_x will be zero, and this function will be the same as the usual softmax. However, if x << min_x, (because all the values in x are masked out) then probability will be allocated to min_x instead, and the probability allocated to x will be 0. I.e., attention will attend to nothing if everything is masked out. .. math :: \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} Args: x: input array axis: the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to :math:`1`. Either an integer or a tuple of integers. min_x: the value of a minimum element which will be included in the softmax sum. The value of min_x should be small when compared to the expected values in x. If all of the values in x are smaller than min_x, then probability will be allocated to the minimum element instead, and the result of softmax will sum to less than 1. Returns: An array of the same shape as x. """ # Subtract maximum value in x for numerical stability, so that the exponent # never exceeds numerical precision. x_max = lax.stop_gradient(jnp.max(x, axis, initial=min_x, keepdims=True)) if min_x is not None: min_x = jnp.asarray(min_x, dtype=x.dtype) x_max = jnp.maximum(x_max, min_x) unnormalized = jnp.exp(x - x_max) x_sum = jnp.sum(unnormalized, axis=axis, keepdims=True) if min_x is not None: x_sum = x_sum + jnp.exp(min_x - x_max) return unnormalized / x_sum def dropout_multiplier_mask(rng, dropout_rate: float, shape: Shape, dtype: Dtype): """Returns an array which can be multiplied by an input to perform dropout. Args: rng: A random number generator. dropout_rate: The rate at which to drop. shape: The shape of the output array. dtype: The type of the output array. Returns: An array of given shape, where values are { 0.0, 1.0/keep_probibility. }. """ if dropout_rate <= 0.0: return jnp.ones(shape, dtype=dtype) logging.info("dropout mask: %s", shape) keep_prob = 1.0 - dropout_rate keep = jax.random.bernoulli(rng, keep_prob, shape) dropout_multiplier = (keep.astype(dtype) / jnp.asarray(keep_prob, dtype)) return dropout_multiplier def tiled_dropout(x: Array, shape: Shape, dropout_rate: float, rng_function: Callable[[], jax.random.KeyArray], deterministic: bool) -> Array: """Tiles a dropout mask over a larger array. This will generate a smaller dropout mask of the given shape, and tile it over a larger array, which reduces the computational cost and memory associated with generating a large dropout mask. Args: x: The input array. shape: The shape of the dropout mask to tile. dropout_rate: The rate at which to drop. rng_function: A function which returns a random number generator, e.g. lambda. self.make_rng("dropout"). The function will not be called if dropout is not enabled. deterministic: If True, don't do dropout. Returns: An array of the same shape as x, with some values dropped out. """ if deterministic or dropout_rate <= 0.0: return x if x.ndim != len(shape): raise ValueError("Shapes must have same number of dimensions %r, %r." % (x.shape, shape)) for (xd, sd) in zip(x.shape, shape): if (xd % sd) != 0: raise ValueError("Incompatible shapes %r, %r" % (x.shape, shape)) # Get random number generator for dropout. rng = rng_function() repeats = [(1 if sd == 1 else xd // sd) for (xd, sd) in zip(x.shape, shape)] logging.info("tiled dropout %r, tile: %r", x.shape, shape) dtype = x.dtype keep_prob = 1.0 - dropout_rate keep = jax.random.bernoulli(rng, keep_prob, shape) keep = jnp.tile(keep, repeats) keep = jnp.broadcast_to(keep, x.shape) x_scaled = x / jnp.asarray(keep_prob, dtype=dtype) return lax.select(keep, x_scaled, jnp.zeros_like(x, dtype=dtype)) @gin.configurable class MLP(nn.Module): """Implements a multi-layer perceptron, with optional resnet or gate.""" # Arguments to module. num_output_features: int # Length of output vectors. # Gin configurable parameters. num_layers: int = gin.REQUIRED # Number of layers in the MLP. num_hidden_units: int = gin.REQUIRED # Length of hidden unit vectors. hidden_activation: Optional[str] = "relu" # Hidden layer activation fn. final_activation: Optional[str] = None # Final layer activation fn. use_bias: bool = True # Use a bias in each dense layer. gate_type: Optional[str] = None # { "residual", "bias", "full" } initializer_scale: float = 1.0 # Scale of initial values. dtype: Any = jnp.float32 def setup(self): kernel_init = jax.nn.initializers.variance_scaling( scale=self.initializer_scale, mode="fan_in", distribution="truncated_normal") assert self.num_layers > 0 hlayers = [] for i in range(0, self.num_layers - 1): assert self.num_hidden_units > 0 hlayer = nn.Dense(self.num_hidden_units, use_bias=self.use_bias, kernel_init=kernel_init, dtype=self.dtype, name=f"hidden{i}") hlayers.append(hlayer) self.hidden_layers = hlayers self.output_layer = nn.Dense(self.num_output_features, use_bias=self.use_bias, kernel_init=kernel_init, dtype=self.dtype) if self.gate_type is None or self.gate_type == "residual": return # We use a low but non-zero bias so that adafactor knows how to scale it. gate_bias_init = jax.nn.initializers.normal(stddev=0.1) # Also use a lower than normal kernel. gate_kernel_init = jax.nn.initializers.variance_scaling( scale=0.1, mode="fan_in", distribution="truncated_normal") if self.gate_type == "bias": self.gate_bias = self.param("gate_bias", gate_bias_init, (self.num_output_features,), jnp.float32) elif self.gate_type == "full": self.gate_layer = nn.Dense(self.num_output_features, use_bias=True, bias_init=gate_bias_init, kernel_init=gate_kernel_init, dtype=self.dtype) elif self.gate_type == "lstm": self.input_gate = nn.Dense(self.num_output_features, use_bias=True, bias_init=gate_bias_init, kernel_init=gate_kernel_init, dtype=self.dtype) self.forget_gate = nn.Dense(self.num_output_features, use_bias=True, bias_init=gate_bias_init, kernel_init=gate_kernel_init, dtype=self.dtype) else: raise ValueError("Unsupported gate_type: %s" % self.gate_type) def _gate(self, y_hidden: Array, state: Array, y_out: Array) -> Array: """Compute the value to use for the gate.""" if self.gate_type == "residual": # Residual connection: just add y_out to the state. logging.info("mlp: residual") return state + y_out elif self.gate_type == "bias": # Simple gate: use a gru_style gate with a learned bias (no kernel). bias = jnp.asarray(self.gate_bias, dtype=self.dtype) bias = jnp.reshape(bias, (1,) * (y_out.ndim - 1) + (-1,)) # batch dims. g = jax.nn.sigmoid(bias) logging.info("mlp: gate bias = %r", g) return (state * g) + (y_out * (1 - g)) elif self.gate_type == "full": # Normal GRU style gate -- compute g using both a kernel and bias. g = jax.nn.sigmoid(self.gate_layer(y_hidden) + 1) # biased to remember logging.info("mlp: gate full = %r", g) return (state * g) + (y_out * (1 - g)) elif self.gate_type == "lstm": # LSTM style gate with input and forget gates. fg = jax.nn.sigmoid(self.forget_gate(y_hidden) + 1) # biased to remember ig = jax.nn.sigmoid(self.input_gate(y_hidden) - 1) logging.info("mlp: gate lstm = %r, %r", ig, fg) return (state * fg) + (y_out * ig) else: raise ValueError("Unsupported gate type %s" % self.gate_type) def __call__(self, x: Array, state: Optional[Array], apply_dropout: bool = False, dropout_rate: float = 0.0, drop_tile_shape: Optional[Shape] = None, rng_function: Optional[Callable[[], Any]] = None) -> Array: """Apply the multi-layer perceptron to the input x. For simple MLPs, returns f(x), where f is the MLP function. For resnets and gated architectures, it returns state + f(x) -- for resnet. g*state + (1-g)*f(x) -- for gated architecture, where g is the gate. Args: x: The input to the MLP. state: The prior value, if this MLP is used as part of a resnet or gated architecture. apply_dropout: If true, applies dropout to the result. dropout_rate: The dropout rate to use. drop_tile_shape: The dropout tile shape. rng_function: Gets a random number seed for dropout. Returns: The combination of f(x) and the (optional) prior state. """ x = jnp.asarray(x, self.dtype) hidden_act_fun = get_activation_function(self.hidden_activation) final_act_fun = get_activation_function(self.final_activation) if self.hidden_layers: # Apply some number of hidden layers. y = x for layer in self.hidden_layers: logging.info("mlp: hidden %d, %s", self.num_hidden_units, self.hidden_activation) y = hidden_act_fun(layer(y)) else: # Apply the hidden activation function to the input. logging.info("mlp: activation = %s", self.hidden_activation) y = hidden_act_fun(x) y_hidden = y # The hidden layer right before the output. logging.info("mlp: final activation = %s", self.final_activation) y_out = self.output_layer(y_hidden) # The MLP final output. y_out = final_act_fun(y_out) # Apply final activation function. logging.info("mlp: final = %r", y_out) # Optionally apply dropout to the output. if apply_dropout: if drop_tile_shape is None: raise ValueError("drop_tile_shape must be specified for dropout.") if rng_function is None: raise ValueError("rng_function must be specified for dropout.") logging.info("mlp: dropout rate = %s", dropout_rate) y_out = tiled_dropout( y_out, shape=drop_tile_shape, dropout_rate=dropout_rate, rng_function=rng_function, deterministic=False) if state is None: # Simple MLP. No gate to combine y_out with the state. assert self.gate_type is None logging.info("mlp: gate type = None.") return y_out # When using state, gate_type must be specified. assert self.gate_type is not None return self._gate(y_hidden, state, y_out) # Modified slightly from the flax implementation. @gin.configurable class LayerNorm(nn.Module): """Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data. It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1. Attributes: epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). use_mean: If True, compute and adjust for the mean. Note that that T5X layernorm does not use the mean. Empirically, ignoring the mean can stabilize learning in transformers. use_scalar_scale_bias: If True, using a single scalar for scale & bias. enable_layernorm: If False, does not perform layernorm. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. """ epsilon: float = 1e-6 dtype: Any = jnp.float32 use_scale: bool = True # Apply a learned scale. use_bias: bool = False # Apply a learned bias. use_mean: bool = False # Calculate and adjust for the mean. use_scalar_scale_bias: bool = False # Learn a single scalar scale & bias. enable_layernorm: bool = True # Turn off layernorm if false. bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones @nn.compact def __call__(self, x): """Applies layer normalization on the input. Args: x: the inputs Returns: Normalized inputs (the same shape as inputs). """ if not self.enable_layernorm: return x x = jnp.asarray(x) # Calculate mean and variance at higher precision. xf = jnp.asarray(x, jnp.float32) if self.use_mean: mean = jnp.mean(xf, axis=-1, keepdims=True) xf = xf - mean var = jnp.mean(lax.square(xf), axis=-1, keepdims=True) mul = lax.rsqrt(var + self.epsilon) # Rescale x # if not use_mean, then rescale around zero instead. (A simplification.) if self.use_mean: y = (x - mean) * mul else: y = x * mul if self.use_scalar_scale_bias: # Learn a single scalar value for bias and scale. # (Which mirrors the single value for mean and stddev above.) num_scale_bias_features = 1 else: # Learn a different value per neuron/feature for bias and scale. num_scale_bias_features = x.shape[-1] # Apply learned scale and bias. if self.use_scale: y = y * jnp.asarray( self.param("scale", self.scale_init, (num_scale_bias_features,)), dtype=self.dtype) if self.use_bias: y = y + jnp.asarray( self.param("bias", self.bias_init, (num_scale_bias_features,)), dtype=self.dtype) return y.astype(self.dtype)