| from torch import nn | |
| from transformers import GPT2LMHeadModel as GPT2LMHeadModelBase | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2Block as GPT2BlockBase | |
| class GPT2Block(GPT2BlockBase): | |
| def forward(self, x, layer_past=None, | |
| attention_mask=None, head_mask=None, use_cache=False, | |
| encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=None): | |
| x = self.ln_1(x) | |
| output_attn = self.attn( | |
| x, layer_past=layer_past, | |
| attention_mask=attention_mask, | |
| head_mask=head_mask, | |
| use_cache=use_cache) | |
| a = output_attn[0] | |
| x = x + a | |
| m = self.mlp(self.ln_2(x)) | |
| x = x + m | |
| outputs = (x,) + output_attn[1:] | |
| return outputs | |
| class GPT2LMHeadModel(GPT2LMHeadModelBase): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.transformer.h = nn.ModuleList([GPT2Block(config, layer_idx) for layer_idx in range(config.n_layer)]) |