File size: 760 Bytes
5683562
 
 
 
 
c58d472
 
5683562
 
 
 
 
 
 
 
 
 
c58d472
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.nn as nn
from transformers import BertModel

class HybridModel(nn.Module):
    def __init__(self, dropout=0.3):
        super(HybridModel, self).__init__()
        self.bert = BertModel.from_pretrained("indobenchmark/indobert-base-p1")
        self.lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(128 * 2, 12)  # 12 label

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        lstm_out, _ = self.lstm(outputs.last_hidden_state)
        x = self.dropout(lstm_out[:, -1, :])
        return torch.sigmoid(self.classifier(x))