shredder-31 commited on
Commit
35ee781
·
verified ·
1 Parent(s): 76bddca

Update trainning.py

Browse files
Files changed (1) hide show
  1. trainning.py +239 -197
trainning.py CHANGED
@@ -1,10 +1,15 @@
1
  import torch
 
2
  import string
3
  import torch.nn.functional as F
4
  import torch.nn as nn
5
  import torchvision.models as models
6
 
7
- def decoder(indices, vocab):
 
 
 
 
8
  tokens = [vocab.lookup_token(idx) for idx in indices]
9
  words = []
10
  current_word = []
@@ -25,203 +30,134 @@ def decoder(indices, vocab):
25
  def beam_search_caption(model, images, vocab, decoder, device="cpu",
26
  start_token="<sos>", end_token="<eos>",
27
  beam_width=3, max_seq_length=100):
28
- """
29
- Generates captions for images using beam search.
30
-
31
- Args:
32
- model (ImgCap): The image captioning model.
33
- images (torch.Tensor): Batch of images.
34
- vocab (Vocab): Vocabulary object.
35
- decoder (function): Function to decode indices to words.
36
- device (str): Device to perform computation on.
37
- start_token (str): Start-of-sequence token.
38
- end_token (str): End-of-sequence token.
39
- beam_width (int): Number of beams to keep.
40
- max_seq_length (int): Maximum length of the generated caption.
41
-
42
- Returns:
43
- list: Generated captions for each image in the batch.
44
- """
45
- model.eval()
46
-
47
- with torch.no_grad():
48
- start_index = vocab[start_token]
49
- end_index = vocab[end_token]
50
- images = images.to(device)
51
- batch_size = images.size(0)
52
-
53
- # Ensure batch_size is 1 for beam search (one image at a time)
54
- if batch_size != 1:
55
- raise ValueError("Beam search currently supports batch_size=1.")
56
-
57
- cnn_feature = model.cnn(images) # Shape: (1, 1024)
58
- lstm_input = model.lstm.projection(cnn_feature).unsqueeze(1) # Shape: (1, 1, 1024)
59
- state = None # Initial LSTM state
60
-
61
- # Initialize the beam with the start token
62
- sequences = [([start_index], 0.0, lstm_input, state)] # List of tuples: (sequence, score, input, state)
63
-
64
- completed_sequences = []
65
-
66
- for _ in range(max_seq_length):
67
- all_candidates = []
68
-
69
- # Iterate over all current sequences in the beam
70
- for seq, score, lstm_input, state in sequences:
71
- # If the last token is the end token, add the sequence to completed_sequences
72
- if seq[-1] == end_index:
73
- completed_sequences.append((seq, score))
74
- continue
75
-
76
- # Pass the current input and state through the LSTM
77
- lstm_out, state_new = model.lstm.lstm(lstm_input, state) # lstm_out: (1, 1, 1024)
78
-
79
- # Pass the LSTM output through the fully connected layer to get logits
80
- output = model.lstm.fc(lstm_out.squeeze(1)) # Shape: (1, vocab_size)
81
-
82
- # Compute log probabilities
83
- log_probs = F.log_softmax(output, dim=1) # Shape: (1, vocab_size)
84
-
85
- # Get the top beam_width tokens and their log probabilities
86
- top_log_probs, top_indices = log_probs.topk(beam_width, dim=1) # Each of shape: (1, beam_width)
87
-
88
- # Iterate over the top tokens to create new candidate sequences
89
- for i in range(beam_width):
90
- token = top_indices[0, i].item()
91
- token_log_prob = top_log_probs[0, i].item()
92
-
93
- # Create a new sequence by appending the current token
94
- new_seq = seq + [token]
95
- new_score = score + token_log_prob
96
-
97
- # Get the embedding of the new token
98
- token_tensor = torch.tensor([token], device=device)
99
- new_lstm_input = model.lstm.embedding(token_tensor).unsqueeze(1) # Shape: (1, 1, 1024)
100
-
101
- # Clone the new state to ensure each beam has its own state
102
- if state_new is not None:
103
- new_state = (state_new[0].clone(), state_new[1].clone())
104
- else:
105
- new_state = None
106
-
107
- # Add the new candidate to all_candidates
108
- all_candidates.append((new_seq, new_score, new_lstm_input, new_state))
109
-
110
- # If no candidates are left to process, break out of the loop
111
- if not all_candidates:
112
- break
113
 
