Spaces:
Sleeping
Sleeping
File size: 1,203 Bytes
edcd390 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
from torch import nn
class LSTMAttention(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
super(LSTMAttention, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
self.attention = nn.Linear(hidden_dim * 2, 1)
self.fc = nn.Linear(hidden_dim * 2, output_dim)
self.dropout = nn.Dropout(0.5)
def forward(self, input_ids):
# Embedding слой
embedded = self.embedding(input_ids) # (batch_size, seq_len, embedding_dim)
# LSTM слой
lstm_out, _ = self.lstm(embedded) # (batch_size, seq_len, hidden_dim*2)
# Механизм внимания
attn_weights = torch.softmax(self.attention(lstm_out), dim=1) # (batch_size, seq_len, 1)
# Вектор контекста
context_vector = torch.sum(attn_weights * lstm_out, dim=1) # (batch_size, hidden_dim*2)
# Классификатор
output = self.fc(self.dropout(context_vector)) # (batch_size, output_dim)
return output, attn_weights.squeeze(-1)
|