ABSA_APT / models.py
mjlee
0708
0cd2c97
import json
import os
import numpy as np
import torch
import torch.nn as nn
from tqdm import trange
from transformers import ElectraModel, AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoConfig
from torch.utils.data import DataLoader, TensorDataset
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW
from datasets import load_metric
from sklearn.metrics import f1_score
import pandas as pd
import copy
# from utils import evaluation, evaluation_f1
from torch.nn import functional as F
import re
from config import entity_property_pair
from tqdm import tqdm
from datasets import Dataset
import torch.nn as nn
from transformers import AutoModelForSequenceClassification
from transformers import ElectraModel
class Classifier(nn.Module):
def __init__(self, base_model, num_labels, device, tokenizer):
super(Classifier, self).__init__()
self.num_labels = num_labels
self.device = device
self.electra = ElectraModel.from_pretrained('beomi/KcELECTRA-base', num_labels=2)
self.electra.resize_token_embeddings(len(tokenizer))
self.fc1 = nn.Linear(self.electra.config.hidden_size, 256)
self.fc2 = nn.Linear(self.electra.config.hidden_size, 512)
self.fc3 = nn.Linear(256+512, 2)
self.dropout = nn.Dropout(0.1)
def forward(self, input_ids, attention_mask, entity_mask):
outputs = self.electra(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
last_hidden_state = outputs.last_hidden_state
masked_last_hidden = self.entity_average(last_hidden_state, entity_mask)
masked_last_hidden = self.fc2(masked_last_hidden)
last_hidden_state = self.fc1(last_hidden_state)
entity_outputs = torch.cat([last_hidden_state[:, 0, :] , masked_last_hidden], dim=-1)
outputs = torch.tanh(entity_outputs)
outputs = self.dropout(outputs)
outputs = self.fc3(outputs)
return outputs
@staticmethod
def entity_average(hidden_output, e_mask):
e_mask_unsqueeze = e_mask.unsqueeze(1) # [b, 1, j-i+1]
length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1) # [batch_size, 1]
# [b, 1, j-i+1] * [b, j-i+1, dim] = [b, 1, dim] -> [b, dim]
sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1)
avg_vector = sum_vector.float() / length_tensor.float() # broadcasting
return avg_vector