File size: 2,617 Bytes
0cd2c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import json
import os
import numpy as np
import torch
import torch.nn as nn
from tqdm import trange
from transformers import ElectraModel, AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoConfig
from torch.utils.data import DataLoader, TensorDataset
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW
from datasets import load_metric
from sklearn.metrics import f1_score
import pandas as pd
import copy 
# from utils import evaluation, evaluation_f1
from torch.nn import functional as F
import re
from config import entity_property_pair
from tqdm import tqdm
from datasets import Dataset
import torch.nn as nn
from transformers import AutoModelForSequenceClassification
from transformers import ElectraModel



class Classifier(nn.Module):
    def __init__(self, base_model, num_labels, device, tokenizer):
        super(Classifier, self).__init__()
        self.num_labels = num_labels
        self.device = device
                
        self.electra = ElectraModel.from_pretrained('beomi/KcELECTRA-base', num_labels=2)
        self.electra.resize_token_embeddings(len(tokenizer))    
        
        self.fc1 = nn.Linear(self.electra.config.hidden_size, 256)
        self.fc2 = nn.Linear(self.electra.config.hidden_size, 512)
        self.fc3 = nn.Linear(256+512, 2)
        
        self.dropout = nn.Dropout(0.1)
        

    def forward(self, input_ids, attention_mask, entity_mask):
  
                
        outputs = self.electra(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_state  = outputs.last_hidden_state  
        
        masked_last_hidden = self.entity_average(last_hidden_state, entity_mask)
        masked_last_hidden = self.fc2(masked_last_hidden)       
        
        last_hidden_state = self.fc1(last_hidden_state)         
        entity_outputs = torch.cat([last_hidden_state[:, 0, :]  , masked_last_hidden], dim=-1)        
         
        outputs = torch.tanh(entity_outputs)
        outputs = self.dropout(outputs)
        outputs = self.fc3(outputs)
        

        return outputs
    
    @staticmethod    
    def entity_average(hidden_output, e_mask):        
        e_mask_unsqueeze = e_mask.unsqueeze(1)  # [b, 1, j-i+1]
        length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1)  # [batch_size, 1]

        # [b, 1, j-i+1] * [b, j-i+1, dim] = [b, 1, dim] -> [b, dim]
        sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1)
        avg_vector = sum_vector.float() / length_tensor.float()  # broadcasting
        return avg_vector