nlp_group_project / models /bert_classifier.py
DanilO0o's picture
added new model
edcd390
raw
history blame
1.15 kB
import torch
from torch import nn
from transformers import AutoModel
class MyTinyBERT(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained("cointegrated/rubert-tiny2")
for param in self.bert.parameters():
param.requires_grad = True
# Разморозка последних слоёв
for name, param in self.bert.named_parameters():
if any(layer in name for layer in ['layer.7', 'layer.8', 'layer.9', 'layer.10', 'layer.11']):
param.requires_grad = True
self.linear = nn.Sequential(
nn.Linear(312, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 10)) # Для 10 классов
def forward(self, input_dict):
# Ожидается словарь с ключами "input_ids" и "attention_mask"
bert_out = self.bert(**input_dict)
# Используем скрытое состояние для [CLS] токена
normed_bert_out = nn.functional.normalize(bert_out.last_hidden_state[:, 0, :])
return self.linear(normed_bert_out)