yangheng commited on
Commit
4c39112
·
verified ·
1 Parent(s): 90722cd

Update modeling_omnigenome.py

Browse files
Files changed (1) hide show
  1. modeling_omnigenome.py +4 -22
modeling_omnigenome.py CHANGED
@@ -291,10 +291,8 @@ class OmniGenomeEmbeddings(nn.Module):
291
  def create_position_ids_from_inputs_embeds(self, inputs_embeds):
292
  """
293
  We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
294
-
295
  Args:
296
  inputs_embeds: torch.Tensor
297
-
298
  Returns: torch.Tensor
299
  """
300
  input_shape = inputs_embeds.size()[:-1]
@@ -578,9 +576,9 @@ class OmniGenomeSelfAttention(nn.Module):
578
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
579
 
580
  # 调整维度顺序为 [batch_size, seq_len, num_heads, head_dim]
581
- q = query_layer.transpose(1, 2).half()
582
- k = key_layer.transpose(1, 2).half()
583
- v = value_layer.transpose(1, 2).half()
584
 
585
  # 使用FlashAttention计算
586
  context_layer = self.flash_attn_func(
@@ -989,15 +987,12 @@ class OmniGenomePreTrainedModel(PreTrainedModel):
989
 
990
 
991
  OmniGenome_START_DOCSTRING = r"""
992
-
993
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
994
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
995
  etc.)
996
-
997
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
998
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
999
  and behavior.
1000
-
1001
  Parameters:
1002
  config ([`OmniGenomeConfig`]): Model configuration class with all the parameters of the
1003
  model. Initializing with a config file does not load the weights associated with the model, only the
@@ -1008,29 +1003,22 @@ OmniGenome_INPUTS_DOCSTRING = r"""
1008
  Args:
1009
  input_ids (`torch.LongTensor` of shape `({0})`):
1010
  Indices of input sequence tokens in the vocabulary.
1011
-
1012
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1013
  [`PreTrainedTokenizer.__call__`] for details.
1014
-
1015
  [What are input IDs?](../glossary#input-ids)
1016
  attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
1017
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1018
-
1019
  - 1 for tokens that are **not masked**,
1020
  - 0 for tokens that are **masked**.
1021
-
1022
  [What are attention masks?](../glossary#attention-mask)
1023
  position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
1024
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1025
  config.max_position_embeddings - 1]`.
1026
-
1027
  [What are position IDs?](../glossary#position-ids)
1028
  head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1029
  Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1030
-
1031
  - 1 indicates the head is **not masked**,
1032
  - 0 indicates the head is **masked**.
1033
-
1034
  inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
1035
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1036
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
@@ -1053,12 +1041,10 @@ OmniGenome_INPUTS_DOCSTRING = r"""
1053
  # Copied from transformers.models.esm.modeling_esm.EsmModel with Esm->OmniGenome
1054
  class OmniGenomeModel(OmniGenomePreTrainedModel):
1055
  """
1056
-
1057
  The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
1058
  cross-attention is added between the self-attention layers, following the architecture described in [Attention is
1059
  all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
1060
  Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
1061
-
1062
  To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
1063
  to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
1064
  `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
@@ -1124,12 +1110,10 @@ class OmniGenomeModel(OmniGenomePreTrainedModel):
1124
  encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1125
  Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1126
  the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1127
-
1128
  - 1 for tokens that are **not masked**,
1129
  - 0 for tokens that are **masked**.
1130
  past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1131
  Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1132
-
1133
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1134
  don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1135
  `decoder_input_ids` of shape `(batch_size, sequence_length)`.
@@ -1218,7 +1202,7 @@ class OmniGenomeModel(OmniGenomePreTrainedModel):
1218
  inputs_embeds=inputs_embeds,
1219
  past_key_values_length=past_key_values_length,
1220
  )
1221
- embedding_output = embedding_output.half()
1222
  encoder_outputs = self.encoder(
1223
  embedding_output,
1224
  attention_mask=extended_attention_mask,
@@ -1893,10 +1877,8 @@ def create_position_ids_from_input_ids(
1893
  """
1894
  Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1895
  are ignored. This is modified from fairseq's `utils.make_positions`.
1896
-
1897
  Args:
1898
  x: torch.Tensor x:
1899
-
1900
  Returns: torch.Tensor
1901
  """
1902
  # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
 
291
  def create_position_ids_from_inputs_embeds(self, inputs_embeds):
292
  """
293
  We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
 
294
  Args:
295
  inputs_embeds: torch.Tensor
 
296
  Returns: torch.Tensor
297
  """
298
  input_shape = inputs_embeds.size()[:-1]
 
576
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
577
 
578
  # 调整维度顺序为 [batch_size, seq_len, num_heads, head_dim]
579
+ q = query_layer.transpose(1, 2).to(torch.float16)
580
+ k = key_layer.transpose(1, 2).to(torch.float16)
581
+ v = value_layer.transpose(1, 2).to(torch.float16)
582
 
583
  # 使用FlashAttention计算
584
  context_layer = self.flash_attn_func(
 
987
 
988
 
989
  OmniGenome_START_DOCSTRING = r"""
 
990
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
991
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
992
  etc.)
 
993
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
994
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
995
  and behavior.
 
996
  Parameters:
997
  config ([`OmniGenomeConfig`]): Model configuration class with all the parameters of the
998
  model. Initializing with a config file does not load the weights associated with the model, only the
 
1003
  Args:
1004
  input_ids (`torch.LongTensor` of shape `({0})`):
1005
  Indices of input sequence tokens in the vocabulary.
 
1006
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1007
  [`PreTrainedTokenizer.__call__`] for details.
 
1008
  [What are input IDs?](../glossary#input-ids)
1009
  attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
1010
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
1011
  - 1 for tokens that are **not masked**,
1012
  - 0 for tokens that are **masked**.
 
1013
  [What are attention masks?](../glossary#attention-mask)
1014
  position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
1015
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1016
  config.max_position_embeddings - 1]`.
 
1017
  [What are position IDs?](../glossary#position-ids)
1018
  head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1019
  Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
 
1020
  - 1 indicates the head is **not masked**,
1021
  - 0 indicates the head is **masked**.
 
1022
  inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
1023
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1024
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
 
1041
  # Copied from transformers.models.esm.modeling_esm.EsmModel with Esm->OmniGenome
1042
  class OmniGenomeModel(OmniGenomePreTrainedModel):
1043
  """
 
1044
  The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
1045
  cross-attention is added between the self-attention layers, following the architecture described in [Attention is
1046
  all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
1047
  Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
 
1048
  To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
1049
  to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
1050
  `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
 
1110
  encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1111
  Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1112
  the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
 
1113
  - 1 for tokens that are **not masked**,
1114
  - 0 for tokens that are **masked**.
1115
  past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1116
  Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
 
1117
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1118
  don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1119
  `decoder_input_ids` of shape `(batch_size, sequence_length)`.
 
1202
  inputs_embeds=inputs_embeds,
1203
  past_key_values_length=past_key_values_length,
1204
  )
1205
+ embedding_output = embedding_output.to(torch.float16)
1206
  encoder_outputs = self.encoder(
1207
  embedding_output,
1208
  attention_mask=extended_attention_mask,
 
1877
  """
1878
  Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1879
  are ignored. This is modified from fairseq's `utils.make_positions`.
 
1880
  Args:
1881
  x: torch.Tensor x:
 
1882
  Returns: torch.Tensor
1883
  """
1884
  # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.