File size: 760 Bytes
5683562 c58d472 5683562 c58d472 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
import torch.nn as nn
from transformers import BertModel
class HybridModel(nn.Module):
def __init__(self, dropout=0.3):
super(HybridModel, self).__init__()
self.bert = BertModel.from_pretrained("indobenchmark/indobert-base-p1")
self.lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(128 * 2, 12) # 12 label
def forward(self, input_ids, attention_mask):
with torch.no_grad():
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
lstm_out, _ = self.lstm(outputs.last_hidden_state)
x = self.dropout(lstm_out[:, -1, :])
return torch.sigmoid(self.classifier(x))
|