114
- # Sort all candidates by score in descending order
115
- ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
116
-
117
- # Select the top beam_width sequences to form the new beam
118
- sequences = ordered[:beam_width]
119
 
120
- # If enough completed sequences are found, stop early
121
- if len(completed_sequences) >= beam_width:
122
- break
123
 
124
- # If no sequences have completed, use the current sequences
125
- if len(completed_sequences) == 0:
126
- completed_sequences = sequences
127
 
128
- # Select the sequence with the highest score
129
- best_seq, best_score = max(completed_sequences, key=lambda x: x[1])
 
130
 
131
- if best_seq[0] == start_index:
132
- best_seq = best_seq[1:]
133
 
134
- best_caption = decoder(best_seq, vocab)
135
 
136
- return best_caption
137
 
 
 
138
 
139
- def generate_caption(model, images, vocab, decoder, device="cpu", start_token="<sos>", end_token="<eos>", max_seq_length=100, top_k=2):
140
- model.eval()
 
 
141
 
142
- with torch.no_grad():
143
- start_index = vocab[start_token]
144
- end_index = vocab[end_token]
145
- images = images.to(device)
146
- batch_size = images.size(0)
147
 
148
- end_token_appear = {i: False for i in range(batch_size)}
149
- captions = [[] for _ in range(batch_size)]
150
 
151
- cnn_feature = model.cnn(images)
152
- lstm_input = model.lstm.projection(cnn_feature).unsqueeze(1) # (B, 1, hidden_size)
153
 
154
- state = None
155
 
156
- for i in range(max_seq_length):
157
- lstm_out, state = model.lstm.lstm(lstm_input, state)
158
- output = model.lstm.fc(lstm_out.squeeze(1))
159
 
160
- top_k_probs, top_k_indices = torch.topk(F.softmax(output, dim=1), top_k, dim=1)
161
- top_k_probs = top_k_probs / torch.sum(top_k_probs, dim=1, keepdim=True)
162
- top_k_samples = torch.multinomial(top_k_probs, 1).squeeze()
163
 
164
- predicted_word_indices = top_k_indices[range(batch_size), top_k_samples]
 
 
 
165
 
166
- lstm_input = model.lstm.embedding(predicted_word_indices).unsqueeze(1) # (B, 1, hidden_size)
 
 
 
167
 
168
- for j in range(batch_size):
169
- if end_token_appear[j]:
170
- continue
171
 
172
- word = vocab.lookup_token(predicted_word_indices[j].item())
173
- if word == end_token:
174
- end_token_appear[j] = True
175
 
176
- captions[j].append(predicted_word_indices[j].item())
177
 
178
- captions = [decoder(caption, vocab) for caption in captions]
179
 
180
- return captions
 
181
 
 
 
 
 
 
182
 
 
183
 
184
 
 
185
  class ResNet50(nn.Module):
186
  def __init__(self):
187
  super(ResNet50, self).__init__()
188
  self.ResNet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
189
-
190
- self.ResNet50.fc = nn.Sequential(
191
- nn.Linear(2048, 1024),
192
- nn.ReLU(),
193
- nn.Dropout(0.5),
194
- nn.Linear(1024, 1024),
195
- nn.ReLU(),
196
- )
197
-
198
- for k,v in self.ResNet50.named_parameters(recurse=True):
199
- if 'fc' in k:
200
- v.requires_grad = True
201
- else:
202
- v.requires_grad = False
203
-
204
- def forward(self,x):
205
- return self.ResNet50(x)
206
-
207
- ## lSTM (Decoder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  class lstm(nn.Module):
210
- def __init__(self, input_size, hidden_size, number_layers, embedding_dim, vocab_size):
211
  super(lstm, self).__init__()
212
 
213
- self.input_size = input_size
214
  self.hidden_size = hidden_size
215
- self.number_layers = number_layers
216
- self.embedding_dim = embedding_dim
217
- self.vocab_size = vocab_size
218
-
219
  self.embedding = nn.Embedding(vocab_size, hidden_size)
220
- self.projection = nn.Linear(input_size, hidden_size)
221
- self.relu = nn.ReLU()
222
 
