from transformer import decoder_stack from transformer import nn_components decoder_stack.DecoderStack: final_mlp_factory = @decoder_final_mlp/nn_components.MLP # Add a final MLP for token prediction after the last transformer layer. decoder_final_mlp/nn_components.MLP: num_hidden_units = %MLP_DIM num_layers = 2 activation_function = "relu" use_bias = False dtype = %DTYPE