IlyasMoutawwakil HF Staff commited on
Commit
177a59e
·
verified ·
1 Parent(s): 626f152

Update modeling_gpt2.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2.py +101 -2
modeling_gpt2.py CHANGED
@@ -34,7 +34,7 @@ from transformers.modeling_outputs import (
34
  SequenceClassifierOutputWithPast,
35
  TokenClassifierOutput,
36
  )
37
- from transformers.modeling_utils import PreTrainedModel, SequenceSummary
38
  from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
39
  from transformers.utils import (
40
  ModelOutput,
@@ -1106,6 +1106,105 @@ class GPT2CustomLMHeadModel(GPT2PreTrainedModel, GenerationMixin):
1106
  )
1107
 
1108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1109
  @add_start_docstrings(
1110
  """
1111
  The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
@@ -1123,7 +1222,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1123
  config.num_labels = 1
1124
  self.transformer = GPT2Model(config)
1125
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1126
- self.multiple_choice_head = SequenceSummary(config)
1127
 
1128
  # Model parallel
1129
  self.model_parallel = False
 
34
  SequenceClassifierOutputWithPast,
35
  TokenClassifierOutput,
36
  )
37
+ from transformers.modeling_utils import PreTrainedModel
38
  from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
39
  from transformers.utils import (
40
  ModelOutput,
 
1106
  )
1107
 
1108
 
1109
+ class GPT2SequenceSummary(nn.Module):
1110
+ r"""
1111
+ Compute a single vector summary of a sequence hidden states.
1112
+
1113
+ Args:
1114
+ config ([`GPT2Config`]):
1115
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
1116
+ config class of your model for the default values it uses):
1117
+
1118
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
1119
+
1120
+ - `"last"` -- Take the last token hidden state (like XLNet)
1121
+ - `"first"` -- Take the first token hidden state (like Bert)
1122
+ - `"mean"` -- Take the mean of all tokens hidden states
1123
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
1124
+ - `"attn"` -- Not implemented now, use multi-head attention
1125
+
1126
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
1127
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
1128
+ (otherwise to `config.hidden_size`).
1129
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
1130
+ another string or `None` will add no activation.
1131
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
1132
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
1133
+ """
1134
+
1135
+ def __init__(self, config: GPT2Config):
1136
+ super().__init__()
1137
+
1138
+ self.summary_type = getattr(config, "summary_type", "last")
1139
+ if self.summary_type == "attn":
1140
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
1141
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
1142
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
1143
+ raise NotImplementedError
1144
+
1145
+ self.summary = nn.Identity()
1146
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
1147
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
1148
+ num_classes = config.num_labels
1149
+ else:
1150
+ num_classes = config.hidden_size
1151
+ self.summary = nn.Linear(config.hidden_size, num_classes)
1152
+
1153
+ activation_string = getattr(config, "summary_activation", None)
1154
+ self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
1155
+
1156
+ self.first_dropout = nn.Identity()
1157
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
1158
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
1159
+
1160
+ self.last_dropout = nn.Identity()
1161
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
1162
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
1163
+
1164
+ def forward(
1165
+ self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
1166
+ ) -> torch.FloatTensor:
1167
+ """
1168
+ Compute a single vector summary of a sequence hidden states.
1169
+
1170
+ Args:
1171
+ hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
1172
+ The hidden states of the last layer.
1173
+ cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
1174
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
1175
+
1176
+ Returns:
1177
+ `torch.FloatTensor`: The summary of the sequence hidden states.
1178
+ """
1179
+ if self.summary_type == "last":
1180
+ output = hidden_states[:, -1]
1181
+ elif self.summary_type == "first":
1182
+ output = hidden_states[:, 0]
1183
+ elif self.summary_type == "mean":
1184
+ output = hidden_states.mean(dim=1)
1185
+ elif self.summary_type == "cls_index":
1186
+ if cls_index is None:
1187
+ cls_index = torch.full_like(
1188
+ hidden_states[..., :1, :],
1189
+ hidden_states.shape[-2] - 1,
1190
+ dtype=torch.long,
1191
+ )
1192
+ else:
1193
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
1194
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
1195
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1196
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
1197
+ elif self.summary_type == "attn":
1198
+ raise NotImplementedError
1199
+
1200
+ output = self.first_dropout(output)
1201
+ output = self.summary(output)
1202
+ output = self.activation(output)
1203
+ output = self.last_dropout(output)
1204
+
1205
+ return output
1206
+
1207
+
1208
  @add_start_docstrings(
1209
  """
1210
  The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
 
1222
  config.num_labels = 1
1223
  self.transformer = GPT2Model(config)
1224
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1225
+ self.multiple_choice_head = GPT2SequenceSummary(config)
1226
 
1227
  # Model parallel
1228
  self.model_parallel = False