Update modeling_omnigenome.py
Browse files- modeling_omnigenome.py +4 -22
modeling_omnigenome.py
CHANGED
@@ -291,10 +291,8 @@ class OmniGenomeEmbeddings(nn.Module):
|
|
291 |
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
292 |
"""
|
293 |
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
294 |
-
|
295 |
Args:
|
296 |
inputs_embeds: torch.Tensor
|
297 |
-
|
298 |
Returns: torch.Tensor
|
299 |
"""
|
300 |
input_shape = inputs_embeds.size()[:-1]
|
@@ -578,9 +576,9 @@ class OmniGenomeSelfAttention(nn.Module):
|
|
578 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
579 |
|
580 |
# 调整维度顺序为 [batch_size, seq_len, num_heads, head_dim]
|
581 |
-
q = query_layer.transpose(1, 2).
|
582 |
-
k = key_layer.transpose(1, 2).
|
583 |
-
v = value_layer.transpose(1, 2).
|
584 |
|
585 |
# 使用FlashAttention计算
|
586 |
context_layer = self.flash_attn_func(
|
@@ -989,15 +987,12 @@ class OmniGenomePreTrainedModel(PreTrainedModel):
|
|
989 |
|
990 |
|
991 |
OmniGenome_START_DOCSTRING = r"""
|
992 |
-
|
993 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
994 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
995 |
etc.)
|
996 |
-
|
997 |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
998 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
999 |
and behavior.
|
1000 |
-
|
1001 |
Parameters:
|
1002 |
config ([`OmniGenomeConfig`]): Model configuration class with all the parameters of the
|
1003 |
model. Initializing with a config file does not load the weights associated with the model, only the
|
@@ -1008,29 +1003,22 @@ OmniGenome_INPUTS_DOCSTRING = r"""
|
|
1008 |
Args:
|
1009 |
input_ids (`torch.LongTensor` of shape `({0})`):
|
1010 |
Indices of input sequence tokens in the vocabulary.
|
1011 |
-
|
1012 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1013 |
[`PreTrainedTokenizer.__call__`] for details.
|
1014 |
-
|
1015 |
[What are input IDs?](../glossary#input-ids)
|
1016 |
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
1017 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1018 |
-
|
1019 |
- 1 for tokens that are **not masked**,
|
1020 |
- 0 for tokens that are **masked**.
|
1021 |
-
|
1022 |
[What are attention masks?](../glossary#attention-mask)
|
1023 |
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
1024 |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
1025 |
config.max_position_embeddings - 1]`.
|
1026 |
-
|
1027 |
[What are position IDs?](../glossary#position-ids)
|
1028 |
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
1029 |
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
1030 |
-
|
1031 |
- 1 indicates the head is **not masked**,
|
1032 |
- 0 indicates the head is **masked**.
|
1033 |
-
|
1034 |
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
1035 |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
1036 |
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
@@ -1053,12 +1041,10 @@ OmniGenome_INPUTS_DOCSTRING = r"""
|
|
1053 |
# Copied from transformers.models.esm.modeling_esm.EsmModel with Esm->OmniGenome
|
1054 |
class OmniGenomeModel(OmniGenomePreTrainedModel):
|
1055 |
"""
|
1056 |
-
|
1057 |
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
1058 |
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
1059 |
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
1060 |
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
1061 |
-
|
1062 |
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
1063 |
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
1064 |
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
@@ -1124,12 +1110,10 @@ class OmniGenomeModel(OmniGenomePreTrainedModel):
|
|
1124 |
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1125 |
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
1126 |
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
1127 |
-
|
1128 |
- 1 for tokens that are **not masked**,
|
1129 |
- 0 for tokens that are **masked**.
|
1130 |
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1131 |
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1132 |
-
|
1133 |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
1134 |
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1135 |
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
@@ -1218,7 +1202,7 @@ class OmniGenomeModel(OmniGenomePreTrainedModel):
|
|
1218 |
inputs_embeds=inputs_embeds,
|
1219 |
past_key_values_length=past_key_values_length,
|
1220 |
)
|
1221 |
-
embedding_output = embedding_output.
|
1222 |
encoder_outputs = self.encoder(
|
1223 |
embedding_output,
|
1224 |
attention_mask=extended_attention_mask,
|
@@ -1893,10 +1877,8 @@ def create_position_ids_from_input_ids(
|
|
1893 |
"""
|
1894 |
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
1895 |
are ignored. This is modified from fairseq's `utils.make_positions`.
|
1896 |
-
|
1897 |
Args:
|
1898 |
x: torch.Tensor x:
|
1899 |
-
|
1900 |
Returns: torch.Tensor
|
1901 |
"""
|
1902 |
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
|
|
291 |
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
292 |
"""
|
293 |
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
|
|
294 |
Args:
|
295 |
inputs_embeds: torch.Tensor
|
|
|
296 |
Returns: torch.Tensor
|
297 |
"""
|
298 |
input_shape = inputs_embeds.size()[:-1]
|
|
|
576 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
577 |
|
578 |
# 调整维度顺序为 [batch_size, seq_len, num_heads, head_dim]
|
579 |
+
q = query_layer.transpose(1, 2).to(torch.float16)
|
580 |
+
k = key_layer.transpose(1, 2).to(torch.float16)
|
581 |
+
v = value_layer.transpose(1, 2).to(torch.float16)
|
582 |
|
583 |
# 使用FlashAttention计算
|
584 |
context_layer = self.flash_attn_func(
|
|
|
987 |
|
988 |
|
989 |
OmniGenome_START_DOCSTRING = r"""
|
|
|
990 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
991 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
992 |
etc.)
|
|
|
993 |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
994 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
995 |
and behavior.
|
|
|
996 |
Parameters:
|
997 |
config ([`OmniGenomeConfig`]): Model configuration class with all the parameters of the
|
998 |
model. Initializing with a config file does not load the weights associated with the model, only the
|
|
|
1003 |
Args:
|
1004 |
input_ids (`torch.LongTensor` of shape `({0})`):
|
1005 |
Indices of input sequence tokens in the vocabulary.
|
|
|
1006 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1007 |
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
1008 |
[What are input IDs?](../glossary#input-ids)
|
1009 |
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
1010 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
1011 |
- 1 for tokens that are **not masked**,
|
1012 |
- 0 for tokens that are **masked**.
|
|
|
1013 |
[What are attention masks?](../glossary#attention-mask)
|
1014 |
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
1015 |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
1016 |
config.max_position_embeddings - 1]`.
|
|
|
1017 |
[What are position IDs?](../glossary#position-ids)
|
1018 |
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
1019 |
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
|
|
1020 |
- 1 indicates the head is **not masked**,
|
1021 |
- 0 indicates the head is **masked**.
|
|
|
1022 |
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
1023 |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
1024 |
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
|
1041 |
# Copied from transformers.models.esm.modeling_esm.EsmModel with Esm->OmniGenome
|
1042 |
class OmniGenomeModel(OmniGenomePreTrainedModel):
|
1043 |
"""
|
|
|
1044 |
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
1045 |
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
1046 |
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
1047 |
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
|
|
1048 |
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
1049 |
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
1050 |
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
|
|
1110 |
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1111 |
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
1112 |
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
|
|
1113 |
- 1 for tokens that are **not masked**,
|
1114 |
- 0 for tokens that are **masked**.
|
1115 |
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1116 |
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
|
1117 |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
1118 |
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1119 |
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
|
1202 |
inputs_embeds=inputs_embeds,
|
1203 |
past_key_values_length=past_key_values_length,
|
1204 |
)
|
1205 |
+
embedding_output = embedding_output.to(torch.float16)
|
1206 |
encoder_outputs = self.encoder(
|
1207 |
embedding_output,
|
1208 |
attention_mask=extended_attention_mask,
|
|
|
1877 |
"""
|
1878 |
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
1879 |
are ignored. This is modified from fairseq's `utils.make_positions`.
|
|
|
1880 |
Args:
|
1881 |
x: torch.Tensor x:
|
|
|
1882 |
Returns: torch.Tensor
|
1883 |
"""
|
1884 |
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|