Upload modeling_eurobert.py

#12
by hgissbkh - opened
Files changed (1) hide show
  1. modeling_eurobert.py +86 -7
modeling_eurobert.py CHANGED
@@ -30,7 +30,7 @@ from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, StaticCache
31
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
33
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput
34
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
35
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
  from transformers.processing_utils import Unpack
@@ -708,7 +708,7 @@ class EuroBertModel(EuroBertPreTrainedModel):
708
 
709
 
710
  @add_start_docstrings(
711
- "The EuroBert Model with a sequence classification head on top that performs pooling.",
712
  EUROBERT_START_DOCSTRING,
713
  )
714
  class EuroBertForMaskedLM(EuroBertPreTrainedModel):
@@ -766,7 +766,7 @@ class EuroBertForMaskedLM(EuroBertPreTrainedModel):
766
 
767
 
768
  @add_start_docstrings(
769
- "The EuroBert Model with a decoder head on top that is used for masked language modeling.",
770
  EUROBERT_START_DOCSTRING,
771
  )
772
  class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
@@ -778,7 +778,7 @@ class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
778
  self.model = EuroBertModel(config)
779
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
780
  self.activation = nn.GELU()
781
- self.out_proj = nn.Linear(config.hidden_size, self.num_labels)
782
  self.post_init()
783
 
784
  @add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
@@ -830,12 +830,12 @@ class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
830
 
831
  pooled_output = self.dense(pooled_output)
832
  pooled_output = self.activation(pooled_output)
833
- logits = self.out_proj(pooled_output)
834
 
835
  elif self.clf_pooling == "late":
836
  x = self.dense(last_hidden_state)
837
  x = self.activation(x)
838
- logits = self.out_proj(x)
839
  if attention_mask is None:
840
  logits = logits.mean(dim=1)
841
  else:
@@ -878,4 +878,83 @@ class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
878
  )
879
 
880
 
881
- __all__ = ["EuroBertPreTrainedModel", "EuroBertModel", "EuroBertForMaskedLM", "EuroBertForSequenceClassification"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  from transformers.cache_utils import Cache, StaticCache
31
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
34
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
35
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
  from transformers.processing_utils import Unpack
 
708
 
709
 
710
  @add_start_docstrings(
711
+ "The EuroBert Model with a decoder head on top that is used for masked language modeling.",
712
  EUROBERT_START_DOCSTRING,
713
  )
714
  class EuroBertForMaskedLM(EuroBertPreTrainedModel):
 
766
 
767
 
768
  @add_start_docstrings(
769
+ "The EuroBert Model with a sequence classification head on top that performs pooling.",
770
  EUROBERT_START_DOCSTRING,
771
  )
772
  class EuroBertForSequenceClassification(EuroBertPreTrainedModel):
 
778
  self.model = EuroBertModel(config)
779
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
780
  self.activation = nn.GELU()
781
+ self.classifier = nn.Linear(config.hidden_size, self.num_labels)
782
  self.post_init()
783
 
784
  @add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
 
830
 
831
  pooled_output = self.dense(pooled_output)
832
  pooled_output = self.activation(pooled_output)
833
+ logits = self.classifier(pooled_output)
834
 
835
  elif self.clf_pooling == "late":
836
  x = self.dense(last_hidden_state)
837
  x = self.activation(x)
838
+ logits = self.classifier(x)
839
  if attention_mask is None:
840
  logits = logits.mean(dim=1)
841
  else:
 
878
  )
879
 
880
 
881
+ @add_start_docstrings(
882
+ """
883
+ The EuroBert Model with a token classification head on top (a linear layer on top of the hidden-states
884
+ output) e.g. for Named-Entity-Recognition (NER) tasks."
885
+ """,
886
+ EUROBERT_START_DOCSTRING,
887
+ )
888
+ class EuroBertForTokenClassification(EuroBertPreTrainedModel):
889
+ def __init__(self, config: EuroBertConfig):
890
+ super().__init__(config)
891
+ self.num_labels = config.num_labels
892
+ self.model = EuroBertModel(config)
893
+
894
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
895
+ self.post_init()
896
+
897
+ def get_input_embeddings(self):
898
+ return self.model.embed_tokens
899
+
900
+ def set_input_embeddings(self, value):
901
+ self.model.embed_tokens = value
902
+
903
+ @add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
904
+ def forward(
905
+ self,
906
+ input_ids: Optional[torch.LongTensor] = None,
907
+ attention_mask: Optional[torch.Tensor] = None,
908
+ position_ids: Optional[torch.LongTensor] = None,
909
+ inputs_embeds: Optional[torch.FloatTensor] = None,
910
+ labels: Optional[torch.LongTensor] = None,
911
+ use_cache: Optional[bool] = None,
912
+ output_attentions: Optional[bool] = None,
913
+ output_hidden_states: Optional[bool] = None,
914
+ return_dict: Optional[bool] = None,
915
+ ) -> Union[Tuple, TokenClassifierOutput]:
916
+ r"""
917
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
918
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
919
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
920
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
921
+ """
922
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
923
+
924
+ outputs = self.model(
925
+ input_ids,
926
+ attention_mask=attention_mask,
927
+ position_ids=position_ids,
928
+ inputs_embeds=inputs_embeds,
929
+ use_cache=use_cache,
930
+ output_attentions=output_attentions,
931
+ output_hidden_states=output_hidden_states,
932
+ return_dict=return_dict,
933
+ )
934
+ sequence_output = outputs[0]
935
+ logits = self.classifier(sequence_output)
936
+
937
+ loss = None
938
+ if labels is not None:
939
+ loss_fct = CrossEntropyLoss()
940
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
941
+
942
+ if not return_dict:
943
+ output = (logits,) + outputs[2:]
944
+ return ((loss,) + output) if loss is not None else output
945
+
946
+ return TokenClassifierOutput(
947
+ loss=loss,
948
+ logits=logits,
949
+ hidden_states=outputs.hidden_states,
950
+ attentions=outputs.attentions,
951
+ )
952
+
953
+
954
+ __all__ = [
955
+ "EuroBertPreTrainedModel",
956
+ "EuroBertModel",
957
+ "EuroBertForMaskedLM",
958
+ "EuroBertForSequenceClassification",
959
+ "EuroBertForTokenClassification",
960
+ ]