esm2_t36_3B_UR50D / esm_nv.py
pstjohn's picture
Upload folder using huggingface_hub
cf74082 verified
# coding=utf-8
# noqa: license-check
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TransformerEngine-optimized ESM model.
Adapted from `modeling_esm.py` in huggingface/transformers.
"""
from typing import Literal, Optional
# TODO: put import guard around transformer_engine here, with an informative error message around
# installation and the nvidia docker container.
import torch
import transformer_engine.pytorch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.esm.configuration_esm import EsmConfig
from transformers.models.esm.modeling_esm import EsmEmbeddings, EsmPooler
from transformers.utils import logging
logger = logging.get_logger(__name__)
class NVEsmConfig(EsmConfig):
"""NVEsmConfig is a configuration for the NVEsm model."""
model_type: str = "nv_esm"
def __init__(
self,
qkv_weight_interleaved: bool = True,
encoder_activation: str = "gelu",
attn_input_format: Literal["bshd", "thd"] = "bshd",
fuse_qkv_params: bool = True,
micro_batch_size: Optional[int] = None,
max_seq_length: Optional[int] = None,
padded_vocab_size: Optional[int] = 64,
**kwargs,
):
"""Initialize the NVEsmConfig with additional TE-related config options.
Args:
qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the
QKV weight is interpreted as a concatenation of query, key, and value weights along
the `0th` dimension. The default interpretation is that the individual `q`, `k`, and
`v` weights for each attention head are interleaved. This parameter is set to `False`
when using :attr:`fuse_qkv_params=False`.
encoder_activation: The activation function to use in the encoder.
attn_input_format: The input format to use for the attention. This controls
whether the dimensions of the intermediate hidden states is 'batch first'
('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length,
`b` batch size, `h` the number of heads, `d` head size. Note that these
formats are very closely related to the `qkv_format` in the
`MultiHeadAttention` and `DotProductAttention` modules.
fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`,
`TransformerLayer` module exposes a single fused parameter for query-key-value.
This enables optimizations such as QKV fusion without concatentations/splits and
also enables the argument `fuse_wgrad_accumulation`.
micro_batch_size: The micro batch size to use for the attention. This is needed for
JIT Warmup, a technique where jit fused functions are warmed up before training to
ensure same kernels are used for forward propogation and activation recompute phase.
max_seq_length: The maximum sequence length to use for the attention. This is needed for
JIT Warmup, a technique where jit fused functions are warmed up before training to
ensure same kernels are used for forward propogation and activation recompute phase.
padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults
to vocab_size. Must be greater than or equal to vocab_size.
**kwargs: Additional config options to pass to EsmConfig.
"""
super().__init__(**kwargs)
# Additional TE-related config options.
self.qkv_weight_interleaved = qkv_weight_interleaved
self.encoder_activation = encoder_activation
self.attn_input_format = attn_input_format
self.fuse_qkv_params = fuse_qkv_params
self.micro_batch_size = micro_batch_size
self.max_seq_length = max_seq_length
# Set padded_vocab_size with default fallback to vocab_size
self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size
# Ensure padded_vocab_size is at least as large as vocab_size
if self.padded_vocab_size is not None and self.vocab_size is not None:
assert self.padded_vocab_size >= self.vocab_size, (
f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})"
)
class NVEsmEncoder(nn.Module):
"""NVEsmEncoder is a TransformerEngine-optimized ESM encoder."""
def __init__(self, config: NVEsmConfig):
"""Initialize a NVEsmEncoder.
Args:
config (NVEsmConfig): The configuration of the model.
"""
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
transformer_engine.pytorch.TransformerLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
layernorm_epsilon=config.layer_norm_eps,
hidden_dropout=config.hidden_dropout_prob,
attention_dropout=config.attention_probs_dropout_prob,
qkv_weight_interleaved=config.qkv_weight_interleaved,
layer_number=i + 1,
layer_type="encoder",
self_attn_mask_type="padding",
activation=config.encoder_activation,
attn_input_format=config.attn_input_format,
seq_length=config.max_seq_length,
micro_batch_size=config.micro_batch_size,
num_gqa_groups=config.num_attention_heads,
fuse_qkv_params=config.fuse_qkv_params,
params_dtype=config.torch_dtype,
window_size=(-1, -1),
)
for i in range(config.num_hidden_layers)
]
)
self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if config.position_embedding_type == "rotary":
self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
# Keep on CPU, pin for faster non_blocking H2D; don't persist in state_dict.
if config.attn_input_format == "bshd":
self.register_buffer(
"te_rope_emb",
self.rotary_embeddings(max_seq_len=config.max_position_embeddings).cpu().pin_memory(),
persistent=False,
)
else:
self.te_rope_emb = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
cu_seq_lens_q: torch.IntTensor | None = None,
cu_seq_lens_k: torch.IntTensor | None = None,
max_length_q: int | None = None,
max_length_k: int | None = None,
):
"""Forward pass of the NVEsmEncoder.
Args:
hidden_states (torch.Tensor): The hidden states.
attention_mask (torch.Tensor): The attention mask.
output_hidden_states (bool): Whether to output the hidden states.
cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
cu_seq_lens_k (torch.IntTensor): The cumulative sequence lengths for the key state, if using THD inputs.
max_length_q (int): The maximum length for the query state, if using THD inputs.
max_length_k (int): The maximum length for the key state, if using THD inputs.
"""
all_hidden_states: tuple[torch.Tensor, ...] = ()
if self.config.attn_input_format == "thd":
if any(x is None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
raise ValueError(
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
)
assert hidden_states.dim() == 3 and hidden_states.size(0) == 1, (
"THD expects embeddings shaped [1, total_tokens, hidden_size]."
)
hidden_states = hidden_states.squeeze(0)
elif self.config.attn_input_format == "bshd":
if any(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
raise ValueError(
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
)
te_rope_emb = None
if self.config.position_embedding_type == "rotary":
if self.config.attn_input_format == "bshd":
te_rope_emb = self.te_rope_emb.to(
device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
)
seq_len = hidden_states.shape[1]
if te_rope_emb.size(0) < seq_len:
raise RuntimeError(
f"ROPE length {te_rope_emb.size(0)} < input seq length {seq_len}. "
f"Increase max_position_embeddings."
)
te_rope_emb = te_rope_emb[:seq_len]
elif self.config.attn_input_format == "thd":
assert cu_seq_lens_q is not None
te_rope_emb = self.rotary_embeddings(max_seq_len=cu_seq_lens_q[-1]).to(
device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
)
else:
raise ValueError(f"Unsupported attention input format: {self.config.attn_input_format}")
for layer_module in self.layers:
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
hidden_states = layer_module(
hidden_states,
attention_mask,
rotary_pos_emb=te_rope_emb,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_kv=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_kv=max_length_k,
)
hidden_states = self.emb_layer_norm_after(hidden_states)
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states if all_hidden_states else None,
)
class NVEsmPreTrainedModel(PreTrainedModel):
"""An abstract class to handle weights initialization and pretrained model loading."""
config_class = NVEsmConfig
base_model_prefix = "esm"
supports_gradient_checkpointing = False
_no_split_modules = (
"TransformerLayer",
"EsmEmbeddings",
)
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module: nn.Module):
"""Initialize the weights.
Args:
module (nn.Module): The module to initialize the weights for.
"""
if isinstance(
module, (nn.Linear, transformer_engine.pytorch.Linear, transformer_engine.pytorch.LayerNormLinear)
):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, (nn.LayerNorm, transformer_engine.pytorch.LayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
module.layer_norm_weight.data.fill_(1.0)
if module.layer_norm_bias is not None:
module.layer_norm_bias.data.zero_()
class NVEsmModel(NVEsmPreTrainedModel):
"""The ESM Encoder-only protein language model.
This model uses NVDIA's TransformerEngine to optimize attention layer training and inference.
"""
def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True):
"""Initialize a NVEsmModel.
Args:
config (NVEsmConfig): The configuration of the model.
add_pooling_layer (bool): Whether to add a pooling layer.
"""
super().__init__(config)
self.config = config
# Create EsmEmbeddings with temporarily modified config to use padded vocab size
# This ensures the word embeddings layer uses the padded vocabulary size for FP8 support
original_vocab_size = config.vocab_size
config.vocab_size = config.padded_vocab_size
# Ensure pad_token_id is set properly, defaulting to 0 if not specified
if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
config.pad_token_id = 0
self.embeddings = EsmEmbeddings(config)
config.vocab_size = original_vocab_size # Restore original vocab_size
self.encoder = NVEsmEncoder(config)
self.pooler = EsmPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
"""Get the input embeddings of the model."""
return self.embeddings.word_embeddings
def set_input_embeddings(self, value: torch.Tensor):
"""Set the input embeddings of the model.
Args:
value (torch.Tensor): The input embeddings.
"""
self.embeddings.word_embeddings = value
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
cu_seq_lens_q: torch.IntTensor | None = None,
cu_seq_lens_k: torch.IntTensor | None = None,
max_length_q: int | None = None,
max_length_k: int | None = None,
) -> BaseModelOutputWithPooling:
"""Forward pass of the NVEsmModel.
Args:
input_ids (torch.Tensor): The input ids.
attention_mask (torch.Tensor): The attention mask.
position_ids (torch.Tensor): The position ids.
head_mask (torch.Tensor): The head mask.
inputs_embeds (torch.Tensor): The input embeddings.
output_hidden_states (bool): Whether to output the hidden states.
cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
cu_seq_lens_k (torch.IntTensor): The cumulative sequence lengths for the key state, if using THD inputs.
max_length_q (int): The maximum length for the query state, if using THD inputs.
max_length_k (int): The maximum length for the key state, if using THD inputs.
Returns:
BaseModelOutputWithPooling: The output of the model.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# TE expects a boolean attention mask, where 1s are masked and 0s are not masked
extended_attention_mask = extended_attention_mask < -1
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
output_hidden_states=output_hidden_states,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
)
class NVEsmForMaskedLM(NVEsmPreTrainedModel):
"""NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling."""
_tied_weights_keys = ("lm_head.decoder.weight",)
def __init__(self, config: NVEsmConfig):
"""Initialize a NVEsmForMaskedLM.
Args:
config (NVEsmConfig): The configuration of the model.
"""
super().__init__(config)
if config.is_decoder:
logger.warning(
"If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.esm = NVEsmModel(config, add_pooling_layer=False)
self.lm_head = NVEsmLMHead(config)
self.init_weights()
self.post_init()
def get_output_embeddings(self):
"""Get the output embeddings of the model."""
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
"""Set the output embeddings of the model."""
self.lm_head.decoder = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
cu_seq_lens_q: torch.IntTensor | None = None,
cu_seq_lens_k: torch.IntTensor | None = None,
max_length_q: int | None = None,
max_length_k: int | None = None,
) -> MaskedLMOutput:
"""Forward pass of the NVEsmForMaskedLM.
Args:
input_ids (torch.LongTensor): The input ids.
attention_mask (torch.Tensor): The attention mask.
position_ids (torch.LongTensor): The position ids.
inputs_embeds (torch.FloatTensor): The input embeddings.
labels (torch.LongTensor): The labels.
output_hidden_states (bool): Whether to output the hidden states.
cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
cu_seq_lens_k (torch.IntTensor): The cumulative sequence lengths for the key state, if using THD inputs.
max_length_q (int): The maximum length for the query state, if using THD inputs.
max_length_k (int): The maximum length for the key state, if using THD inputs.
Returns:
MaskedLMOutput: The output of the model.
"""
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
# Truncate logits back to original vocab_size if padding was used
if self.config.padded_vocab_size != self.config.vocab_size:
prediction_scores = prediction_scores[..., : self.config.vocab_size]
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.to(prediction_scores.device).view(-1),
)
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
)
class NVEsmLMHead(nn.Module):
"""ESM Head for masked language modeling using TransformerEngine."""
def __init__(self, config: NVEsmConfig):
"""Initialize a NVEsmLMHead.
Args:
config (NVEsmConfig): The configuration of the model.
"""
super().__init__()
self.dense = transformer_engine.pytorch.Linear(config.hidden_size, config.hidden_size)
self.decoder = transformer_engine.pytorch.LayerNormLinear(
config.hidden_size,
config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size,
bias=True,
eps=config.layer_norm_eps,
)
def forward(self, features, **kwargs):
"""Forward pass of the NVEsmLMHead.
Args:
features (torch.Tensor): The features.
**kwargs: Additional arguments.
"""
x = self.dense(features)
x = torch.nn.functional.gelu(x)
x = self.decoder(x)
return x