S5 / Essay_classifier /s5 /train_helpers.py
dbal0503's picture
Upload 693 files
2ce7b1a
from functools import partial
import jax
import jax.numpy as np
from jax.nn import one_hot
from tqdm import tqdm
from flax.training import train_state
import optax
from typing import Any, Tuple
# LR schedulers
def linear_warmup(step, base_lr, end_step, lr_min=None):
return base_lr * (step + 1) / end_step
def cosine_annealing(step, base_lr, end_step, lr_min=1e-6):
# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py#L207#L240
count = np.minimum(step, end_step)
cosine_decay = 0.5 * (1 + np.cos(np.pi * count / end_step))
decayed = (base_lr - lr_min) * cosine_decay + lr_min
return decayed
def reduce_lr_on_plateau(input, factor=0.2, patience=20, lr_min=1e-6):
lr, ssm_lr, count, new_acc, opt_acc = input
if new_acc > opt_acc:
count = 0
opt_acc = new_acc
else:
count += 1
if count > patience:
lr = factor * lr
ssm_lr = factor * ssm_lr
count = 0
if lr < lr_min:
lr = lr_min
if ssm_lr < lr_min:
ssm_lr = lr_min
return lr, ssm_lr, count, opt_acc
def constant_lr(step, base_lr, end_step, lr_min=None):
return base_lr
def update_learning_rate_per_step(lr_params, state):
decay_function, ssm_lr, lr, step, end_step, opt_config, lr_min = lr_params
# Get decayed value
lr_val = decay_function(step, lr, end_step, lr_min)
ssm_lr_val = decay_function(step, ssm_lr, end_step, lr_min)
step += 1
# Update state
state.opt_state.inner_states['regular'].inner_state.hyperparams['learning_rate'] = np.array(lr_val, dtype=np.float32)
state.opt_state.inner_states['ssm'].inner_state.hyperparams['learning_rate'] = np.array(ssm_lr_val, dtype=np.float32)
if opt_config in ["BandCdecay"]:
# In this case we are applying the ssm learning rate to B, even though
# we are also using weight decay on B
state.opt_state.inner_states['none'].inner_state.hyperparams['learning_rate'] = np.array(ssm_lr_val, dtype=np.float32)
return state, step
def map_nested_fn(fn):
"""
Recursively apply `fn to the key-value pairs of a nested dict / pytree.
We use this for some of the optax definitions below.
"""
def map_fn(nested_dict):
return {
k: (map_fn(v) if hasattr(v, "keys") else fn(k, v))
for k, v in nested_dict.items()
}
return map_fn
def create_train_state(model_cls,
rng,
padded,
retrieval,
in_dim=1,
bsz=128,
seq_len=784,
weight_decay=0.01,
batchnorm=False,
opt_config="standard",
ssm_lr=1e-3,
lr=1e-3,
dt_global=False
):
"""
Initializes the training state using optax
:param model_cls:
:param rng:
:param padded:
:param retrieval:
:param in_dim:
:param bsz:
:param seq_len:
:param weight_decay:
:param batchnorm:
:param opt_config:
:param ssm_lr:
:param lr:
:param dt_global:
:return:
"""
if padded:
if retrieval:
# For retrieval tasks we have two different sets of "documents"
dummy_input = (np.ones((2*bsz, seq_len, in_dim)), np.ones(2*bsz))
integration_timesteps = np.ones((2*bsz, seq_len,))
else:
dummy_input = (np.ones((bsz, seq_len, in_dim)), np.ones(bsz))
integration_timesteps = np.ones((bsz, seq_len,))
else:
dummy_input = np.ones((bsz, seq_len, in_dim))
integration_timesteps = np.ones((bsz, seq_len, ))
model = model_cls(training=True)
init_rng, dropout_rng = jax.random.split(rng, num=2)
variables = model.init({"params": init_rng,
"dropout": dropout_rng},
dummy_input, integration_timesteps,
)
if batchnorm:
params = variables["params"].unfreeze()
batch_stats = variables["batch_stats"]
else:
params = variables["params"].unfreeze()
# Note: `unfreeze()` is for using Optax.
if opt_config in ["standard"]:
"""This option applies weight decay to C, but B is kept with the
SSM parameters with no weight decay.
"""
print("configuring standard optimization setup")
if dt_global:
ssm_fn = map_nested_fn(
lambda k, _: "ssm"
if k in ["B", "Lambda_re", "Lambda_im", "norm"]
else ("none" if k in [] else "regular")
)
else:
ssm_fn = map_nested_fn(
lambda k, _: "ssm"
if k in ["B", "Lambda_re", "Lambda_im", "log_step", "norm"]
else ("none" if k in [] else "regular")
)
tx = optax.multi_transform(
{
"none": optax.inject_hyperparams(optax.sgd)(learning_rate=0.0),
"ssm": optax.inject_hyperparams(optax.adam)(learning_rate=ssm_lr),
"regular": optax.inject_hyperparams(optax.adamw)(learning_rate=lr,
weight_decay=weight_decay),
},
ssm_fn,
)
elif opt_config in ["BandCdecay"]:
"""This option applies weight decay to both C and B. Note we still apply the
ssm learning rate to B.
"""
print("configuring optimization with B in AdamW setup")
if dt_global:
ssm_fn = map_nested_fn(
lambda k, _: "ssm"
if k in ["Lambda_re", "Lambda_im", "norm"]
else ("none" if k in ["B"] else "regular")
)
else:
ssm_fn = map_nested_fn(
lambda k, _: "ssm"
if k in ["Lambda_re", "Lambda_im", "log_step", "norm"]
else ("none" if k in ["B"] else "regular")
)
tx = optax.multi_transform(
{
"none": optax.inject_hyperparams(optax.adamw)(learning_rate=ssm_lr,
weight_decay=weight_decay),
"ssm": optax.inject_hyperparams(optax.adam)(learning_rate=ssm_lr),
"regular": optax.inject_hyperparams(optax.adamw)(learning_rate=lr,
weight_decay=weight_decay),
},
ssm_fn,
)
elif opt_config in ["BfastandCdecay"]:
"""This option applies weight decay to both C and B. Note here we apply
faster global learning rate to B also.
"""
print("configuring optimization with B in AdamW setup with lr")
if dt_global:
ssm_fn = map_nested_fn(
lambda k, _: "ssm"
if k in ["Lambda_re", "Lambda_im", "norm"]
else ("none" if k in [] else "regular")
)
else:
ssm_fn = map_nested_fn(
lambda k, _: "ssm"
if k in ["Lambda_re", "Lambda_im", "log_step", "norm"]
else ("none" if k in [] else "regular")
)
tx = optax.multi_transform(
{
"none": optax.inject_hyperparams(optax.adamw)(learning_rate=0.0),
"ssm": optax.inject_hyperparams(optax.adam)(learning_rate=ssm_lr),
"regular": optax.inject_hyperparams(optax.adamw)(learning_rate=lr,
weight_decay=weight_decay),
},
ssm_fn,
)
elif opt_config in ["noBCdecay"]:
"""This option does not apply weight decay to B or C. C is included
with the SSM parameters and uses ssm learning rate.
"""
print("configuring optimization with C not in AdamW setup")
if dt_global:
ssm_fn = map_nested_fn(
lambda k, _: "ssm"
if k in ["B", "C", "C1", "C2", "D",
"Lambda_re", "Lambda_im", "norm"]
else ("none" if k in [] else "regular")
)
else:
ssm_fn = map_nested_fn(
lambda k, _: "ssm"
if k in ["B", "C", "C1", "C2", "D",
"Lambda_re", "Lambda_im", "log_step", "norm"]
else ("none" if k in [] else "regular")
)
tx = optax.multi_transform(
{
"none": optax.inject_hyperparams(optax.sgd)(learning_rate=0.0),
"ssm": optax.inject_hyperparams(optax.adam)(learning_rate=ssm_lr),
"regular": optax.inject_hyperparams(optax.adamw)(learning_rate=lr,
weight_decay=weight_decay),
},
ssm_fn,
)
fn_is_complex = lambda x: x.dtype in [np.complex64, np.complex128]
param_sizes = map_nested_fn(lambda k, param: param.size * (2 if fn_is_complex(param) else 1))(params)
print(f"[*] Trainable Parameters: {sum(jax.tree_leaves(param_sizes))}")
if batchnorm:
class TrainState(train_state.TrainState):
batch_stats: Any
return TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats)
else:
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
# Train and eval steps
@partial(np.vectorize, signature="(c),()->()")
def cross_entropy_loss(logits, label):
one_hot_label = jax.nn.one_hot(label, num_classes=logits.shape[0])
return -np.sum(one_hot_label * logits)
@partial(np.vectorize, signature="(c),()->()")
def compute_accuracy(logits, label):
return np.argmax(logits) == label
def prep_batch(batch: tuple,
seq_len: int,
in_dim: int) -> Tuple[np.ndarray, np.ndarray, np.array]:
"""
Take a batch and convert it to a standard x/y format.
:param batch: (x, y, aux_data) as returned from dataloader.
:param seq_len: (int) length of sequence.
:param in_dim: (int) dimension of input.
:return:
"""
if len(batch) == 2:
inputs, targets = batch
aux_data = {}
elif len(batch) == 3:
inputs, targets, aux_data = batch
else:
raise RuntimeError("Err... not sure what I should do... Unhandled data type. ")
# Convert to JAX.
inputs = np.asarray(inputs.numpy())
# Grab lengths from aux if it is there.
lengths = aux_data.get('lengths', None)
# Make all batches have same sequence length
num_pad = seq_len - inputs.shape[1]
if num_pad > 0:
# Assuming vocab padding value is zero
inputs = np.pad(inputs, ((0, 0), (0, num_pad)), 'constant', constant_values=(0,))
# Inputs is either [n_batch, seq_len] or [n_batch, seq_len, in_dim].
# If there are not three dimensions and trailing dimension is not equal to in_dim then
# transform into one-hot. This should be a fairly reliable fix.
if (inputs.ndim < 3) and (inputs.shape[-1] != in_dim):
inputs = one_hot(np.asarray(inputs), in_dim)
# If there are lengths, bundle them up.
if lengths is not None:
lengths = np.asarray(lengths.numpy())
full_inputs = (inputs.astype(float), lengths.astype(float))
else:
full_inputs = inputs.astype(float)
# Convert and apply.
targets = np.array(targets.numpy())
# If there is an aux channel containing the integration times, then add that.
if 'timesteps' in aux_data.keys():
integration_timesteps = np.diff(np.asarray(aux_data['timesteps'].numpy()))
else:
integration_timesteps = np.ones((len(inputs), seq_len))
return full_inputs, targets.astype(float), integration_timesteps
def train_epoch(state, rng, model, trainloader, seq_len, in_dim, batchnorm, lr_params):
"""
Training function for an epoch that loops over batches.
"""
# Store Metrics
model = model(training=True)
batch_losses = []
decay_function, ssm_lr, lr, step, end_step, opt_config, lr_min = lr_params
for batch_idx, batch in enumerate(tqdm(trainloader)):
inputs, labels, integration_times = prep_batch(batch, seq_len, in_dim)
rng, drop_rng = jax.random.split(rng)
state, loss = train_step(
state,
drop_rng,
inputs,
labels,
integration_times,
model,
batchnorm,
)
batch_losses.append(loss)
lr_params = (decay_function, ssm_lr, lr, step, end_step, opt_config, lr_min)
state, step = update_learning_rate_per_step(lr_params, state)
# Return average loss over batches
return state, np.mean(np.array(batch_losses)), step
def validate(state, model, testloader, seq_len, in_dim, batchnorm, step_rescale=1.0):
"""Validation function that loops over batches"""
model = model(training=False, step_rescale=step_rescale)
losses, accuracies, preds = np.array([]), np.array([]), np.array([])
for batch_idx, batch in enumerate(tqdm(testloader)):
inputs, labels, integration_timesteps = prep_batch(batch, seq_len, in_dim)
loss, acc, pred = eval_step(inputs, labels, integration_timesteps, state, model, batchnorm)
losses = np.append(losses, loss)
accuracies = np.append(accuracies, acc)
aveloss, aveaccu = np.mean(losses), np.mean(accuracies)
return aveloss, aveaccu
@partial(jax.jit, static_argnums=(5, 6))
def train_step(state,
rng,
batch_inputs,
batch_labels,
batch_integration_timesteps,
model,
batchnorm,
):
"""Performs a single training step given a batch of data"""
def loss_fn(params):
if batchnorm:
logits, mod_vars = model.apply(
{"params": params, "batch_stats": state.batch_stats},
batch_inputs, batch_integration_timesteps,
rngs={"dropout": rng},
mutable=["intermediates", "batch_stats"],
)
else:
logits, mod_vars = model.apply(
{"params": params},
batch_inputs, batch_integration_timesteps,
rngs={"dropout": rng},
mutable=["intermediates"],
)
loss = np.mean(cross_entropy_loss(logits, batch_labels))
return loss, (mod_vars, logits)
(loss, (mod_vars, logits)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
if batchnorm:
state = state.apply_gradients(grads=grads, batch_stats=mod_vars["batch_stats"])
else:
state = state.apply_gradients(grads=grads)
return state, loss
@partial(jax.jit, static_argnums=(4, 5))
def eval_step(batch_inputs,
batch_labels,
batch_integration_timesteps,
state,
model,
batchnorm,
):
if batchnorm:
logits = model.apply({"params": state.params, "batch_stats": state.batch_stats},
batch_inputs, batch_integration_timesteps,
)
else:
logits = model.apply({"params": state.params},
batch_inputs, batch_integration_timesteps,
)
losses = cross_entropy_loss(logits, batch_labels)
accs = compute_accuracy(logits, batch_labels)
return losses, accs, logits