import torch import torch.nn as nn from transformers import BertModel class HybridModel(nn.Module): def __init__(self, dropout=0.3): super(HybridModel, self).__init__() self.bert = BertModel.from_pretrained("indobenchmark/indobert-base-p1") self.lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True) self.dropout = nn.Dropout(dropout) self.classifier = nn.Linear(128 * 2, 12) # 12 label def forward(self, input_ids, attention_mask): with torch.no_grad(): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) lstm_out, _ = self.lstm(outputs.last_hidden_state) x = self.dropout(lstm_out[:, -1, :]) return torch.sigmoid(self.classifier(x))