|
import copy |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
|
|
from contextlib import nullcontext |
|
from torch import Tensor |
|
from torch.distributions import Categorical |
|
from typing import Dict, Optional, Tuple |
|
from dataclasses import dataclass |
|
from transformers import AutoModelForMaskedLM, ElectraModel |
|
from transformers.modeling_outputs import MaskedLMOutput, ModelOutput |
|
from transformers.models.bert import BertForMaskedLM |
|
|
|
from logger_config import logger |
|
from config import Arguments |
|
from utils import slice_batch_dict |
|
|
|
|
|
@dataclass |
|
class ReplaceLMOutput(ModelOutput): |
|
loss: Optional[Tensor] = None |
|
encoder_mlm_loss: Optional[Tensor] = None |
|
decoder_mlm_loss: Optional[Tensor] = None |
|
g_mlm_loss: Optional[Tensor] = None |
|
replace_ratio: Optional[Tensor] = None |
|
|
|
|
|
class ReplaceLM(nn.Module): |
|
def __init__(self, args: Arguments, |
|
bert: BertForMaskedLM): |
|
super(ReplaceLM, self).__init__() |
|
self.encoder = bert |
|
self.decoder = copy.deepcopy(self.encoder.bert.encoder.layer[-args.rlm_decoder_layers:]) |
|
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') |
|
|
|
self.generator: ElectraModel = AutoModelForMaskedLM.from_pretrained(args.rlm_generator_model_name) |
|
|
|
if args.rlm_freeze_generator: |
|
self.generator.eval() |
|
self.generator.requires_grad_(False) |
|
|
|
self.args = args |
|
|
|
from trainers.rlm_trainer import ReplaceLMTrainer |
|
self.trainer: Optional[ReplaceLMTrainer] = None |
|
|
|
def forward(self, model_input: Dict[str, torch.Tensor]) -> ReplaceLMOutput: |
|
enc_prefix, dec_prefix = 'enc_', 'dec_' |
|
encoder_inputs = slice_batch_dict(model_input, enc_prefix) |
|
decoder_inputs = slice_batch_dict(model_input, dec_prefix) |
|
labels = model_input['labels'] |
|
|
|
enc_sampled_input_ids, g_mlm_loss = self._replace_tokens(encoder_inputs) |
|
if self.args.rlm_freeze_generator: |
|
g_mlm_loss = torch.tensor(0, dtype=torch.float, device=g_mlm_loss.device) |
|
dec_sampled_input_ids, _ = self._replace_tokens(decoder_inputs, no_grad=True) |
|
|
|
encoder_inputs['input_ids'] = enc_sampled_input_ids |
|
decoder_inputs['input_ids'] = dec_sampled_input_ids |
|
|
|
encoder_inputs['labels'] = labels |
|
decoder_inputs['labels'] = labels |
|
|
|
is_replaced = (encoder_inputs['input_ids'] != labels) & (labels >= 0) |
|
replace_cnt = is_replaced.long().sum().item() |
|
total_cnt = (encoder_inputs['attention_mask'] == 1).long().sum().item() |
|
replace_ratio = torch.tensor(replace_cnt / total_cnt, device=g_mlm_loss.device) |
|
|
|
encoder_out: MaskedLMOutput = self.encoder( |
|
**encoder_inputs, |
|
output_hidden_states=True, |
|
return_dict=True) |
|
|
|
|
|
cls_hidden = encoder_out.hidden_states[-1][:, :1] |
|
|
|
dec_inputs_embeds = self.encoder.bert.embeddings(decoder_inputs['input_ids']) |
|
hiddens = torch.cat([cls_hidden, dec_inputs_embeds[:, 1:]], dim=1) |
|
|
|
attention_mask = self.encoder.get_extended_attention_mask( |
|
encoder_inputs['attention_mask'], |
|
encoder_inputs['attention_mask'].shape, |
|
encoder_inputs['attention_mask'].device |
|
) |
|
|
|
for layer in self.decoder: |
|
layer_out = layer(hiddens, attention_mask) |
|
hiddens = layer_out[0] |
|
|
|
decoder_mlm_loss = self.mlm_loss(hiddens, labels) |
|
|
|
loss = decoder_mlm_loss + encoder_out.loss + g_mlm_loss * self.args.rlm_generator_mlm_weight |
|
|
|
return ReplaceLMOutput(loss=loss, |
|
encoder_mlm_loss=encoder_out.loss.detach(), |
|
decoder_mlm_loss=decoder_mlm_loss.detach(), |
|
g_mlm_loss=g_mlm_loss.detach(), |
|
replace_ratio=replace_ratio) |
|
|
|
def _replace_tokens(self, batch_dict: Dict[str, torch.Tensor], |
|
no_grad: bool = False) -> Tuple: |
|
with torch.no_grad() if self.args.rlm_freeze_generator or no_grad else nullcontext(): |
|
outputs: MaskedLMOutput = self.generator( |
|
**batch_dict, |
|
return_dict=True) |
|
|
|
with torch.no_grad(): |
|
sampled_input_ids = Categorical(logits=outputs.logits).sample() |
|
is_mask = (batch_dict['labels'] >= 0).long() |
|
sampled_input_ids = batch_dict['input_ids'] * (1 - is_mask) + sampled_input_ids * is_mask |
|
|
|
return sampled_input_ids.long(), outputs.loss |
|
|
|
def mlm_loss(self, hiddens: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
|
pred_scores = self.encoder.cls(hiddens) |
|
mlm_loss = self.cross_entropy( |
|
pred_scores.view(-1, self.encoder.config.vocab_size), |
|
labels.view(-1)) |
|
return mlm_loss |
|
|
|
@classmethod |
|
def from_pretrained(cls, all_args: Arguments, |
|
model_name_or_path: str, *args, **kwargs): |
|
hf_model = AutoModelForMaskedLM.from_pretrained(model_name_or_path, *args, **kwargs) |
|
model = cls(all_args, hf_model) |
|
decoder_save_path = os.path.join(model_name_or_path, 'decoder.pt') |
|
if os.path.exists(decoder_save_path): |
|
logger.info('loading extra weights from local files') |
|
state_dict = torch.load(decoder_save_path, map_location="cpu") |
|
model.decoder.load_state_dict(state_dict) |
|
return model |
|
|
|
def save_pretrained(self, output_dir: str): |
|
self.encoder.save_pretrained(output_dir) |
|
torch.save(self.decoder.state_dict(), os.path.join(output_dir, 'decoder.pt')) |
|
|