FikriRiyadi's picture
Update model.py
c58d472 verified
raw
history blame contribute delete
760 Bytes
import torch
import torch.nn as nn
from transformers import BertModel
class HybridModel(nn.Module):
def __init__(self, dropout=0.3):
super(HybridModel, self).__init__()
self.bert = BertModel.from_pretrained("indobenchmark/indobert-base-p1")
self.lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(128 * 2, 12) # 12 label
def forward(self, input_ids, attention_mask):
with torch.no_grad():
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
lstm_out, _ = self.lstm(outputs.last_hidden_state)
x = self.dropout(lstm_out[:, -1, :])
return torch.sigmoid(self.classifier(x))