File size: 390 Bytes
2e4e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

from transformer import decoder_stack
from transformer import nn_components

decoder_stack.DecoderStack:
  final_mlp_factory = @decoder_final_mlp/nn_components.MLP

# Add a final MLP for token prediction after the last transformer layer.
decoder_final_mlp/nn_components.MLP:
  num_hidden_units = %MLP_DIM
  num_layers = 2
  activation_function = "relu"
  use_bias = False
  dtype = %DTYPE