pstjohn commited on
Commit
a6a87b9
·
verified ·
1 Parent(s): 978c16e

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. esm_nv.py +9 -2
config.json CHANGED
@@ -33,7 +33,7 @@
33
  "qkv_weight_interleaved": true,
34
  "token_dropout": true,
35
  "torch_dtype": "float32",
36
- "transformers_version": "4.55.0.dev0",
37
  "use_cache": true,
38
  "vocab_list": null,
39
  "vocab_size": 33
 
33
  "qkv_weight_interleaved": true,
34
  "token_dropout": true,
35
  "torch_dtype": "float32",
36
+ "transformers_version": "4.55.4",
37
  "use_cache": true,
38
  "vocab_list": null,
39
  "vocab_size": 33
esm_nv.py CHANGED
@@ -1,4 +1,5 @@
1
  # coding=utf-8
 
2
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
  # SPDX-License-Identifier: LicenseRef-Apache2
4
  # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
@@ -137,7 +138,7 @@ class NVEsmEncoder(nn.Module):
137
  self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
138
  if config.position_embedding_type == "rotary":
139
  self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
140
- self.te_rope_emb = self.rotary_embeddings(max_seq_len=config.max_position_embeddings).cuda()
141
  else:
142
  self.te_rope_emb = None
143
 
@@ -156,6 +157,12 @@ class NVEsmEncoder(nn.Module):
156
  """
157
  all_hidden_states = () if output_hidden_states else None
158
 
 
 
 
 
 
 
159
  for layer_module in self.layers:
160
  if output_hidden_states:
161
  all_hidden_states = (*all_hidden_states, hidden_states)
@@ -163,7 +170,7 @@ class NVEsmEncoder(nn.Module):
163
  hidden_states = layer_module(
164
  hidden_states,
165
  attention_mask,
166
- rotary_pos_emb=self.te_rope_emb,
167
  )
168
 
169
  hidden_states = self.emb_layer_norm_after(hidden_states)
 
1
  # coding=utf-8
2
+ # noqa: license-check
3
  # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
  # SPDX-License-Identifier: LicenseRef-Apache2
5
  # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
 
138
  self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
139
  if config.position_embedding_type == "rotary":
140
  self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
141
+ self.te_rope_emb = self.rotary_embeddings(max_seq_len=config.max_position_embeddings)
142
  else:
143
  self.te_rope_emb = None
144
 
 
157
  """
158
  all_hidden_states = () if output_hidden_states else None
159
 
160
+ if self.te_rope_emb is not None:
161
+ te_rope_emb = self.te_rope_emb.to(hidden_states.device, non_blocking=True)
162
+ te_rope_emb = te_rope_emb[: hidden_states.shape[1]]
163
+ else:
164
+ te_rope_emb = None
165
+
166
  for layer_module in self.layers:
167
  if output_hidden_states:
168
  all_hidden_states = (*all_hidden_states, hidden_states)
 
170
  hidden_states = layer_module(
171
  hidden_states,
172
  attention_mask,
173
+ rotary_pos_emb=te_rope_emb,
174
  )
175
 
176
  hidden_states = self.emb_layer_norm_after(hidden_states)