Upload modeling_eurobert.py
#12
by
hgissbkh
- opened
- modeling_eurobert.py +86 -7
modeling_eurobert.py
CHANGED
@@ -30,7 +30,7 @@ from transformers.activations import ACT2FN
|
|
30 |
from transformers.cache_utils import Cache, StaticCache
|
31 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
32 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
33 |
-
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput
|
34 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
35 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
36 |
from transformers.processing_utils import Unpack
|
@@ -708,7 +708,7 @@ class EuroBertModel(EuroBertPreTrainedModel):
|
|
708 |
|
709 |
|
710 |
@add_start_docstrings(
|
711 |
-
"The EuroBert Model with a
|
712 |
EUROBERT_START_DOCSTRING,
|
713 |
)
|
714 |
class EuroBertForMaskedLM(EuroBertPreTrainedModel):
|
@@ -766,7 +766,7 @@ class EuroBertForMaskedLM(EuroBertPreTrainedModel):
|
|
766 |
|
767 |
|
768 |
@add_start_docstrings(
|
769 |
-
"The EuroBert Model with a
|
770 |
EUROBERT_START_DOCSTRING,
|
771 |
)
|
772 |
class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
|
@@ -778,7 +778,7 @@ class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
|
|
778 |
self.model = EuroBertModel(config)
|
779 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
780 |
self.activation = nn.GELU()
|
781 |
-
self.
|
782 |
self.post_init()
|
783 |
|
784 |
@add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
|
@@ -830,12 +830,12 @@ class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
|
|
830 |
|
831 |
pooled_output = self.dense(pooled_output)
|
832 |
pooled_output = self.activation(pooled_output)
|
833 |
-
logits = self.
|
834 |
|
835 |
elif self.clf_pooling == "late":
|
836 |
x = self.dense(last_hidden_state)
|
837 |
x = self.activation(x)
|
838 |
-
logits = self.
|
839 |
if attention_mask is None:
|
840 |
logits = logits.mean(dim=1)
|
841 |
else:
|
@@ -878,4 +878,83 @@ class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
|
|
878 |
)
|
879 |
|
880 |
|
881 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
from transformers.cache_utils import Cache, StaticCache
|
31 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
32 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
33 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
34 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
35 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
36 |
from transformers.processing_utils import Unpack
|
|
|
708 |
|
709 |
|
710 |
@add_start_docstrings(
|
711 |
+
"The EuroBert Model with a decoder head on top that is used for masked language modeling.",
|
712 |
EUROBERT_START_DOCSTRING,
|
713 |
)
|
714 |
class EuroBertForMaskedLM(EuroBertPreTrainedModel):
|
|
|
766 |
|
767 |
|
768 |
@add_start_docstrings(
|
769 |
+
"The EuroBert Model with a sequence classification head on top that performs pooling.",
|
770 |
EUROBERT_START_DOCSTRING,
|
771 |
)
|
772 |
class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
|
|
|
778 |
self.model = EuroBertModel(config)
|
779 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
780 |
self.activation = nn.GELU()
|
781 |
+
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
|
782 |
self.post_init()
|
783 |
|
784 |
@add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
|
|
|
830 |
|
831 |
pooled_output = self.dense(pooled_output)
|
832 |
pooled_output = self.activation(pooled_output)
|
833 |
+
logits = self.classifier(pooled_output)
|
834 |
|
835 |
elif self.clf_pooling == "late":
|
836 |
x = self.dense(last_hidden_state)
|
837 |
x = self.activation(x)
|
838 |
+
logits = self.classifier(x)
|
839 |
if attention_mask is None:
|
840 |
logits = logits.mean(dim=1)
|
841 |
else:
|
|
|
878 |
)
|
879 |
|
880 |
|
881 |
+
@add_start_docstrings(
|
882 |
+
"""
|
883 |
+
The EuroBert Model with a token classification head on top (a linear layer on top of the hidden-states
|
884 |
+
output) e.g. for Named-Entity-Recognition (NER) tasks."
|
885 |
+
""",
|
886 |
+
EUROBERT_START_DOCSTRING,
|
887 |
+
)
|
888 |
+
class EuroBertForTokenClassification(EuroBertPreTrainedModel):
|
889 |
+
def __init__(self, config: EuroBertConfig):
|
890 |
+
super().__init__(config)
|
891 |
+
self.num_labels = config.num_labels
|
892 |
+
self.model = EuroBertModel(config)
|
893 |
+
|
894 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
895 |
+
self.post_init()
|
896 |
+
|
897 |
+
def get_input_embeddings(self):
|
898 |
+
return self.model.embed_tokens
|
899 |
+
|
900 |
+
def set_input_embeddings(self, value):
|
901 |
+
self.model.embed_tokens = value
|
902 |
+
|
903 |
+
@add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
|
904 |
+
def forward(
|
905 |
+
self,
|
906 |
+
input_ids: Optional[torch.LongTensor] = None,
|
907 |
+
attention_mask: Optional[torch.Tensor] = None,
|
908 |
+
position_ids: Optional[torch.LongTensor] = None,
|
909 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
910 |
+
labels: Optional[torch.LongTensor] = None,
|
911 |
+
use_cache: Optional[bool] = None,
|
912 |
+
output_attentions: Optional[bool] = None,
|
913 |
+
output_hidden_states: Optional[bool] = None,
|
914 |
+
return_dict: Optional[bool] = None,
|
915 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
916 |
+
r"""
|
917 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
918 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
919 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
920 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
921 |
+
"""
|
922 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
923 |
+
|
924 |
+
outputs = self.model(
|
925 |
+
input_ids,
|
926 |
+
attention_mask=attention_mask,
|
927 |
+
position_ids=position_ids,
|
928 |
+
inputs_embeds=inputs_embeds,
|
929 |
+
use_cache=use_cache,
|
930 |
+
output_attentions=output_attentions,
|
931 |
+
output_hidden_states=output_hidden_states,
|
932 |
+
return_dict=return_dict,
|
933 |
+
)
|
934 |
+
sequence_output = outputs[0]
|
935 |
+
logits = self.classifier(sequence_output)
|
936 |
+
|
937 |
+
loss = None
|
938 |
+
if labels is not None:
|
939 |
+
loss_fct = CrossEntropyLoss()
|
940 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
941 |
+
|
942 |
+
if not return_dict:
|
943 |
+
output = (logits,) + outputs[2:]
|
944 |
+
return ((loss,) + output) if loss is not None else output
|
945 |
+
|
946 |
+
return TokenClassifierOutput(
|
947 |
+
loss=loss,
|
948 |
+
logits=logits,
|
949 |
+
hidden_states=outputs.hidden_states,
|
950 |
+
attentions=outputs.attentions,
|
951 |
+
)
|
952 |
+
|
953 |
+
|
954 |
+
__all__ = [
|
955 |
+
"EuroBertPreTrainedModel",
|
956 |
+
"EuroBertModel",
|
957 |
+
"EuroBertForMaskedLM",
|
958 |
+
"EuroBertForSequenceClassification",
|
959 |
+
"EuroBertForTokenClassification",
|
960 |
+
]
|