Update modeling_gpt2.py
Browse files- modeling_gpt2.py +101 -2
modeling_gpt2.py
CHANGED
@@ -34,7 +34,7 @@ from transformers.modeling_outputs import (
|
|
34 |
SequenceClassifierOutputWithPast,
|
35 |
TokenClassifierOutput,
|
36 |
)
|
37 |
-
from transformers.modeling_utils import PreTrainedModel
|
38 |
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
39 |
from transformers.utils import (
|
40 |
ModelOutput,
|
@@ -1106,6 +1106,105 @@ class GPT2CustomLMHeadModel(GPT2PreTrainedModel, GenerationMixin):
|
|
1106 |
)
|
1107 |
|
1108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1109 |
@add_start_docstrings(
|
1110 |
"""
|
1111 |
The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
|
@@ -1123,7 +1222,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|
1123 |
config.num_labels = 1
|
1124 |
self.transformer = GPT2Model(config)
|
1125 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
1126 |
-
self.multiple_choice_head =
|
1127 |
|
1128 |
# Model parallel
|
1129 |
self.model_parallel = False
|
|
|
34 |
SequenceClassifierOutputWithPast,
|
35 |
TokenClassifierOutput,
|
36 |
)
|
37 |
+
from transformers.modeling_utils import PreTrainedModel
|
38 |
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
39 |
from transformers.utils import (
|
40 |
ModelOutput,
|
|
|
1106 |
)
|
1107 |
|
1108 |
|
1109 |
+
class GPT2SequenceSummary(nn.Module):
|
1110 |
+
r"""
|
1111 |
+
Compute a single vector summary of a sequence hidden states.
|
1112 |
+
|
1113 |
+
Args:
|
1114 |
+
config ([`GPT2Config`]):
|
1115 |
+
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
1116 |
+
config class of your model for the default values it uses):
|
1117 |
+
|
1118 |
+
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
1119 |
+
|
1120 |
+
- `"last"` -- Take the last token hidden state (like XLNet)
|
1121 |
+
- `"first"` -- Take the first token hidden state (like Bert)
|
1122 |
+
- `"mean"` -- Take the mean of all tokens hidden states
|
1123 |
+
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
1124 |
+
- `"attn"` -- Not implemented now, use multi-head attention
|
1125 |
+
|
1126 |
+
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
1127 |
+
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
1128 |
+
(otherwise to `config.hidden_size`).
|
1129 |
+
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
1130 |
+
another string or `None` will add no activation.
|
1131 |
+
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
1132 |
+
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
1133 |
+
"""
|
1134 |
+
|
1135 |
+
def __init__(self, config: GPT2Config):
|
1136 |
+
super().__init__()
|
1137 |
+
|
1138 |
+
self.summary_type = getattr(config, "summary_type", "last")
|
1139 |
+
if self.summary_type == "attn":
|
1140 |
+
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
1141 |
+
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
1142 |
+
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
1143 |
+
raise NotImplementedError
|
1144 |
+
|
1145 |
+
self.summary = nn.Identity()
|
1146 |
+
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
1147 |
+
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
1148 |
+
num_classes = config.num_labels
|
1149 |
+
else:
|
1150 |
+
num_classes = config.hidden_size
|
1151 |
+
self.summary = nn.Linear(config.hidden_size, num_classes)
|
1152 |
+
|
1153 |
+
activation_string = getattr(config, "summary_activation", None)
|
1154 |
+
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
1155 |
+
|
1156 |
+
self.first_dropout = nn.Identity()
|
1157 |
+
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
1158 |
+
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
1159 |
+
|
1160 |
+
self.last_dropout = nn.Identity()
|
1161 |
+
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
1162 |
+
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
1163 |
+
|
1164 |
+
def forward(
|
1165 |
+
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
1166 |
+
) -> torch.FloatTensor:
|
1167 |
+
"""
|
1168 |
+
Compute a single vector summary of a sequence hidden states.
|
1169 |
+
|
1170 |
+
Args:
|
1171 |
+
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
1172 |
+
The hidden states of the last layer.
|
1173 |
+
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
1174 |
+
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
1175 |
+
|
1176 |
+
Returns:
|
1177 |
+
`torch.FloatTensor`: The summary of the sequence hidden states.
|
1178 |
+
"""
|
1179 |
+
if self.summary_type == "last":
|
1180 |
+
output = hidden_states[:, -1]
|
1181 |
+
elif self.summary_type == "first":
|
1182 |
+
output = hidden_states[:, 0]
|
1183 |
+
elif self.summary_type == "mean":
|
1184 |
+
output = hidden_states.mean(dim=1)
|
1185 |
+
elif self.summary_type == "cls_index":
|
1186 |
+
if cls_index is None:
|
1187 |
+
cls_index = torch.full_like(
|
1188 |
+
hidden_states[..., :1, :],
|
1189 |
+
hidden_states.shape[-2] - 1,
|
1190 |
+
dtype=torch.long,
|
1191 |
+
)
|
1192 |
+
else:
|
1193 |
+
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
1194 |
+
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
1195 |
+
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
1196 |
+
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
1197 |
+
elif self.summary_type == "attn":
|
1198 |
+
raise NotImplementedError
|
1199 |
+
|
1200 |
+
output = self.first_dropout(output)
|
1201 |
+
output = self.summary(output)
|
1202 |
+
output = self.activation(output)
|
1203 |
+
output = self.last_dropout(output)
|
1204 |
+
|
1205 |
+
return output
|
1206 |
+
|
1207 |
+
|
1208 |
@add_start_docstrings(
|
1209 |
"""
|
1210 |
The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
|
|
|
1222 |
config.num_labels = 1
|
1223 |
self.transformer = GPT2Model(config)
|
1224 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
1225 |
+
self.multiple_choice_head = GPT2SequenceSummary(config)
|
1226 |
|
1227 |
# Model parallel
|
1228 |
self.model_parallel = False
|