nlp_group_project / models /lstm_attention.py
DanilO0o's picture
added new model
edcd390
raw
history blame
1.2 kB
import torch
from torch import nn
class LSTMAttention(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
super(LSTMAttention, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
self.attention = nn.Linear(hidden_dim * 2, 1)
self.fc = nn.Linear(hidden_dim * 2, output_dim)
self.dropout = nn.Dropout(0.5)
def forward(self, input_ids):
# Embedding слой
embedded = self.embedding(input_ids) # (batch_size, seq_len, embedding_dim)
# LSTM слой
lstm_out, _ = self.lstm(embedded) # (batch_size, seq_len, hidden_dim*2)
# Механизм внимания
attn_weights = torch.softmax(self.attention(lstm_out), dim=1) # (batch_size, seq_len, 1)
# Вектор контекста
context_vector = torch.sum(attn_weights * lstm_out, dim=1) # (batch_size, hidden_dim*2)
# Классификатор
output = self.fc(self.dropout(context_vector)) # (batch_size, output_dim)
return output, attn_weights.squeeze(-1)