223
  self.lstm = nn.LSTM(
224
- input_size=hidden_size,
225
  hidden_size=hidden_size,
226
  num_layers=number_layers,
227
  dropout=0.5,
@@ -230,36 +166,49 @@ class lstm(nn.Module):
230
 
231
  self.fc = nn.Linear(hidden_size, vocab_size)
232
 
233
- def forward(self, x, captions):
234
- projected_image = self.projection(x).unsqueeze(dim=1)
235
- embeddings = self.embedding(captions[:, :-1])
236
 
237
- # Concatenate the image feature as frist step with word embeddings
238
- lstm_input = torch.cat((projected_image, embeddings), dim=1)
239
- # print(torch.all(projected_image[:, 0, :] == lstm_input[:, 0, :])) # check
240
 
241
- lstm_out, _ = self.lstm(lstm_input)
242
- logits = self.fc(lstm_out)
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- return logits
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- ## ImgCap
247
 
248
  class ImgCap(nn.Module):
249
- def __init__(self, cnn_feature_size, lstm_hidden_size, num_layers, vocab_size, embedding_dim):
250
  super(ImgCap, self).__init__()
251
-
252
  self.cnn = ResNet50()
253
-
254
- self.lstm = lstm(input_size=cnn_feature_size,
255
- hidden_size=lstm_hidden_size,
256
- number_layers=num_layers,
257
- embedding_dim=embedding_dim,
258
- vocab_size=vocab_size)
259
 
260
  def forward(self, images, captions):
261
  cnn_features = self.cnn(images)
262
- output = self.lstm(cnn_features, captions)
263
  return output
264
 
265
  def generate_caption(self, images, vocab, decoder, device="cpu", start_token="<sos>", end_token="<eos>", max_seq_length=100):
@@ -271,31 +220,124 @@ class ImgCap(nn.Module):
271
  images = images.to(device)
272
  batch_size = images.size(0)
273
 
274
- end_token_appear = {i: False for i in range(batch_size)}
275
- captions = [[] for _ in range(batch_size)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- cnn_feature = self.cnn(images)
278
- lstm_input = self.lstm.projection(cnn_feature).unsqueeze(1) # (B, 1, hidden_size)
279
 
280
- state = None
281
 
282
- for i in range(max_seq_length):
283
- lstm_out, state = self.lstm.lstm(lstm_input, state)
284
- output = self.lstm.fc(lstm_out.squeeze(1))
285
- predicted_word_indices = torch.argmax(output, dim=1)
286
- lstm_input = self.lstm.embedding(predicted_word_indices).unsqueeze(1) # (B, 1, hidden_size)
287
 
288
- for j in range(batch_size):
289
- if end_token_appear[j]:
 
290
  continue
291
 
292
- word = vocab.lookup_token(predicted_word_indices[j].item())
293
- if word == end_token:
294
- end_token_appear[j] = True
295
 
296
- captions[j].append(predicted_word_indices[j].item())
297
 
298
- captions = [decoder(caption) for caption in captions]
299
 
300
- return captions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
 
 
1
  import torch
2
+ import pickle
3
  import string
4
  import torch.nn.functional as F
5
  import torch.nn as nn
6
  import torchvision.models as models
7
 
8
+ def decoder(indices):
9
+
10
+ with open(f"/teamspace/studios/this_studio/ImgCap/vocab.pkl", 'rb') as f:
11
+ vocab = pickle.load(f)
12
+
13
  tokens = [vocab.lookup_token(idx) for idx in indices]
14
  words = []
15
  current_word = []
 
30
  def beam_search_caption(model, images, vocab, decoder, device="cpu",
31
  start_token="<sos>", end_token="<eos>",
32
  beam_width=3, max_seq_length=100):
33
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ with torch.no_grad():
36
+ start_index = vocab[start_token]
37
+ end_index = vocab[end_token]
38
+ images = images.to(device)
39
+ batch_size = images.size(0)
40
 
41
+ # Ensure batch_size is 1 for beam search (one image at a time)
42
+ if batch_size != 1:
43
+ raise ValueError("Beam search currently supports batch_size=1.")
44
 
45
+ cnn_features = model.cnn(images) # (B, 49, 2048)
46
+ h, c = model.lstm.init_hidden_state(batch_size)
47
+ word_input = torch.full((batch_size,), start_index, dtype=torch.long).to(device)
48
 
49
+ embeddings = model.lstm.embedding(word_input)
50
+ context, _ = model.lstm.attention(cnn_features, h[-1])
51
+ lstm_input = torch.cat([embeddings, context], dim=1).unsqueeze(1)
52
 
 
 
53
 
54
+ sequences = [([start_index], 0.0, lstm_input, (h, c))] # List of tuples: (sequence, score, input, state)
55
 
56
+ completed_sequences = []
57
 
58
+ for _ in range(max_seq_length):
59
+ all_candidates = []
60
 
61
+ for seq, score, lstm_input, (h,c) in sequences:
62
+ if seq[-1] == end_index:
63
+ completed_sequences.append((seq, score))
64
+ continue
65
 
66
+ lstm_out, (h_new, c_new) = model.lstm.lstm(lstm_input, (h, c)) # lstm_out: (1, 1, 1024)
 
 
 
 
67
 
68
+ output = model.lstm.fc(lstm_out.squeeze(1)) # Shape: (1, vocab_size)
 
69
 
70
+ log_probs = F.log_softmax(output, dim=1) # Shape: (1, vocab_size)
 
71
 
72
+ top_log_probs, top_indices = log_probs.topk(beam_width, dim=1) # Each of shape: (1, beam_width)
73
 
74
+ for i in range(beam_width):
75
+ token = top_indices[0, i].item()
76
+ token_log_prob = top_log_probs[0, i].item()
77
 
78
+ new_seq = seq + [token]
79
+ new_score = score + token_log_prob
 
80
 
81
+ token_tensor = torch.tensor([token], device=device)
82
+ embeddings = model.lstm.embedding(token_tensor)
83
+ context, _ = model.lstm.attention(cnn_features, h_new[-1])
84
+ new_lstm_input = torch.cat([embeddings, context], dim=1).unsqueeze(1)
85
 
86
+ if h_new is not None and c_new is not None:
87
+ h_new, c_new = (h_new.clone(), c_new.clone())
88
+ else:
89
+ h_new, c_new = None, None
90
 
91
+ all_candidates.append((new_seq, new_score, new_lstm_input, (h_new, c_new) ))
 
 
92
 
93
+ if not all_candidates:
94
+ break
 
95
 
96
+ ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
97
 
98
+ sequences = ordered[:beam_width]
99
 
100
+ if len(completed_sequences) >= beam_width:
101
+ break
102
 
103
+ if len(completed_sequences) == 0:
104
+ completed_sequences = sequences
105
+
106
+ best_seq = max(completed_sequences, key=lambda x: x[1])
107
+ best_caption = decoder(best_seq[0])
108
 
109
+ return best_caption
110
 
111
 
112
+ ## ResNet50 (CNN Encoder)
113
  class ResNet50(nn.Module):
114
  def __init__(self):
115
  super(ResNet50, self).__init__()
116
  self.ResNet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
117
+ self.features = nn.Sequential(*list(self.ResNet50.children())[:-2])
118
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
119
+
120
+ for param in self.ResNet50.parameters():
121
+ param.requires_grad = False
122
+
123
+ def forward(self, x):
124
+ x = self.features(x)
125
+ x = self.avgpool(x)
126
+ B, C, H, W = x.size()
127
+ x = x.view(B, C, -1) # Flatten spatial dimensions: (B, 2048, 49)
128
+ x = x.permute(0, 2, 1) # (B, 49, 2048) - 49 spatial locations
129
+ return x
130
+
131
+ class Attention(nn.Module):
132
+ def __init__(self, feature_size, hidden_size):
133
+ super(Attention, self).__init__()
134
+ self.attention = nn.Linear(feature_size + hidden_size, hidden_size)
135
+ self.attn_weights = nn.Linear(hidden_size, 1)
136
+
137
+ def forward(self, features, hidden_state): # features: (B, 49, 2048), hidden_state: (B, hidden_size)
138
+ hidden_state = hidden_state.unsqueeze(1).repeat(1, features.size(1), 1) # (B, 49, hidden_size)
139
+ combined = torch.cat((features, hidden_state), dim=2) # (B, 49, feature_size + hidden_size)
140
+ attn_hidden = torch.tanh(self.attention(combined)) # (B, 49, hidden_size)
141
+ attention_logits = self.attn_weights(attn_hidden).squeeze(2) # (B, 49)
142
+ attention_weights = torch.softmax(attention_logits, dim=1) # (B, 49)
143
+ context = (features * attention_weights.unsqueeze(2)).sum(dim=1) # (B, 2048)
144
+ return context, attention_weights
145
+
146
+ # Attention without learnable paramters:
147
+ # logits = torch.matmul(features, hidden_state.unsqueeze(2)) # (B, 49, 1) - Batch Matriax
148
+ # attention_weights = torch.softmax(logits, dim=1).squeeze(2) # (B, 49)
149
+ # context = (features * attention_weights.unsqueeze(2)).sum(dim=1) # (B, 2048)
150
 
151
  class lstm(nn.Module):
152
+ def __init__(self, feature_size, hidden_size, number_layers, embedding_dim, vocab_size):
153
  super(lstm, self).__init__()
154
 
 
155
  self.hidden_size = hidden_size
 
 
 
 
156
  self.embedding = nn.Embedding(vocab_size, hidden_size)
157
+ self.attention = Attention(feature_size, hidden_size)
 
158
 
159
  self.lstm = nn.LSTM(
160
+ input_size=hidden_size + feature_size, # input: concatenated context and word embedding
161
  hidden_size=hidden_size,
162
  num_layers=number_layers,
163
  dropout=0.5,
 
166
 
167
  self.fc = nn.Linear(hidden_size, vocab_size)
168
 
169
+ def forward(self, features, captions=None, max_seq_len=None, teacher_forcing_ratio=0.90):
 
 
170
 
171
+ batch_size = features.size(0)
172
+ max_seq_len = max_seq_len if max_seq_len is not None else captions.size(1)
173
+ h, c = self.init_hidden_state(batch_size)
174
 
175
+ outputs = torch.zeros(batch_size, max_seq_len, self.fc.out_features).to(features.device)
176
+ word_input = torch.tensor(2, dtype=torch.long).expand(batch_size).to(features.device) # vocab["<sos>"] ---> 2
177
+
178
+ for t in range(1, max_seq_len):
179
+ embeddings = self.embedding(word_input)
180
+ context, _ = self.attention(features, h[-1])
181
+ lstm_input_step = torch.cat([embeddings, context], dim=1).unsqueeze(1) # Combine context + word embedding
182
+
183
+ out, (h, c) = self.lstm(lstm_input_step, (h, c))
184
+ output = self.fc(out.squeeze(1))
185
+ outputs[:, t, :] = output
186
+
187
+ top1 = output.argmax(1)
188
 
189
+ if captions is not None and torch.rand(1).item() < teacher_forcing_ratio:
190
+ word_input = captions[:, t]
191
+ else:
192
+ word_input = top1
193
+
194
+ return outputs
195
+
196
+ def init_hidden_state(self, batch_size):
197
+ device = next(self.parameters()).device
198
+ h0 = torch.zeros(self.lstm.num_layers, batch_size, self.hidden_size).to(device)
199
+ c0 = torch.zeros(self.lstm.num_layers, batch_size, self.hidden_size).to(device)
200
+ return (h0, c0)
201
 
 
202
 
203
  class ImgCap(nn.Module):
204
+ def __init__(self, feature_size, lstm_hidden_size, num_layers, vocab_size, embedding_dim):
205
  super(ImgCap, self).__init__()
 
206
  self.cnn = ResNet50()
207
+ self.lstm = lstm(feature_size, lstm_hidden_size, num_layers, embedding_dim, vocab_size)
 
 
 
 
 
208
 
209
  def forward(self, images, captions):
210
  cnn_features = self.cnn(images)
211
+ output = self.lstm(cnn_features, captions)
212
  return output
213
 
214
  def generate_caption(self, images, vocab, decoder, device="cpu", start_token="<sos>", end_token="<eos>", max_seq_length=100):
 
220
  images = images.to(device)
221
  batch_size = images.size(0)
222
 
223
+ captions = [[start_index,] for _ in range(batch_size)]
224
+ end_token_appear = [False] * batch_size
225
+
226
+ cnn_features = self.cnn(images) # (B, 49, 2048)
227
+
228
+ h, c = self.lstm.init_hidden_state(batch_size)
229
+
230
+ word_input = torch.full((batch_size,), start_index, dtype=torch.long).to(device)
231
+
232
+ for t in range(max_seq_length):
233
+
234
+ embeddings = self.lstm.embedding(word_input)
235
+ context, _ = self.lstm.attention(cnn_features, h[-1]) # Attention context
236
+ lstm_input_step = torch.cat([embeddings, context], dim=1).unsqueeze(1) # Combine context + word embedding
237
+
238
+ out, (h, c) = self.lstm.lstm(lstm_input_step, (h, c))
239
+
240
+ output = self.lstm.fc(out.squeeze(1)) # (B, vocab_size)
241
+
242
+ # Get the predicted word (greedy search)
243
+ predicted_word_indices = torch.argmax(output, dim=1) # (B,)
244
+ word_input = predicted_word_indices
245
+
246
+
247
+ for i in range(batch_size):
248
+ if not end_token_appear[i]:
249
+ predicted_word = vocab.lookup_token(predicted_word_indices[i].item())
250
+ if predicted_word == end_token:
251
+ captions[i].append(predicted_word_indices[i].item())
252
+ end_token_appear[i] = True
253
+ else:
254
+ captions[i].append(predicted_word_indices[i].item())
255
+
256
+
257
+ if all(end_token_appear): # Stop if all captions have reached the <eos> token
258
+ break
259
+
260
+ captions = [decoder(caption) for caption in captions]
261
+
262
+ return captions
263
+
264
+ def beam_search_caption(self, images, vocab, decoder, device="cpu",
265
+ start_token="<sos>", end_token="<eos>",
266
+ beam_width=3, max_seq_length=100):
267
+ self.eval()
268
+
269
+ with torch.no_grad():
270
+ start_index = vocab[start_token]
271
+ end_index = vocab[end_token]
272
+ images = images.to(device)
273
+ batch_size = images.size(0)
274
+
275
+ # Ensure batch_size is 1 for beam search (one image at a time)
276
+ if batch_size != 1:
277
+ raise ValueError("Beam search currently supports batch_size=1.")
278
+
279
+ cnn_features = self.cnn(images) # (B, 49, 2048)
280
+ h, c = self.lstm.init_hidden_state(batch_size)
281
+ word_input = torch.full((batch_size,), start_index, dtype=torch.long).to(device)
282
+
283
+ embeddings = self.lstm.embedding(word_input)
284
+ context, _ = self.lstm.attention(cnn_features, h[-1])
285
+ lstm_input = torch.cat([embeddings, context], dim=1).unsqueeze(1)
286
+
287
 
288
+ sequences = [([start_index], 0.0, lstm_input, (h, c))] # List of tuples: (sequence, score, input, state)
 
289
 
290
+ completed_sequences = []
291
 
292
+ for _ in range(max_seq_length):
293
+ all_candidates = []
 
 
 
294
 
295
+ for seq, score, lstm_input, (h,c) in sequences:
296
+ if seq[-1] == end_index:
297
+ completed_sequences.append((seq, score))
298
  continue
299
 
300
+ lstm_out, (h_new, c_new) = model.lstm.lstm(lstm_input, (h, c)) # lstm_out: (1, 1, 1024)
 
 
301
 
302
+ output = model.lstm.fc(lstm_out.squeeze(1)) # Shape: (1, vocab_size)
303
 
304
+ log_probs = F.log_softmax(output, dim=1) # Shape: (1, vocab_size)
305
 
306
+ top_log_probs, top_indices = log_probs.topk(beam_width, dim=1) # Each of shape: (1, beam_width)
307
+
308
+ for i in range(beam_width):
309
+ token = top_indices[0, i].item()
310
+ token_log_prob = top_log_probs[0, i].item()
311
+
312
+ new_seq = seq + [token]
313
+ new_score = score + token_log_prob
314
+
315
+ token_tensor = torch.tensor([token], device=device)
316
+ embeddings = self.lstm.embedding(token_tensor)
317
+ context, _ = self.lstm.attention(cnn_features, h_new[-1])
318
+ new_lstm_input = torch.cat([embeddings, context], dim=1).unsqueeze(1)
319
+
320
+ if h_new is not None and c_new is not None:
321
+ h_new, c_new = (h_new.clone(), c_new.clone())
322
+ else:
323
+ h_new, c_new = None, None
324
+
325
+ all_candidates.append((new_seq, new_score, new_lstm_input, (h_new, c_new) ))
326
+
327
+ if not all_candidates:
328
+ break
329
+
330
+ ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
331
+
332
+ sequences = ordered[:beam_width]
333
+
334
+ if len(completed_sequences) >= beam_width:
335
+ break
336
+
337
+ if len(completed_sequences) == 0:
338
+ completed_sequences = sequences
339
+
340
+ best_seq = max(completed_sequences, key=lambda x: x[1])
341
+ best_caption = decoder(best_seq[0], vocab)
342
 
343
+ return best_caption