File size: 1,558 Bytes
4736ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch.nn as nn
from torch import cat
from transformers import DistilBertModel

class JobFakeModel(nn.Module):
    def __init__(self, base_model, freeze_base):
        super(JobFakeModel, self).__init__()
        self.base_model = base_model
        self.fc = nn.Sequential(
            nn.Linear(768*3, 600),
            nn.ReLU(),
            nn.Linear(600, 300),
            nn.ReLU(),
            nn.Linear(300, 1)
        )
        self.head1, self.head2, self.head3 = self._create_base_model()

        if freeze_base:
            for param in self.head1.parameters():
                param.requires_grad = False
            for param in self.head2.parameters():
                param.requires_grad = False
            for param in self.head3.parameters():
                param.requires_grad = False
    
    def forward(self, x, y , z):
        x = self.head1(**x).last_hidden_state.mean(dim=1)
        y = self.head2(**y).last_hidden_state.mean(dim=1)
        z = self.head3(**z).last_hidden_state.mean(dim=1)
        output = cat([x, y, z], dim=1)
        output = self.fc(output)
        return output
    
    def _create_base_model(self):
        if self.base_model == "distilbert":
            model1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
            model2 = DistilBertModel.from_pretrained("distilbert-base-uncased")
            model3 = DistilBertModel.from_pretrained("distilbert-base-uncased")
            return model1, model2, model3
        else:
            raise ValueError("Model not supported")