File size: 1,203 Bytes
edcd390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch import nn

class LSTMAttention(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super(LSTMAttention, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.attention = nn.Linear(hidden_dim * 2, 1)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, input_ids):
        # Embedding слой
        embedded = self.embedding(input_ids)  # (batch_size, seq_len, embedding_dim)

        # LSTM слой
        lstm_out, _ = self.lstm(embedded)  # (batch_size, seq_len, hidden_dim*2)

        # Механизм внимания
        attn_weights = torch.softmax(self.attention(lstm_out), dim=1)  # (batch_size, seq_len, 1)

        # Вектор контекста
        context_vector = torch.sum(attn_weights * lstm_out, dim=1)  # (batch_size, hidden_dim*2)

        # Классификатор
        output = self.fc(self.dropout(context_vector))  # (batch_size, output_dim)

        return output, attn_weights.squeeze(-1)