HugoVoxx's picture
Upload 33 files
2e4e57e verified
raw
history blame
390 Bytes
from transformer import decoder_stack
from transformer import nn_components
decoder_stack.DecoderStack:
final_mlp_factory = @decoder_final_mlp/nn_components.MLP
# Add a final MLP for token prediction after the last transformer layer.
decoder_final_mlp/nn_components.MLP:
num_hidden_units = %MLP_DIM
num_layers = 2
activation_function = "relu"
use_bias = False
dtype = %DTYPE