| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import transformer_engine.pytorch |
| | from torch import nn |
| | from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding |
| | from transformers.configuration_utils import PretrainedConfig |
| | from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput |
| | from transformers.modeling_utils import PreTrainedModel |
| |
|
| |
|
| | class AMPLIFYConfig(PretrainedConfig): |
| | """AMPLIFY model configuration.""" |
| |
|
| | model_type = "AMPLIFY" |
| |
|
| | |
| | def __init__( |
| | self, |
| | hidden_size: int = 960, |
| | num_hidden_layers: int = 32, |
| | num_attention_heads: int = 15, |
| | intermediate_size: int = 3840, |
| | dropout_prob: float = 0, |
| | embedding_init_range: float = 0.02, |
| | decoder_init_range: float = 0.02, |
| | rms_norm: bool = True, |
| | norm_eps: float = 1e-05, |
| | hidden_act: str = "SwiGLU", |
| | layer_norm_after_embedding: bool = False, |
| | layer_norm_before_last_layer: bool = True, |
| | vocab_size: int = 27, |
| | padded_vocab_size: int = 32, |
| | ffn_bias: bool = False, |
| | att_bias: bool = False, |
| | pad_token_id: int = 0, |
| | max_length: int = 2048, |
| | **kwargs, |
| | ): |
| | """Initialize a AMPLIFYConfig. |
| | |
| | Args: |
| | hidden_size (int): The hidden size of the model. |
| | num_hidden_layers (int): The number of hidden layers in the model. |
| | num_attention_heads (int): The number of attention heads in the model. |
| | intermediate_size (int): The intermediate size of the model. |
| | dropout_prob (float): The dropout probability of the model. |
| | embedding_init_range (float): The range of the embedding initialization. |
| | decoder_init_range (float): The range of the decoder initialization. |
| | rms_norm (bool): Whether to use RMSNorm. |
| | norm_eps (float): The epsilon for the normalization. |
| | hidden_act (str): The activation function of the model. |
| | layer_norm_after_embedding (bool): Whether to use layer normalization after the embedding. |
| | layer_norm_before_last_layer (bool): Whether to use layer normalization before the last layer. |
| | vocab_size (int): The vocabulary size of the model. |
| | padded_vocab_size (int): The padded vocabulary size of the model to support fp8. |
| | ffn_bias (bool): Whether to use bias in the feedforward network. |
| | att_bias (bool): Whether to use bias in the attention. |
| | pad_token_id (int): The padding token id. |
| | max_length (int): The maximum length of the sequence. |
| | **kwargs: Additional arguments. |
| | """ |
| | super().__init__(**kwargs) |
| |
|
| | self.hidden_size = hidden_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.intermediate_size = intermediate_size |
| | self.dropout_prob = dropout_prob |
| | self.embedding_init_range = embedding_init_range |
| | self.decoder_init_range = decoder_init_range |
| | self.rms_norm = rms_norm |
| | self.norm_eps = norm_eps |
| | self.hidden_act = hidden_act |
| | self.layer_norm_after_embedding = layer_norm_after_embedding |
| | self.layer_norm_before_last_layer = layer_norm_before_last_layer |
| | self.vocab_size = vocab_size |
| | self.padded_vocab_size = padded_vocab_size |
| | self.ffn_bias = ffn_bias |
| | self.att_bias = att_bias |
| | self.pad_token_id = pad_token_id |
| | self.max_length = max_length |
| |
|
| | assert self.padded_vocab_size >= self.vocab_size, ( |
| | "padded_vocab_size must be greater than or equal to vocab_size" |
| | ) |
| |
|
| |
|
| | class AMPLIFYPreTrainedModel(PreTrainedModel): |
| | """AMPLIFY pre-trained model.""" |
| |
|
| | config: AMPLIFYConfig |
| | config_class = AMPLIFYConfig |
| | base_model_prefix = "amplify" |
| |
|
| | def _init_weights(self, module): |
| | if isinstance( |
| | module, (nn.Linear, transformer_engine.pytorch.Linear, transformer_engine.pytorch.LayerNormLinear) |
| | ): |
| | module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | if isinstance(module, nn.Embedding): |
| | module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range) |
| |
|
| |
|
| | class AMPLIFY(AMPLIFYPreTrainedModel): |
| | """The main model class.""" |
| |
|
| | def __init__(self, config: AMPLIFYConfig, **kwargs): |
| | """Initialize a AMPLIFY model. |
| | |
| | Args: |
| | config (AMPLIFYConfig): The configuration of the model. |
| | **kwargs: Additional arguments. |
| | """ |
| | super().__init__(config) |
| |
|
| | self.config = config |
| |
|
| | self.encoder = nn.Embedding( |
| | config.padded_vocab_size, |
| | config.hidden_size, |
| | padding_idx=config.pad_token_id, |
| | dtype=config.dtype, |
| | ) |
| |
|
| | if config.layer_norm_after_embedding: |
| | self.layer_norm_1 = ( |
| | transformer_engine.pytorch.RMSNorm(config.hidden_size, config.norm_eps, params_dtype=config.dtype) |
| | if config.rms_norm |
| | else transformer_engine.pytorch.LayerNorm( |
| | config.hidden_size, config.norm_eps, params_dtype=config.dtype |
| | ) |
| | ) |
| |
|
| | if config.hidden_act.lower() == "swiglu": |
| | |
| | |
| | |
| | multiple_of = 8 |
| | intermediate_size = int(2 * config.intermediate_size / 3) |
| | intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of) |
| |
|
| | else: |
| | intermediate_size = config.intermediate_size |
| |
|
| | self.transformer_encoder = nn.ModuleList() |
| | for layer_num in range(config.num_hidden_layers): |
| | self.transformer_encoder.append( |
| | transformer_engine.pytorch.TransformerLayer( |
| | hidden_size=config.hidden_size, |
| | ffn_hidden_size=intermediate_size, |
| | num_attention_heads=config.num_attention_heads, |
| | layernorm_epsilon=config.norm_eps, |
| | hidden_dropout=config.dropout_prob, |
| | attention_dropout=config.dropout_prob, |
| | apply_residual_connection_post_layernorm=False, |
| | layer_type="encoder", |
| | self_attn_mask_type="padding", |
| | normalization="RMSNorm" if config.rms_norm else "LayerNorm", |
| | fuse_qkv_params=True, |
| | qkv_weight_interleaved=True, |
| | output_layernorm=False, |
| | bias=False, |
| | activation=config.hidden_act.lower(), |
| | attn_input_format="bshd", |
| | layer_number=layer_num + 1, |
| | name="encoder_block", |
| | window_size=(-1, -1), |
| | rotary_pos_interleaved=True, |
| | seq_length=config.max_length, |
| | params_dtype=config.dtype, |
| | ) |
| | ) |
| |
|
| | self.freqs_cis = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads, interleaved=True)( |
| | config.max_length |
| | ) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids, |
| | attention_mask=None, |
| | output_hidden_states=False, |
| | output_attentions=False, |
| | labels=None, |
| | ) -> BaseModelOutput: |
| | """Forward pass of the AMPLIFY model. |
| | |
| | Args: |
| | input_ids (torch.Tensor): The input ids. |
| | attention_mask (torch.Tensor): The attention mask. |
| | output_hidden_states (bool): Whether to output the hidden states. |
| | output_attentions (bool): Whether to output the attention weights. |
| | labels (torch.Tensor): The labels. |
| | |
| | Returns: |
| | BaseModelOutput: The output of the model. |
| | """ |
| | |
| | hidden_states = [] |
| |
|
| | |
| | if attention_mask is not None and attention_mask.dtype is torch.int64: |
| | |
| | attention_mask = ~attention_mask.to(bool) |
| |
|
| | |
| | self.freqs_cis = self.freqs_cis.to(input_ids.device, non_blocking=True) |
| | freqs_cis = self.freqs_cis[: input_ids.shape[1]] |
| |
|
| | |
| | x = self.encoder(input_ids) |
| | if self.config.layer_norm_after_embedding: |
| | x = self.layer_norm_1(x) |
| |
|
| | |
| | for layer in self.transformer_encoder: |
| | x = layer(x, attention_mask, rotary_pos_emb=freqs_cis) |
| | if output_hidden_states: |
| | hidden_states.append(x) |
| | if output_attentions: |
| | raise ValueError("output_attentions is not supported for TE") |
| |
|
| | return BaseModelOutput( |
| | last_hidden_state=x, |
| | hidden_states=tuple(hidden_states) if hidden_states else None, |
| | attentions=None, |
| | ) |
| |
|
| |
|
| | class AMPLIFYForMaskedLM(AMPLIFYPreTrainedModel): |
| | """AMPLIFY for masked language modeling.""" |
| |
|
| | def __init__(self, config: AMPLIFYConfig, **kwargs): |
| | """Initialize a AMPLIFYForMaskedLM model. |
| | |
| | Args: |
| | config (AMPLIFYConfig): The configuration of the model. |
| | **kwargs: Additional arguments. |
| | """ |
| | super().__init__(config) |
| | self.amplify = AMPLIFY(config, **kwargs) |
| |
|
| | if config.layer_norm_before_last_layer: |
| | self.decoder = transformer_engine.pytorch.LayerNormLinear( |
| | config.hidden_size, |
| | config.padded_vocab_size, |
| | config.norm_eps, |
| | params_dtype=config.dtype, |
| | normalization="RMSNorm" if config.rms_norm else "LayerNorm", |
| | init_method=lambda x: torch.nn.init.uniform_( |
| | x, -self.config.decoder_init_range, self.config.decoder_init_range |
| | ), |
| | ) |
| |
|
| | else: |
| | self.decoder = transformer_engine.pytorch.Linear( |
| | config.hidden_size, config.vocab_size, params_dtype=config.dtype |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input_ids, |
| | attention_mask=None, |
| | output_hidden_states=False, |
| | output_attentions=False, |
| | labels=None, |
| | ) -> MaskedLMOutput: |
| | """Forward pass of the AMPLIFYForMaskedLM model. |
| | |
| | Args: |
| | input_ids (torch.Tensor): The input ids. |
| | attention_mask (torch.Tensor): The attention mask. |
| | output_hidden_states (bool): Whether to output the hidden states. |
| | output_attentions (bool): Whether to output the attention weights. |
| | labels (torch.Tensor): The labels. |
| | |
| | Returns: |
| | MaskedLMOutput: The output of the model. |
| | """ |
| | outputs = self.amplify( |
| | input_ids, |
| | attention_mask, |
| | output_hidden_states, |
| | output_attentions, |
| | labels, |
| | ) |
| |
|
| | |
| | logits = self.decoder(outputs.last_hidden_state) |
| | if self.config.padded_vocab_size != self.config.vocab_size: |
| | logits = logits[:, :, : self.config.vocab_size] |
| |
|
| | if labels is not None: |
| | loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) |
| |
|
| | else: |
| | loss = None |
| |
|
| | |
| | return MaskedLMOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=outputs.hidden_states, |
| | ) |
| |
|