from functools import partial |
import jax |
import jax.numpy as np |
from jax.nn import one_hot |
from tqdm import tqdm |
from flax.training import train_state |
import optax |
from typing import Any, Tuple |
def linear_warmup(step, base_lr, end_step, lr_min=None): |
return base_lr * (step + 1) / end_step |
def cosine_annealing(step, base_lr, end_step, lr_min=1e-6): |
count = np.minimum(step, end_step) |
cosine_decay = 0.5 * (1 + np.cos(np.pi * count / end_step)) |
decayed = (base_lr - lr_min) * cosine_decay + lr_min |
return decayed |
def reduce_lr_on_plateau(input, factor=0.2, patience=20, lr_min=1e-6): |
lr, ssm_lr, count, new_acc, opt_acc = input |
if new_acc > opt_acc: |
count = 0 |
opt_acc = new_acc |
else: |
count += 1 |
if count > patience: |
lr = factor * lr |
ssm_lr = factor * ssm_lr |
count = 0 |
if lr < lr_min: |
lr = lr_min |
if ssm_lr < lr_min: |
ssm_lr = lr_min |
return lr, ssm_lr, count, opt_acc |
def constant_lr(step, base_lr, end_step, lr_min=None): |
return base_lr |
def update_learning_rate_per_step(lr_params, state): |
decay_function, ssm_lr, lr, step, end_step, opt_config, lr_min = lr_params |
lr_val = decay_function(step, lr, end_step, lr_min) |
ssm_lr_val = decay_function(step, ssm_lr, end_step, lr_min) |
step += 1 |
state.opt_state.inner_states['regular'].inner_state.hyperparams['learning_rate'] = np.array(lr_val, dtype=np.float32) |
state.opt_state.inner_states['ssm'].inner_state.hyperparams['learning_rate'] = np.array(ssm_lr_val, dtype=np.float32) |
if opt_config in ["BandCdecay"]: |
state.opt_state.inner_states['none'].inner_state.hyperparams['learning_rate'] = np.array(ssm_lr_val, dtype=np.float32) |
return state, step |
def map_nested_fn(fn): |
""" |
Recursively apply `fn to the key-value pairs of a nested dict / pytree. |
We use this for some of the optax definitions below. |
""" |
def map_fn(nested_dict): |
return { |
k: (map_fn(v) if hasattr(v, "keys") else fn(k, v)) |
for k, v in nested_dict.items() |
} |
return map_fn |
def create_train_state(model_cls, |
rng, |
padded, |
retrieval, |
in_dim=1, |
bsz=128, |
seq_len=784, |
weight_decay=0.01, |
batchnorm=False, |
opt_config="standard", |
ssm_lr=1e-3, |
lr=1e-3, |
dt_global=False |
): |
""" |
Initializes the training state using optax |
:param model_cls: |
:param rng: |
:param padded: |
:param retrieval: |
:param in_dim: |
:param bsz: |
:param seq_len: |
:param weight_decay: |
:param batchnorm: |
:param opt_config: |
:param ssm_lr: |
:param lr: |
:param dt_global: |
:return: |
""" |
if padded: |
if retrieval: |
dummy_input = (np.ones((2*bsz, seq_len, in_dim)), np.ones(2*bsz)) |
integration_timesteps = np.ones((2*bsz, seq_len,)) |
else: |
dummy_input = (np.ones((bsz, seq_len, in_dim)), np.ones(bsz)) |
integration_timesteps = np.ones((bsz, seq_len,)) |
else: |
dummy_input = np.ones((bsz, seq_len, in_dim)) |
integration_timesteps = np.ones((bsz, seq_len, )) |
model = model_cls(training=True) |
init_rng, dropout_rng = jax.random.split(rng, num=2) |
variables = model.init({"params": init_rng, |
"dropout": dropout_rng}, |
dummy_input, integration_timesteps, |
) |
if batchnorm: |
params = variables["params"].unfreeze() |
batch_stats = variables["batch_stats"] |
else: |
params = variables["params"].unfreeze() |
if opt_config in ["standard"]: |
"""This option applies weight decay to C, but B is kept with the |
SSM parameters with no weight decay. |
""" |
print("configuring standard optimization setup") |
if dt_global: |
ssm_fn = map_nested_fn( |
lambda k, _: "ssm" |
if k in ["B", "Lambda_re", "Lambda_im", "norm"] |
else ("none" if k in [] else "regular") |
) |
else: |
ssm_fn = map_nested_fn( |
lambda k, _: "ssm" |
if k in ["B", "Lambda_re", "Lambda_im", "log_step", "norm"] |
else ("none" if k in [] else "regular") |
) |
tx = optax.multi_transform( |
{ |
"none": optax.inject_hyperparams(optax.sgd)(learning_rate=0.0), |
"ssm": optax.inject_hyperparams(optax.adam)(learning_rate=ssm_lr), |
"regular": optax.inject_hyperparams(optax.adamw)(learning_rate=lr, |
weight_decay=weight_decay), |
}, |
ssm_fn, |
) |
elif opt_config in ["BandCdecay"]: |
"""This option applies weight decay to both C and B. Note we still apply the |
ssm learning rate to B. |
""" |
print("configuring optimization with B in AdamW setup") |
if dt_global: |
ssm_fn = map_nested_fn( |
lambda k, _: "ssm" |
if k in ["Lambda_re", "Lambda_im", "norm"] |
else ("none" if k in ["B"] else "regular") |
) |
else: |
ssm_fn = map_nested_fn( |
lambda k, _: "ssm" |
if k in ["Lambda_re", "Lambda_im", "log_step", "norm"] |
else ("none" if k in ["B"] else "regular") |
) |
tx = optax.multi_transform( |
{ |
"none": optax.inject_hyperparams(optax.adamw)(learning_rate=ssm_lr, |
weight_decay=weight_decay), |
"ssm": optax.inject_hyperparams(optax.adam)(learning_rate=ssm_lr), |
"regular": optax.inject_hyperparams(optax.adamw)(learning_rate=lr, |
weight_decay=weight_decay), |
}, |
ssm_fn, |
) |
elif opt_config in ["BfastandCdecay"]: |
"""This option applies weight decay to both C and B. Note here we apply |
faster global learning rate to B also. |
""" |
print("configuring optimization with B in AdamW setup with lr") |
if dt_global: |
ssm_fn = map_nested_fn( |
lambda k, _: "ssm" |
if k in ["Lambda_re", "Lambda_im", "norm"] |
else ("none" if k in [] else "regular") |
) |
else: |
ssm_fn = map_nested_fn( |
lambda k, _: "ssm" |
if k in ["Lambda_re", "Lambda_im", "log_step", "norm"] |
else ("none" if k in [] else "regular") |
) |
tx = optax.multi_transform( |
{ |
"none": optax.inject_hyperparams(optax.adamw)(learning_rate=0.0), |
"ssm": optax.inject_hyperparams(optax.adam)(learning_rate=ssm_lr), |
"regular": optax.inject_hyperparams(optax.adamw)(learning_rate=lr, |
weight_decay=weight_decay), |
}, |
ssm_fn, |
) |
elif opt_config in ["noBCdecay"]: |
"""This option does not apply weight decay to B or C. C is included |
with the SSM parameters and uses ssm learning rate. |
""" |
print("configuring optimization with C not in AdamW setup") |
if dt_global: |
ssm_fn = map_nested_fn( |
lambda k, _: "ssm" |
if k in ["B", "C", "C1", "C2", "D", |
"Lambda_re", "Lambda_im", "norm"] |
else ("none" if k in [] else "regular") |
) |
else: |
ssm_fn = map_nested_fn( |
lambda k, _: "ssm" |
if k in ["B", "C", "C1", "C2", "D", |
"Lambda_re", "Lambda_im", "log_step", "norm"] |
else ("none" if k in [] else "regular") |
) |
tx = optax.multi_transform( |
{ |
"none": optax.inject_hyperparams(optax.sgd)(learning_rate=0.0), |
"ssm": optax.inject_hyperparams(optax.adam)(learning_rate=ssm_lr), |
"regular": optax.inject_hyperparams(optax.adamw)(learning_rate=lr, |
weight_decay=weight_decay), |
}, |
ssm_fn, |
) |
fn_is_complex = lambda x: x.dtype in [np.complex64, np.complex128] |
param_sizes = map_nested_fn(lambda k, param: param.size * (2 if fn_is_complex(param) else 1))(params) |
print(f"[*] Trainable Parameters: {sum(jax.tree_leaves(param_sizes))}") |
if batchnorm: |
class TrainState(train_state.TrainState): |
batch_stats: Any |
return TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats) |
else: |
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) |
@partial(np.vectorize, signature="(c),()->()") |
def cross_entropy_loss(logits, label): |
one_hot_label = jax.nn.one_hot(label, num_classes=logits.shape[0]) |
return -np.sum(one_hot_label * logits) |
@partial(np.vectorize, signature="(c),()->()") |
def compute_accuracy(logits, label): |
return np.argmax(logits) == label |
def prep_batch(batch: tuple, |
seq_len: int, |
in_dim: int) -> Tuple[np.ndarray, np.ndarray, np.array]: |
""" |
Take a batch and convert it to a standard x/y format. |
:param batch: (x, y, aux_data) as returned from dataloader. |
:param seq_len: (int) length of sequence. |
:param in_dim: (int) dimension of input. |
:return: |
""" |
if len(batch) == 2: |
inputs, targets = batch |
aux_data = {} |
elif len(batch) == 3: |
inputs, targets, aux_data = batch |
else: |
raise RuntimeError("Err... not sure what I should do... Unhandled data type. ") |
inputs = np.asarray(inputs.numpy()) |
lengths = aux_data.get('lengths', None) |
num_pad = seq_len - inputs.shape[1] |
if num_pad > 0: |
inputs = np.pad(inputs, ((0, 0), (0, num_pad)), 'constant', constant_values=(0,)) |
if (inputs.ndim < 3) and (inputs.shape[-1] != in_dim): |
inputs = one_hot(np.asarray(inputs), in_dim) |
if lengths is not None: |
lengths = np.asarray(lengths.numpy()) |
full_inputs = (inputs.astype(float), lengths.astype(float)) |
else: |
full_inputs = inputs.astype(float) |
targets = np.array(targets.numpy()) |
if 'timesteps' in aux_data.keys(): |
integration_timesteps = np.diff(np.asarray(aux_data['timesteps'].numpy())) |
else: |
integration_timesteps = np.ones((len(inputs), seq_len)) |
return full_inputs, targets.astype(float), integration_timesteps |
def train_epoch(state, rng, model, trainloader, seq_len, in_dim, batchnorm, lr_params): |
""" |
Training function for an epoch that loops over batches. |
""" |
model = model(training=True) |
batch_losses = [] |
decay_function, ssm_lr, lr, step, end_step, opt_config, lr_min = lr_params |
for batch_idx, batch in enumerate(tqdm(trainloader)): |
inputs, labels, integration_times = prep_batch(batch, seq_len, in_dim) |
rng, drop_rng = jax.random.split(rng) |
state, loss = train_step( |
state, |
drop_rng, |
inputs, |
labels, |
integration_times, |
model, |
batchnorm, |
) |
batch_losses.append(loss) |
lr_params = (decay_function, ssm_lr, lr, step, end_step, opt_config, lr_min) |
state, step = update_learning_rate_per_step(lr_params, state) |
return state, np.mean(np.array(batch_losses)), step |
def validate(state, model, testloader, seq_len, in_dim, batchnorm, step_rescale=1.0): |
"""Validation function that loops over batches""" |
model = model(training=False, step_rescale=step_rescale) |
losses, accuracies, preds = np.array([]), np.array([]), np.array([]) |
for batch_idx, batch in enumerate(tqdm(testloader)): |
inputs, labels, integration_timesteps = prep_batch(batch, seq_len, in_dim) |
loss, acc, pred = eval_step(inputs, labels, integration_timesteps, state, model, batchnorm) |
losses = np.append(losses, loss) |
accuracies = np.append(accuracies, acc) |
aveloss, aveaccu = np.mean(losses), np.mean(accuracies) |
return aveloss, aveaccu |
@partial(jax.jit, static_argnums=(5, 6)) |
def train_step(state, |
rng, |
batch_inputs, |
batch_labels, |
batch_integration_timesteps, |
model, |
batchnorm, |
): |
"""Performs a single training step given a batch of data""" |
def loss_fn(params): |
if batchnorm: |
logits, mod_vars = model.apply( |
{"params": params, "batch_stats": state.batch_stats}, |
batch_inputs, batch_integration_timesteps, |
rngs={"dropout": rng}, |
mutable=["intermediates", "batch_stats"], |
) |
else: |
logits, mod_vars = model.apply( |
{"params": params}, |
batch_inputs, batch_integration_timesteps, |
rngs={"dropout": rng}, |
mutable=["intermediates"], |
) |
loss = np.mean(cross_entropy_loss(logits, batch_labels)) |
return loss, (mod_vars, logits) |
(loss, (mod_vars, logits)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) |
if batchnorm: |
state = state.apply_gradients(grads=grads, batch_stats=mod_vars["batch_stats"]) |
else: |
state = state.apply_gradients(grads=grads) |
return state, loss |
@partial(jax.jit, static_argnums=(4, 5)) |
def eval_step(batch_inputs, |
batch_labels, |
batch_integration_timesteps, |
state, |
model, |
batchnorm, |
): |
if batchnorm: |
logits = model.apply({"params": state.params, "batch_stats": state.batch_stats}, |
batch_inputs, batch_integration_timesteps, |
) |
else: |
logits = model.apply({"params": state.params}, |
batch_inputs, batch_integration_timesteps, |
) |
losses = cross_entropy_loss(logits, batch_labels) |
accs = compute_accuracy(logits, batch_labels) |
return losses, accs, logits |