Spaces:
Runtime error
Runtime error
feat: scan layers + gradient checkpointing (#161)
Browse files* scan layers for faster compilation
* support gradient checkpointing
src/dalle_mini/model/configuration.py
CHANGED
|
@@ -51,7 +51,8 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 51 |
activation_dropout=0.0,
|
| 52 |
init_std=0.02,
|
| 53 |
scale_embedding=False,
|
| 54 |
-
gradient_checkpointing=
|
|
|
|
| 55 |
use_cache=True,
|
| 56 |
is_encoder_decoder=True,
|
| 57 |
forced_eos_token_id=None,
|
|
@@ -59,7 +60,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 59 |
do_sample=True,
|
| 60 |
# transformer variants
|
| 61 |
use_bias=False, # use bias in attention and dense layers (except for lm_head)
|
| 62 |
-
ln_type="
|
| 63 |
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
|
| 64 |
use_head_scale=False, # used in NormFormer
|
| 65 |
use_cosine_attention=False, # used in Swin v2
|
|
@@ -67,7 +68,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 67 |
use_absolute_position_embeddings=True, # default
|
| 68 |
use_swin_position_embeddings=False, # used in Swin v1/v2
|
| 69 |
use_deepnet_scaling=False, # used in Deepnet
|
| 70 |
-
use_glu=
|
| 71 |
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
| 72 |
sinkhorn_iters=1, # used in SinkFormers
|
| 73 |
use_final_ln_encoder=True, # final layer normalization in encoder
|
|
@@ -136,6 +137,11 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 136 |
self.init_std = init_std
|
| 137 |
self.use_cache = use_cache
|
| 138 |
self.gradient_checkpointing = gradient_checkpointing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
self.scale_embedding = (
|
| 140 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
| 141 |
)
|
|
|
|
| 51 |
activation_dropout=0.0,
|
| 52 |
init_std=0.02,
|
| 53 |
scale_embedding=False,
|
| 54 |
+
gradient_checkpointing=True,
|
| 55 |
+
use_scan=None,
|
| 56 |
use_cache=True,
|
| 57 |
is_encoder_decoder=True,
|
| 58 |
forced_eos_token_id=None,
|
|
|
|
| 60 |
do_sample=True,
|
| 61 |
# transformer variants
|
| 62 |
use_bias=False, # use bias in attention and dense layers (except for lm_head)
|
| 63 |
+
ln_type="rmsnorm", # layer normalization type, "rmsnorm", "layernorm"
|
| 64 |
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
|
| 65 |
use_head_scale=False, # used in NormFormer
|
| 66 |
use_cosine_attention=False, # used in Swin v2
|
|
|
|
| 68 |
use_absolute_position_embeddings=True, # default
|
| 69 |
use_swin_position_embeddings=False, # used in Swin v1/v2
|
| 70 |
use_deepnet_scaling=False, # used in Deepnet
|
| 71 |
+
use_glu=True, # "GLU Variants Improve Transformer"
|
| 72 |
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
| 73 |
sinkhorn_iters=1, # used in SinkFormers
|
| 74 |
use_final_ln_encoder=True, # final layer normalization in encoder
|
|
|
|
| 137 |
self.init_std = init_std
|
| 138 |
self.use_cache = use_cache
|
| 139 |
self.gradient_checkpointing = gradient_checkpointing
|
| 140 |
+
# all layers are the same in most configurations
|
| 141 |
+
self.use_scan = use_scan if use_scan is not None else ln_positions != "swinv2"
|
| 142 |
+
assert not (
|
| 143 |
+
self.use_scan and ln_positions == "swinv2"
|
| 144 |
+
), "scan cannot be used with 'swinv2'"
|
| 145 |
self.scale_embedding = (
|
| 146 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
| 147 |
)
|
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -619,6 +619,9 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 619 |
deterministic: bool = True,
|
| 620 |
) -> Tuple[jnp.ndarray]:
|
| 621 |
|
|
|
|
|
|
|
|
|
|
| 622 |
res_gain = (
|
| 623 |
deepnet_gain["encoder"]["alpha"](self.config)
|
| 624 |
if self.config.use_deepnet_scaling
|
|
@@ -679,12 +682,8 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 679 |
)
|
| 680 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
| 681 |
hidden_states = residual * res_gain + hidden_states
|
| 682 |
-
if self.add_norm
|
| 683 |
-
use_scale =
|
| 684 |
-
self.use_scale
|
| 685 |
-
or self.config.ln_positions == "postln"
|
| 686 |
-
or self.config.force_ln_scale
|
| 687 |
-
)
|
| 688 |
hidden_states = norm(
|
| 689 |
self.config.ln_type,
|
| 690 |
dtype=self.dtype,
|
|
@@ -697,6 +696,9 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 697 |
if output_attentions:
|
| 698 |
outputs += (attn_weights,)
|
| 699 |
|
|
|
|
|
|
|
|
|
|
| 700 |
return outputs
|
| 701 |
|
| 702 |
|
|
@@ -710,7 +712,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 710 |
config: DalleBartConfig
|
| 711 |
dtype: jnp.dtype = jnp.float32
|
| 712 |
add_norm: bool = False
|
| 713 |
-
use_scale: bool =
|
| 714 |
|
| 715 |
@nn.compact
|
| 716 |
def __call__(
|
|
@@ -724,6 +726,9 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 724 |
deterministic: bool = True,
|
| 725 |
) -> Tuple[jnp.ndarray]:
|
| 726 |
|
|
|
|
|
|
|
|
|
|
| 727 |
res_gain = (
|
| 728 |
deepnet_gain["decoder"]["alpha"](self.config)
|
| 729 |
if self.config.use_deepnet_scaling
|
|
@@ -831,12 +836,8 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 831 |
)
|
| 832 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
| 833 |
hidden_states = residual * res_gain + hidden_states
|
| 834 |
-
if self.add_norm
|
| 835 |
-
use_scale =
|
| 836 |
-
self.use_scale
|
| 837 |
-
or self.config.ln_positions == "postln"
|
| 838 |
-
or self.config.force_ln_scale
|
| 839 |
-
)
|
| 840 |
hidden_states = norm(
|
| 841 |
self.config.ln_type,
|
| 842 |
dtype=self.dtype,
|
|
@@ -849,6 +850,9 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 849 |
if output_attentions:
|
| 850 |
outputs += (attn_weights, cross_attn_weights)
|
| 851 |
|
|
|
|
|
|
|
|
|
|
| 852 |
return outputs
|
| 853 |
|
| 854 |
|
|
@@ -876,35 +880,80 @@ class FlaxBartEncoderLayerCollection(nn.Module):
|
|
| 876 |
|
| 877 |
n_layers = self.config.encoder_layers
|
| 878 |
layer = (
|
| 879 |
-
remat(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
if self.config.gradient_checkpointing
|
| 881 |
else FlaxBartEncoderLayer
|
| 882 |
)
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 895 |
)(
|
| 896 |
hidden_states,
|
| 897 |
attention_mask,
|
| 898 |
output_attentions,
|
| 899 |
deterministic,
|
| 900 |
)
|
| 901 |
-
hidden_states =
|
| 902 |
-
|
| 903 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
|
| 909 |
outputs = [
|
| 910 |
hidden_states,
|
|
@@ -953,22 +1002,39 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
| 953 |
|
| 954 |
n_layers = self.config.decoder_layers
|
| 955 |
layer = (
|
| 956 |
-
remat(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 957 |
if self.config.gradient_checkpointing
|
| 958 |
else FlaxBartDecoderLayer
|
| 959 |
)
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 972 |
)(
|
| 973 |
hidden_states,
|
| 974 |
attention_mask,
|
|
@@ -978,17 +1044,56 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
| 978 |
output_attentions,
|
| 979 |
deterministic,
|
| 980 |
)
|
|
|
|
| 981 |
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
|
| 986 |
-
|
| 987 |
-
|
|
|
|
| 988 |
|
| 989 |
-
#
|
| 990 |
-
if
|
| 991 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 992 |
|
| 993 |
outputs = [
|
| 994 |
hidden_states,
|
|
|
|
| 619 |
deterministic: bool = True,
|
| 620 |
) -> Tuple[jnp.ndarray]:
|
| 621 |
|
| 622 |
+
if self.config.use_scan:
|
| 623 |
+
hidden_states = hidden_states[0]
|
| 624 |
+
|
| 625 |
res_gain = (
|
| 626 |
deepnet_gain["encoder"]["alpha"](self.config)
|
| 627 |
if self.config.use_deepnet_scaling
|
|
|
|
| 682 |
)
|
| 683 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
| 684 |
hidden_states = residual * res_gain + hidden_states
|
| 685 |
+
if self.add_norm:
|
| 686 |
+
use_scale = self.use_scale or self.config.force_ln_scale
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
hidden_states = norm(
|
| 688 |
self.config.ln_type,
|
| 689 |
dtype=self.dtype,
|
|
|
|
| 696 |
if output_attentions:
|
| 697 |
outputs += (attn_weights,)
|
| 698 |
|
| 699 |
+
if self.config.use_scan:
|
| 700 |
+
outputs = (outputs, None)
|
| 701 |
+
|
| 702 |
return outputs
|
| 703 |
|
| 704 |
|
|
|
|
| 712 |
config: DalleBartConfig
|
| 713 |
dtype: jnp.dtype = jnp.float32
|
| 714 |
add_norm: bool = False
|
| 715 |
+
use_scale: bool = True
|
| 716 |
|
| 717 |
@nn.compact
|
| 718 |
def __call__(
|
|
|
|
| 726 |
deterministic: bool = True,
|
| 727 |
) -> Tuple[jnp.ndarray]:
|
| 728 |
|
| 729 |
+
if self.config.use_scan:
|
| 730 |
+
hidden_states = hidden_states[0]
|
| 731 |
+
|
| 732 |
res_gain = (
|
| 733 |
deepnet_gain["decoder"]["alpha"](self.config)
|
| 734 |
if self.config.use_deepnet_scaling
|
|
|
|
| 836 |
)
|
| 837 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
| 838 |
hidden_states = residual * res_gain + hidden_states
|
| 839 |
+
if self.add_norm:
|
| 840 |
+
use_scale = self.use_scale or self.config.force_ln_scale
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
hidden_states = norm(
|
| 842 |
self.config.ln_type,
|
| 843 |
dtype=self.dtype,
|
|
|
|
| 850 |
if output_attentions:
|
| 851 |
outputs += (attn_weights, cross_attn_weights)
|
| 852 |
|
| 853 |
+
if self.config.use_scan:
|
| 854 |
+
outputs = (outputs, None)
|
| 855 |
+
|
| 856 |
return outputs
|
| 857 |
|
| 858 |
|
|
|
|
| 880 |
|
| 881 |
n_layers = self.config.encoder_layers
|
| 882 |
layer = (
|
| 883 |
+
remat(
|
| 884 |
+
FlaxBartEncoderLayer,
|
| 885 |
+
static_argnums=(2, 3),
|
| 886 |
+
prevent_cse=not self.config.use_scan,
|
| 887 |
+
)
|
| 888 |
if self.config.gradient_checkpointing
|
| 889 |
else FlaxBartEncoderLayer
|
| 890 |
)
|
| 891 |
+
|
| 892 |
+
if self.config.use_scan:
|
| 893 |
+
# all blocks are the same so we use nn.scan
|
| 894 |
+
assert not output_attentions, "cannot scan with output_attentions"
|
| 895 |
+
assert not output_hidden_states, "cannot scan with output_hidden_states"
|
| 896 |
+
hidden_states = (hidden_states,)
|
| 897 |
+
# we use a scale on all norms (even last layer) to allow scanning
|
| 898 |
+
hidden_states, _ = nn.scan(
|
| 899 |
+
layer,
|
| 900 |
+
variable_axes={"params": 0},
|
| 901 |
+
split_rngs={"params": True, "dropout": True},
|
| 902 |
+
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
|
| 903 |
+
length=n_layers,
|
| 904 |
+
)(
|
| 905 |
+
self.config,
|
| 906 |
+
dtype=self.dtype,
|
| 907 |
+
add_norm=self.config.ln_positions == "postln",
|
| 908 |
+
name="FlaxBartEncoderLayers",
|
| 909 |
)(
|
| 910 |
hidden_states,
|
| 911 |
attention_mask,
|
| 912 |
output_attentions,
|
| 913 |
deterministic,
|
| 914 |
)
|
| 915 |
+
hidden_states = hidden_states[0]
|
| 916 |
+
else:
|
| 917 |
+
for i in range(n_layers):
|
| 918 |
+
if output_hidden_states:
|
| 919 |
+
all_hidden_states += (hidden_states,)
|
| 920 |
+
# final layernorm on the output of the last layer
|
| 921 |
+
# or every 6 layers for Swin v2
|
| 922 |
+
add_norm = self.config.ln_positions == "postln" or (
|
| 923 |
+
self.config.ln_positions == "swinv2"
|
| 924 |
+
and ((i + 1) % 6 == 0)
|
| 925 |
+
and (i != n_layers - 1)
|
| 926 |
+
)
|
| 927 |
+
# we don't need to scale the norm for the last layer
|
| 928 |
+
use_scale = i != n_layers - 1
|
| 929 |
+
layer_outputs = layer(
|
| 930 |
+
self.config,
|
| 931 |
+
dtype=self.dtype,
|
| 932 |
+
add_norm=add_norm,
|
| 933 |
+
use_scale=use_scale,
|
| 934 |
+
name=f"FlaxBartEncoderLayer_{i}",
|
| 935 |
+
)(
|
| 936 |
+
hidden_states,
|
| 937 |
+
attention_mask,
|
| 938 |
+
output_attentions,
|
| 939 |
+
deterministic,
|
| 940 |
+
)
|
| 941 |
+
hidden_states = layer_outputs[0]
|
| 942 |
+
if output_attentions:
|
| 943 |
+
all_self_attns += (layer_outputs[1],)
|
| 944 |
|
| 945 |
+
# add hidden states from the last layer
|
| 946 |
+
if output_hidden_states:
|
| 947 |
+
all_hidden_states += (hidden_states,)
|
| 948 |
+
|
| 949 |
+
# postln is already applied in every layer
|
| 950 |
+
if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
|
| 951 |
+
hidden_states = norm(
|
| 952 |
+
self.config.ln_type,
|
| 953 |
+
dtype=self.dtype,
|
| 954 |
+
epsilon=1e-05,
|
| 955 |
+
use_scale=self.config.force_ln_scale,
|
| 956 |
+
)(hidden_states)
|
| 957 |
|
| 958 |
outputs = [
|
| 959 |
hidden_states,
|
|
|
|
| 1002 |
|
| 1003 |
n_layers = self.config.decoder_layers
|
| 1004 |
layer = (
|
| 1005 |
+
remat(
|
| 1006 |
+
FlaxBartDecoderLayer,
|
| 1007 |
+
static_argnums=(4, 5, 6),
|
| 1008 |
+
prevent_cse=not self.config.use_scan,
|
| 1009 |
+
)
|
| 1010 |
if self.config.gradient_checkpointing
|
| 1011 |
else FlaxBartDecoderLayer
|
| 1012 |
)
|
| 1013 |
+
|
| 1014 |
+
if self.config.use_scan:
|
| 1015 |
+
# all blocks are the same so we use nn.scan
|
| 1016 |
+
assert not output_attentions, "cannot scan with output_attentions"
|
| 1017 |
+
assert not output_hidden_states, "cannot scan with output_hidden_states"
|
| 1018 |
+
hidden_states = (hidden_states,)
|
| 1019 |
+
# we use a scale on all norms (even last layer) to allow scanning
|
| 1020 |
+
hidden_states, _ = nn.scan(
|
| 1021 |
+
layer,
|
| 1022 |
+
variable_axes={"params": 0},
|
| 1023 |
+
split_rngs={"params": True, "dropout": True},
|
| 1024 |
+
in_axes=(
|
| 1025 |
+
nn.broadcast,
|
| 1026 |
+
nn.broadcast,
|
| 1027 |
+
nn.broadcast,
|
| 1028 |
+
nn.broadcast,
|
| 1029 |
+
nn.broadcast,
|
| 1030 |
+
nn.broadcast,
|
| 1031 |
+
),
|
| 1032 |
+
length=n_layers,
|
| 1033 |
+
)(
|
| 1034 |
+
self.config,
|
| 1035 |
+
dtype=self.dtype,
|
| 1036 |
+
add_norm=self.config.ln_positions == "postln",
|
| 1037 |
+
name="FlaxBartEncoderLayers",
|
| 1038 |
)(
|
| 1039 |
hidden_states,
|
| 1040 |
attention_mask,
|
|
|
|
| 1044 |
output_attentions,
|
| 1045 |
deterministic,
|
| 1046 |
)
|
| 1047 |
+
hidden_states = hidden_states[0]
|
| 1048 |
|
| 1049 |
+
else:
|
| 1050 |
+
for i in range(n_layers):
|
| 1051 |
+
if output_hidden_states:
|
| 1052 |
+
all_hidden_states += (hidden_states,)
|
| 1053 |
+
# final layernorm on the output of the last layer
|
| 1054 |
+
# or every 6 layers for Swin v2
|
| 1055 |
+
add_norm = self.config.ln_positions == "postln" or (
|
| 1056 |
+
self.config.ln_positions == "swinv2"
|
| 1057 |
+
and ((i + 1) % 6 == 0)
|
| 1058 |
+
and (i != n_layers - 1)
|
| 1059 |
+
)
|
| 1060 |
+
# we don't need to scale the norm for the last layer
|
| 1061 |
+
use_scale = i != n_layers - 1
|
| 1062 |
+
layer_outputs = layer(
|
| 1063 |
+
self.config,
|
| 1064 |
+
dtype=self.dtype,
|
| 1065 |
+
add_norm=add_norm,
|
| 1066 |
+
use_scale=use_scale,
|
| 1067 |
+
name=f"FlaxBartDecoderLayer_{i}",
|
| 1068 |
+
)(
|
| 1069 |
+
hidden_states,
|
| 1070 |
+
attention_mask,
|
| 1071 |
+
encoder_hidden_states,
|
| 1072 |
+
encoder_attention_mask,
|
| 1073 |
+
init_cache,
|
| 1074 |
+
output_attentions,
|
| 1075 |
+
deterministic,
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
hidden_states = layer_outputs[0]
|
| 1079 |
+
if output_attentions:
|
| 1080 |
+
all_self_attns += (layer_outputs[1],)
|
| 1081 |
+
|
| 1082 |
+
if encoder_hidden_states is not None:
|
| 1083 |
+
all_cross_attentions += (layer_outputs[2],)
|
| 1084 |
|
| 1085 |
+
# add hidden states from the last decoder layer
|
| 1086 |
+
if output_hidden_states:
|
| 1087 |
+
all_hidden_states += (hidden_states,)
|
| 1088 |
|
| 1089 |
+
# postln is already applied in every layer
|
| 1090 |
+
if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
|
| 1091 |
+
hidden_states = norm(
|
| 1092 |
+
self.config.ln_type,
|
| 1093 |
+
dtype=self.dtype,
|
| 1094 |
+
epsilon=1e-05,
|
| 1095 |
+
use_scale=self.config.force_ln_scale,
|
| 1096 |
+
)(hidden_states)
|
| 1097 |
|
| 1098 |
outputs = [
|
| 1099 |
hidden_states,
|
src/dalle_mini/model/partitions.py
CHANGED
|
@@ -55,7 +55,7 @@ def _get_partition_rules():
|
|
| 55 |
]
|
| 56 |
|
| 57 |
|
| 58 |
-
def set_partitions(in_dict):
|
| 59 |
rules = _get_partition_rules()
|
| 60 |
replace = _replacement_rules(rules)
|
| 61 |
initd = {k: _unmatched for k in flatten_dict(in_dict)}
|
|
@@ -63,5 +63,14 @@ def set_partitions(in_dict):
|
|
| 63 |
for k, v in result.items():
|
| 64 |
if v == _unmatched:
|
| 65 |
print(f"Unmatched -> {k}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
assert _unmatched not in result.values(), "Incomplete partition spec."
|
| 67 |
return freeze(unflatten_dict(result))
|
|
|
|
| 55 |
]
|
| 56 |
|
| 57 |
|
| 58 |
+
def set_partitions(in_dict, use_scan):
|
| 59 |
rules = _get_partition_rules()
|
| 60 |
replace = _replacement_rules(rules)
|
| 61 |
initd = {k: _unmatched for k in flatten_dict(in_dict)}
|
|
|
|
| 63 |
for k, v in result.items():
|
| 64 |
if v == _unmatched:
|
| 65 |
print(f"Unmatched -> {k}")
|
| 66 |
+
l = list(result.keys())
|
| 67 |
+
if use_scan:
|
| 68 |
+
# add None dimension to scanned layers
|
| 69 |
+
result = {
|
| 70 |
+
k: (P(*(None,) + v) if v is not None else None)
|
| 71 |
+
if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
|
| 72 |
+
else v
|
| 73 |
+
for k, v in result.items()
|
| 74 |
+
}
|
| 75 |
assert _unmatched not in result.values(), "Incomplete partition spec."
|
| 76 |
return freeze(unflatten_dict(result))
|
tools/train/config/mega/config.json
CHANGED
|
@@ -7,14 +7,14 @@
|
|
| 7 |
"decoder_attention_heads": 32,
|
| 8 |
"decoder_ffn_dim": 4096,
|
| 9 |
"decoder_layerdrop": 0.0,
|
| 10 |
-
"decoder_layers":
|
| 11 |
"decoder_start_token_id": 16384,
|
| 12 |
"do_sample": true,
|
| 13 |
"dropout": 0.0,
|
| 14 |
"encoder_attention_heads": 32,
|
| 15 |
"encoder_ffn_dim": 4096,
|
| 16 |
"encoder_layerdrop": 0.0,
|
| 17 |
-
"encoder_layers":
|
| 18 |
"encoder_vocab_size": 50272,
|
| 19 |
"eos_token_id": 16385,
|
| 20 |
"force_ln_scale": false,
|
|
|
|
| 7 |
"decoder_attention_heads": 32,
|
| 8 |
"decoder_ffn_dim": 4096,
|
| 9 |
"decoder_layerdrop": 0.0,
|
| 10 |
+
"decoder_layers": 26,
|
| 11 |
"decoder_start_token_id": 16384,
|
| 12 |
"do_sample": true,
|
| 13 |
"dropout": 0.0,
|
| 14 |
"encoder_attention_heads": 32,
|
| 15 |
"encoder_ffn_dim": 4096,
|
| 16 |
"encoder_layerdrop": 0.0,
|
| 17 |
+
"encoder_layers": 26,
|
| 18 |
"encoder_vocab_size": 50272,
|
| 19 |
"eos_token_id": 16385,
|
| 20 |
"force_ln_scale": false,
|
tools/train/train.py
CHANGED
|
@@ -42,6 +42,7 @@ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
|
| 42 |
from flax.serialization import from_bytes, to_bytes
|
| 43 |
from flax.training import train_state
|
| 44 |
from flax.training.common_utils import onehot
|
|
|
|
| 45 |
from jax.experimental import PartitionSpec, maps
|
| 46 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 47 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
@@ -531,6 +532,54 @@ class TrainState(train_state.TrainState):
|
|
| 531 |
train_time: float = 0.0 # total time the model trained
|
| 532 |
train_samples: int = 0 # number of samples seen
|
| 533 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
def main():
|
| 536 |
# See all possible arguments by passing the --help flag to this script.
|
|
@@ -618,7 +667,7 @@ def main():
|
|
| 618 |
model_metadata = model_args.get_metadata()
|
| 619 |
|
| 620 |
# get PartitionSpec for model params (required to be a dict)
|
| 621 |
-
param_spec = set_partitions(model.params)
|
| 622 |
|
| 623 |
# convert params to frozen dict
|
| 624 |
model._params = freeze(model.params)
|
|
@@ -743,6 +792,23 @@ def main():
|
|
| 743 |
|
| 744 |
learning_rate_fn = create_learning_rate_fn()
|
| 745 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
# create adam optimizer
|
| 747 |
if training_args.optim == "distributed_shampoo":
|
| 748 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
|
@@ -795,10 +861,12 @@ def main():
|
|
| 795 |
)
|
| 796 |
# get the real optimizer and helper functions
|
| 797 |
update_fn = optimizer.update
|
| 798 |
-
|
|
|
|
| 799 |
opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
|
| 800 |
optimizer.pspec_fn, optimizer.shape_and_dtype_fn
|
| 801 |
)
|
|
|
|
| 802 |
optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
|
| 803 |
|
| 804 |
elif training_args.optim == "adam":
|
|
@@ -819,7 +887,7 @@ def main():
|
|
| 819 |
# get PartitionSpec for optimizer state
|
| 820 |
def get_opt_state_spec_and_shape(param_spec):
|
| 821 |
# get opt_state shape without actual init
|
| 822 |
-
opt_state_shape = jax.eval_shape(optimizer.init,
|
| 823 |
|
| 824 |
if training_args.optim == "adam":
|
| 825 |
|
|
@@ -844,7 +912,7 @@ def main():
|
|
| 844 |
|
| 845 |
elif training_args.optim == "distributed_shampoo":
|
| 846 |
opt_state_spec = opt_fn.pspec_fn(
|
| 847 |
-
params=
|
| 848 |
params_partition_spec=param_spec,
|
| 849 |
partition_spec_for_statistics=PartitionSpec(None, "dp", None),
|
| 850 |
)
|
|
@@ -852,7 +920,7 @@ def main():
|
|
| 852 |
raise NotImplementedError
|
| 853 |
return opt_state_spec, opt_state_shape
|
| 854 |
|
| 855 |
-
opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(
|
| 856 |
|
| 857 |
# create a mesh
|
| 858 |
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
|
|
|
| 42 |
from flax.serialization import from_bytes, to_bytes
|
| 43 |
from flax.training import train_state
|
| 44 |
from flax.training.common_utils import onehot
|
| 45 |
+
from jax import ShapeDtypeStruct
|
| 46 |
from jax.experimental import PartitionSpec, maps
|
| 47 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 48 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
|
|
| 532 |
train_time: float = 0.0 # total time the model trained
|
| 533 |
train_samples: int = 0 # number of samples seen
|
| 534 |
|
| 535 |
+
def apply_gradients(self, *, grads, **kwargs):
|
| 536 |
+
params = self.unscan(self.params)
|
| 537 |
+
updates, new_opt_state = self.tx.update(
|
| 538 |
+
self.unscan(grads), self.opt_state, params
|
| 539 |
+
)
|
| 540 |
+
params = optax.apply_updates(params, updates)
|
| 541 |
+
return self.replace(
|
| 542 |
+
step=self.step + 1,
|
| 543 |
+
params=self.rescan(params),
|
| 544 |
+
opt_state=new_opt_state,
|
| 545 |
+
**kwargs,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
@classmethod
|
| 549 |
+
def create(cls, *, apply_fn, params, tx, **kwargs):
|
| 550 |
+
opt_state = tx.init(cls.unscan(params))
|
| 551 |
+
return cls(
|
| 552 |
+
step=0,
|
| 553 |
+
apply_fn=apply_fn,
|
| 554 |
+
params=params,
|
| 555 |
+
tx=tx,
|
| 556 |
+
opt_state=opt_state,
|
| 557 |
+
**kwargs,
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
@staticmethod
|
| 561 |
+
def unscan(params):
|
| 562 |
+
params = unfreeze(params)
|
| 563 |
+
for l in ["encoder", "decoder"]:
|
| 564 |
+
params["model"][l]["layers"] = jax.tree_map(
|
| 565 |
+
lambda x: {f"{i}": x[i] for i in range(len(x))},
|
| 566 |
+
params["model"][l]["layers"],
|
| 567 |
+
)
|
| 568 |
+
params = freeze(params)
|
| 569 |
+
return params
|
| 570 |
+
|
| 571 |
+
@staticmethod
|
| 572 |
+
def rescan(params):
|
| 573 |
+
params = unfreeze(params)
|
| 574 |
+
for l in ["encoder", "decoder"]:
|
| 575 |
+
params["model"][l]["layers"] = jax.tree_map(
|
| 576 |
+
lambda x: jnp.stack([x[f"{i}"] for i in range(len(x))]),
|
| 577 |
+
params["model"][l]["layers"],
|
| 578 |
+
is_leaf=lambda x: "0" in x,
|
| 579 |
+
)
|
| 580 |
+
params = freeze(params)
|
| 581 |
+
return params
|
| 582 |
+
|
| 583 |
|
| 584 |
def main():
|
| 585 |
# See all possible arguments by passing the --help flag to this script.
|
|
|
|
| 667 |
model_metadata = model_args.get_metadata()
|
| 668 |
|
| 669 |
# get PartitionSpec for model params (required to be a dict)
|
| 670 |
+
param_spec = set_partitions(model.params, model.config.use_scan)
|
| 671 |
|
| 672 |
# convert params to frozen dict
|
| 673 |
model._params = freeze(model.params)
|
|
|
|
| 792 |
|
| 793 |
learning_rate_fn = create_learning_rate_fn()
|
| 794 |
|
| 795 |
+
# reshape params to split scanned layers for optimizers
|
| 796 |
+
if model.config.use_scan:
|
| 797 |
+
params_struct = unfreeze(model.params)
|
| 798 |
+
for l in ["encoder", "decoder"]:
|
| 799 |
+
params_struct["model"][l]["layers"] = jax.tree_map(
|
| 800 |
+
lambda x: {
|
| 801 |
+
f"{i}": ShapeDtypeStruct(shape=x.shape[1:], dtype=x.dtype)
|
| 802 |
+
for i in range(len(x))
|
| 803 |
+
},
|
| 804 |
+
params_struct["model"][l]["layers"],
|
| 805 |
+
)
|
| 806 |
+
params_struct = freeze(params_struct)
|
| 807 |
+
|
| 808 |
+
else:
|
| 809 |
+
params_struct = model.params
|
| 810 |
+
opt_param_spec = set_partitions(params_struct, False)
|
| 811 |
+
|
| 812 |
# create adam optimizer
|
| 813 |
if training_args.optim == "distributed_shampoo":
|
| 814 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
|
|
|
| 861 |
)
|
| 862 |
# get the real optimizer and helper functions
|
| 863 |
update_fn = optimizer.update
|
| 864 |
+
|
| 865 |
+
optimizer = optimizer.init(params_struct)
|
| 866 |
opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
|
| 867 |
optimizer.pspec_fn, optimizer.shape_and_dtype_fn
|
| 868 |
)
|
| 869 |
+
|
| 870 |
optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
|
| 871 |
|
| 872 |
elif training_args.optim == "adam":
|
|
|
|
| 887 |
# get PartitionSpec for optimizer state
|
| 888 |
def get_opt_state_spec_and_shape(param_spec):
|
| 889 |
# get opt_state shape without actual init
|
| 890 |
+
opt_state_shape = jax.eval_shape(optimizer.init, params_struct)
|
| 891 |
|
| 892 |
if training_args.optim == "adam":
|
| 893 |
|
|
|
|
| 912 |
|
| 913 |
elif training_args.optim == "distributed_shampoo":
|
| 914 |
opt_state_spec = opt_fn.pspec_fn(
|
| 915 |
+
params=params_struct,
|
| 916 |
params_partition_spec=param_spec,
|
| 917 |
partition_spec_for_statistics=PartitionSpec(None, "dp", None),
|
| 918 |
)
|
|
|
|
| 920 |
raise NotImplementedError
|
| 921 |
return opt_state_spec, opt_state_shape
|
| 922 |
|
| 923 |
+
opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(opt_param_spec)
|
| 924 |
|
| 925 |
# create a mesh
|
| 926 |
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|