Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from torch import cat | |
from transformers import DistilBertModel | |
class JobFakeModel(nn.Module): | |
def __init__(self, base_model, freeze_base): | |
super(JobFakeModel, self).__init__() | |
self.base_model = base_model | |
self.fc = nn.Sequential( | |
nn.Linear(768*3, 600), | |
nn.ReLU(), | |
nn.Linear(600, 300), | |
nn.ReLU(), | |
nn.Linear(300, 1) | |
) | |
self.head1, self.head2, self.head3 = self._create_base_model() | |
if freeze_base: | |
for param in self.head1.parameters(): | |
param.requires_grad = False | |
for param in self.head2.parameters(): | |
param.requires_grad = False | |
for param in self.head3.parameters(): | |
param.requires_grad = False | |
def forward(self, x, y , z): | |
x = self.head1(**x).last_hidden_state.mean(dim=1) | |
y = self.head2(**y).last_hidden_state.mean(dim=1) | |
z = self.head3(**z).last_hidden_state.mean(dim=1) | |
output = cat([x, y, z], dim=1) | |
output = self.fc(output) | |
return output | |
def _create_base_model(self): | |
if self.base_model == "distilbert": | |
model1 = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
model2 = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
model3 = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
return model1, model2, model3 | |
else: | |
raise ValueError("Model not supported") | |