|
import json |
|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from tqdm import trange |
|
from transformers import ElectraModel, AutoTokenizer, AutoModel |
|
from transformers import AutoTokenizer, AutoConfig |
|
from torch.utils.data import DataLoader, TensorDataset |
|
from transformers import get_linear_schedule_with_warmup |
|
from transformers import AdamW |
|
from datasets import load_metric |
|
from sklearn.metrics import f1_score |
|
import pandas as pd |
|
import copy |
|
|
|
from torch.nn import functional as F |
|
import re |
|
from config import entity_property_pair |
|
from tqdm import tqdm |
|
from datasets import Dataset |
|
import torch.nn as nn |
|
from transformers import AutoModelForSequenceClassification |
|
from transformers import ElectraModel |
|
|
|
|
|
|
|
class Classifier(nn.Module): |
|
def __init__(self, base_model, num_labels, device, tokenizer): |
|
super(Classifier, self).__init__() |
|
self.num_labels = num_labels |
|
self.device = device |
|
|
|
self.electra = ElectraModel.from_pretrained('beomi/KcELECTRA-base', num_labels=2) |
|
self.electra.resize_token_embeddings(len(tokenizer)) |
|
|
|
self.fc1 = nn.Linear(self.electra.config.hidden_size, 256) |
|
self.fc2 = nn.Linear(self.electra.config.hidden_size, 512) |
|
self.fc3 = nn.Linear(256+512, 2) |
|
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
def forward(self, input_ids, attention_mask, entity_mask): |
|
|
|
|
|
outputs = self.electra(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
last_hidden_state = outputs.last_hidden_state |
|
|
|
masked_last_hidden = self.entity_average(last_hidden_state, entity_mask) |
|
masked_last_hidden = self.fc2(masked_last_hidden) |
|
|
|
last_hidden_state = self.fc1(last_hidden_state) |
|
entity_outputs = torch.cat([last_hidden_state[:, 0, :] , masked_last_hidden], dim=-1) |
|
|
|
outputs = torch.tanh(entity_outputs) |
|
outputs = self.dropout(outputs) |
|
outputs = self.fc3(outputs) |
|
|
|
|
|
return outputs |
|
|
|
@staticmethod |
|
def entity_average(hidden_output, e_mask): |
|
e_mask_unsqueeze = e_mask.unsqueeze(1) |
|
length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1) |
|
|
|
|
|
sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1) |
|
avg_vector = sum_vector.float() / length_tensor.float() |
|
return avg_vector |