Spaces:
Sleeping
Sleeping
Update trainning.py
Browse files- trainning.py +239 -197
trainning.py
CHANGED
@@ -1,10 +1,15 @@
|
|
1 |
import torch
|
|
|
2 |
import string
|
3 |
import torch.nn.functional as F
|
4 |
import torch.nn as nn
|
5 |
import torchvision.models as models
|
6 |
|
7 |
-
def decoder(indices
|
|
|
|
|
|
|
|
|
8 |
tokens = [vocab.lookup_token(idx) for idx in indices]
|
9 |
words = []
|
10 |
current_word = []
|
@@ -25,203 +30,134 @@ def decoder(indices, vocab):
|
|
25 |
def beam_search_caption(model, images, vocab, decoder, device="cpu",
|
26 |
start_token="<sos>", end_token="<eos>",
|
27 |
beam_width=3, max_seq_length=100):
|
28 |
-
|
29 |
-
Generates captions for images using beam search.
|
30 |
-
|
31 |
-
Args:
|
32 |
-
model (ImgCap): The image captioning model.
|
33 |
-
images (torch.Tensor): Batch of images.
|
34 |
-
vocab (Vocab): Vocabulary object.
|
35 |
-
decoder (function): Function to decode indices to words.
|
36 |
-
device (str): Device to perform computation on.
|
37 |
-
start_token (str): Start-of-sequence token.
|
38 |
-
end_token (str): End-of-sequence token.
|
39 |
-
beam_width (int): Number of beams to keep.
|
40 |
-
max_seq_length (int): Maximum length of the generated caption.
|
41 |
-
|
42 |
-
Returns:
|
43 |
-
list: Generated captions for each image in the batch.
|
44 |
-
"""
|
45 |
-
model.eval()
|
46 |
-
|
47 |
-
with torch.no_grad():
|
48 |
-
start_index = vocab[start_token]
|
49 |
-
end_index = vocab[end_token]
|
50 |
-
images = images.to(device)
|
51 |
-
batch_size = images.size(0)
|
52 |
-
|
53 |
-
# Ensure batch_size is 1 for beam search (one image at a time)
|
54 |
-
if batch_size != 1:
|
55 |
-
raise ValueError("Beam search currently supports batch_size=1.")
|
56 |
-
|
57 |
-
cnn_feature = model.cnn(images) # Shape: (1, 1024)
|
58 |
-
lstm_input = model.lstm.projection(cnn_feature).unsqueeze(1) # Shape: (1, 1, 1024)
|
59 |
-
state = None # Initial LSTM state
|
60 |
-
|
61 |
-
# Initialize the beam with the start token
|
62 |
-
sequences = [([start_index], 0.0, lstm_input, state)] # List of tuples: (sequence, score, input, state)
|
63 |
-
|
64 |
-
completed_sequences = []
|
65 |
-
|
66 |
-
for _ in range(max_seq_length):
|
67 |
-
all_candidates = []
|
68 |
-
|
69 |
-
# Iterate over all current sequences in the beam
|
70 |
-
for seq, score, lstm_input, state in sequences:
|
71 |
-
# If the last token is the end token, add the sequence to completed_sequences
|
72 |
-
if seq[-1] == end_index:
|
73 |
-
completed_sequences.append((seq, score))
|
74 |
-
continue
|
75 |
-
|
76 |
-
# Pass the current input and state through the LSTM
|
77 |
-
lstm_out, state_new = model.lstm.lstm(lstm_input, state) # lstm_out: (1, 1, 1024)
|
78 |
-
|
79 |
-
# Pass the LSTM output through the fully connected layer to get logits
|
80 |
-
output = model.lstm.fc(lstm_out.squeeze(1)) # Shape: (1, vocab_size)
|
81 |
-
|
82 |
-
# Compute log probabilities
|
83 |
-
log_probs = F.log_softmax(output, dim=1) # Shape: (1, vocab_size)
|
84 |
-
|
85 |
-
# Get the top beam_width tokens and their log probabilities
|
86 |
-
top_log_probs, top_indices = log_probs.topk(beam_width, dim=1) # Each of shape: (1, beam_width)
|
87 |
-
|
88 |
-
# Iterate over the top tokens to create new candidate sequences
|
89 |
-
for i in range(beam_width):
|
90 |
-
token = top_indices[0, i].item()
|
91 |
-
token_log_prob = top_log_probs[0, i].item()
|
92 |
-
|
93 |
-
# Create a new sequence by appending the current token
|
94 |
-
new_seq = seq + [token]
|
95 |
-
new_score = score + token_log_prob
|
96 |
-
|
97 |
-
# Get the embedding of the new token
|
98 |
-
token_tensor = torch.tensor([token], device=device)
|
99 |
-
new_lstm_input = model.lstm.embedding(token_tensor).unsqueeze(1) # Shape: (1, 1, 1024)
|
100 |
-
|
101 |
-
# Clone the new state to ensure each beam has its own state
|
102 |
-
if state_new is not None:
|
103 |
-
new_state = (state_new[0].clone(), state_new[1].clone())
|
104 |
-
else:
|
105 |
-
new_state = None
|
106 |
-
|
107 |
-
# Add the new candidate to all_candidates
|
108 |
-
all_candidates.append((new_seq, new_score, new_lstm_input, new_state))
|
109 |
-
|
110 |
-
# If no candidates are left to process, break out of the loop
|
111 |
-
if not all_candidates:
|
112 |
-
break
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
|
120 |
-
#
|
121 |
-
if
|
122 |
-
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
|
128 |
-
|
129 |
-
|
|
|
130 |
|
131 |
-
if best_seq[0] == start_index:
|
132 |
-
best_seq = best_seq[1:]
|
133 |
|
134 |
-
|
135 |
|
136 |
-
|
137 |
|
|
|
|
|
138 |
|
139 |
-
|
140 |
-
|
|
|
|
|
141 |
|
142 |
-
|
143 |
-
start_index = vocab[start_token]
|
144 |
-
end_index = vocab[end_token]
|
145 |
-
images = images.to(device)
|
146 |
-
batch_size = images.size(0)
|
147 |
|
148 |
-
|
149 |
-
captions = [[] for _ in range(batch_size)]
|
150 |
|
151 |
-
|
152 |
-
lstm_input = model.lstm.projection(cnn_feature).unsqueeze(1) # (B, 1, hidden_size)
|
153 |
|
154 |
-
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
top_k_samples = torch.multinomial(top_k_probs, 1).squeeze()
|
163 |
|
164 |
-
|
|
|
|
|
|
|
165 |
|
166 |
-
|
|
|
|
|
|
|
167 |
|
168 |
-
|
169 |
-
if end_token_appear[j]:
|
170 |
-
continue
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
end_token_appear[j] = True
|
175 |
|
176 |
-
|
177 |
|
178 |
-
|
179 |
|
180 |
-
|
|
|
181 |
|
|
|
|
|
|
|
|
|
|
|
182 |
|
|
|
183 |
|
184 |
|
|
|
185 |
class ResNet50(nn.Module):
|
186 |
def __init__(self):
|
187 |
super(ResNet50, self).__init__()
|
188 |
self.ResNet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
189 |
-
|
190 |
-
self.
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
def
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
class lstm(nn.Module):
|
210 |
-
def __init__(self,
|
211 |
super(lstm, self).__init__()
|
212 |
|
213 |
-
self.input_size = input_size
|
214 |
self.hidden_size = hidden_size
|
215 |
-
self.number_layers = number_layers
|
216 |
-
self.embedding_dim = embedding_dim
|
217 |
-
self.vocab_size = vocab_size
|
218 |
-
|
219 |
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
220 |
-
self.
|
221 |
-
self.relu = nn.ReLU()
|
222 |
|
223 |
self.lstm = nn.LSTM(
|
224 |
-
input_size=hidden_size,
|
225 |
hidden_size=hidden_size,
|
226 |
num_layers=number_layers,
|
227 |
dropout=0.5,
|
@@ -230,36 +166,49 @@ class lstm(nn.Module):
|
|
230 |
|
231 |
self.fc = nn.Linear(hidden_size, vocab_size)
|
232 |
|
233 |
-
def forward(self,
|
234 |
-
projected_image = self.projection(x).unsqueeze(dim=1)
|
235 |
-
embeddings = self.embedding(captions[:, :-1])
|
236 |
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
-
## ImgCap
|
247 |
|
248 |
class ImgCap(nn.Module):
|
249 |
-
def __init__(self,
|
250 |
super(ImgCap, self).__init__()
|
251 |
-
|
252 |
self.cnn = ResNet50()
|
253 |
-
|
254 |
-
self.lstm = lstm(input_size=cnn_feature_size,
|
255 |
-
hidden_size=lstm_hidden_size,
|
256 |
-
number_layers=num_layers,
|
257 |
-
embedding_dim=embedding_dim,
|
258 |
-
vocab_size=vocab_size)
|
259 |
|
260 |
def forward(self, images, captions):
|
261 |
cnn_features = self.cnn(images)
|
262 |
-
output = self.lstm(cnn_features, captions)
|
263 |
return output
|
264 |
|
265 |
def generate_caption(self, images, vocab, decoder, device="cpu", start_token="<sos>", end_token="<eos>", max_seq_length=100):
|
@@ -271,31 +220,124 @@ class ImgCap(nn.Module):
|
|
271 |
images = images.to(device)
|
272 |
batch_size = images.size(0)
|
273 |
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
-
|
278 |
-
lstm_input = self.lstm.projection(cnn_feature).unsqueeze(1) # (B, 1, hidden_size)
|
279 |
|
280 |
-
|
281 |
|
282 |
-
for
|
283 |
-
|
284 |
-
output = self.lstm.fc(lstm_out.squeeze(1))
|
285 |
-
predicted_word_indices = torch.argmax(output, dim=1)
|
286 |
-
lstm_input = self.lstm.embedding(predicted_word_indices).unsqueeze(1) # (B, 1, hidden_size)
|
287 |
|
288 |
-
for
|
289 |
-
if
|
|
|
290 |
continue
|
291 |
|
292 |
-
|
293 |
-
if word == end_token:
|
294 |
-
end_token_appear[j] = True
|
295 |
|
296 |
-
|
297 |
|
298 |
-
|
299 |
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
|
|
|
1 |
import torch
|
2 |
+
import pickle
|
3 |
import string
|
4 |
import torch.nn.functional as F
|
5 |
import torch.nn as nn
|
6 |
import torchvision.models as models
|
7 |
|
8 |
+
def decoder(indices):
|
9 |
+
|
10 |
+
with open(f"/teamspace/studios/this_studio/ImgCap/vocab.pkl", 'rb') as f:
|
11 |
+
vocab = pickle.load(f)
|
12 |
+
|
13 |
tokens = [vocab.lookup_token(idx) for idx in indices]
|
14 |
words = []
|
15 |
current_word = []
|
|
|
30 |
def beam_search_caption(model, images, vocab, decoder, device="cpu",
|
31 |
start_token="<sos>", end_token="<eos>",
|
32 |
beam_width=3, max_seq_length=100):
|
33 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
with torch.no_grad():
|
36 |
+
start_index = vocab[start_token]
|
37 |
+
end_index = vocab[end_token]
|
38 |
+
images = images.to(device)
|
39 |
+
batch_size = images.size(0)
|
40 |
|
41 |
+
# Ensure batch_size is 1 for beam search (one image at a time)
|
42 |
+
if batch_size != 1:
|
43 |
+
raise ValueError("Beam search currently supports batch_size=1.")
|
44 |
|
45 |
+
cnn_features = model.cnn(images) # (B, 49, 2048)
|
46 |
+
h, c = model.lstm.init_hidden_state(batch_size)
|
47 |
+
word_input = torch.full((batch_size,), start_index, dtype=torch.long).to(device)
|
48 |
|
49 |
+
embeddings = model.lstm.embedding(word_input)
|
50 |
+
context, _ = model.lstm.attention(cnn_features, h[-1])
|
51 |
+
lstm_input = torch.cat([embeddings, context], dim=1).unsqueeze(1)
|
52 |
|
|
|
|
|
53 |
|
54 |
+
sequences = [([start_index], 0.0, lstm_input, (h, c))] # List of tuples: (sequence, score, input, state)
|
55 |
|
56 |
+
completed_sequences = []
|
57 |
|
58 |
+
for _ in range(max_seq_length):
|
59 |
+
all_candidates = []
|
60 |
|
61 |
+
for seq, score, lstm_input, (h,c) in sequences:
|
62 |
+
if seq[-1] == end_index:
|
63 |
+
completed_sequences.append((seq, score))
|
64 |
+
continue
|
65 |
|
66 |
+
lstm_out, (h_new, c_new) = model.lstm.lstm(lstm_input, (h, c)) # lstm_out: (1, 1, 1024)
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
output = model.lstm.fc(lstm_out.squeeze(1)) # Shape: (1, vocab_size)
|
|
|
69 |
|
70 |
+
log_probs = F.log_softmax(output, dim=1) # Shape: (1, vocab_size)
|
|
|
71 |
|
72 |
+
top_log_probs, top_indices = log_probs.topk(beam_width, dim=1) # Each of shape: (1, beam_width)
|
73 |
|
74 |
+
for i in range(beam_width):
|
75 |
+
token = top_indices[0, i].item()
|
76 |
+
token_log_prob = top_log_probs[0, i].item()
|
77 |
|
78 |
+
new_seq = seq + [token]
|
79 |
+
new_score = score + token_log_prob
|
|
|
80 |
|
81 |
+
token_tensor = torch.tensor([token], device=device)
|
82 |
+
embeddings = model.lstm.embedding(token_tensor)
|
83 |
+
context, _ = model.lstm.attention(cnn_features, h_new[-1])
|
84 |
+
new_lstm_input = torch.cat([embeddings, context], dim=1).unsqueeze(1)
|
85 |
|
86 |
+
if h_new is not None and c_new is not None:
|
87 |
+
h_new, c_new = (h_new.clone(), c_new.clone())
|
88 |
+
else:
|
89 |
+
h_new, c_new = None, None
|
90 |
|
91 |
+
all_candidates.append((new_seq, new_score, new_lstm_input, (h_new, c_new) ))
|
|
|
|
|
92 |
|
93 |
+
if not all_candidates:
|
94 |
+
break
|
|
|
95 |
|
96 |
+
ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
|
97 |
|
98 |
+
sequences = ordered[:beam_width]
|
99 |
|
100 |
+
if len(completed_sequences) >= beam_width:
|
101 |
+
break
|
102 |
|
103 |
+
if len(completed_sequences) == 0:
|
104 |
+
completed_sequences = sequences
|
105 |
+
|
106 |
+
best_seq = max(completed_sequences, key=lambda x: x[1])
|
107 |
+
best_caption = decoder(best_seq[0])
|
108 |
|
109 |
+
return best_caption
|
110 |
|
111 |
|
112 |
+
## ResNet50 (CNN Encoder)
|
113 |
class ResNet50(nn.Module):
|
114 |
def __init__(self):
|
115 |
super(ResNet50, self).__init__()
|
116 |
self.ResNet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
117 |
+
self.features = nn.Sequential(*list(self.ResNet50.children())[:-2])
|
118 |
+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
119 |
+
|
120 |
+
for param in self.ResNet50.parameters():
|
121 |
+
param.requires_grad = False
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
x = self.features(x)
|
125 |
+
x = self.avgpool(x)
|
126 |
+
B, C, H, W = x.size()
|
127 |
+
x = x.view(B, C, -1) # Flatten spatial dimensions: (B, 2048, 49)
|
128 |
+
x = x.permute(0, 2, 1) # (B, 49, 2048) - 49 spatial locations
|
129 |
+
return x
|
130 |
+
|
131 |
+
class Attention(nn.Module):
|
132 |
+
def __init__(self, feature_size, hidden_size):
|
133 |
+
super(Attention, self).__init__()
|
134 |
+
self.attention = nn.Linear(feature_size + hidden_size, hidden_size)
|
135 |
+
self.attn_weights = nn.Linear(hidden_size, 1)
|
136 |
+
|
137 |
+
def forward(self, features, hidden_state): # features: (B, 49, 2048), hidden_state: (B, hidden_size)
|
138 |
+
hidden_state = hidden_state.unsqueeze(1).repeat(1, features.size(1), 1) # (B, 49, hidden_size)
|
139 |
+
combined = torch.cat((features, hidden_state), dim=2) # (B, 49, feature_size + hidden_size)
|
140 |
+
attn_hidden = torch.tanh(self.attention(combined)) # (B, 49, hidden_size)
|
141 |
+
attention_logits = self.attn_weights(attn_hidden).squeeze(2) # (B, 49)
|
142 |
+
attention_weights = torch.softmax(attention_logits, dim=1) # (B, 49)
|
143 |
+
context = (features * attention_weights.unsqueeze(2)).sum(dim=1) # (B, 2048)
|
144 |
+
return context, attention_weights
|
145 |
+
|
146 |
+
# Attention without learnable paramters:
|
147 |
+
# logits = torch.matmul(features, hidden_state.unsqueeze(2)) # (B, 49, 1) - Batch Matriax
|
148 |
+
# attention_weights = torch.softmax(logits, dim=1).squeeze(2) # (B, 49)
|
149 |
+
# context = (features * attention_weights.unsqueeze(2)).sum(dim=1) # (B, 2048)
|
150 |
|
151 |
class lstm(nn.Module):
|
152 |
+
def __init__(self, feature_size, hidden_size, number_layers, embedding_dim, vocab_size):
|
153 |
super(lstm, self).__init__()
|
154 |
|
|
|
155 |
self.hidden_size = hidden_size
|
|
|
|
|
|
|
|
|
156 |
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
157 |
+
self.attention = Attention(feature_size, hidden_size)
|
|
|
158 |
|
159 |
self.lstm = nn.LSTM(
|
160 |
+
input_size=hidden_size + feature_size, # input: concatenated context and word embedding
|
161 |
hidden_size=hidden_size,
|
162 |
num_layers=number_layers,
|
163 |
dropout=0.5,
|
|
|
166 |
|
167 |
self.fc = nn.Linear(hidden_size, vocab_size)
|
168 |
|
169 |
+
def forward(self, features, captions=None, max_seq_len=None, teacher_forcing_ratio=0.90):
|
|
|
|
|
170 |
|
171 |
+
batch_size = features.size(0)
|
172 |
+
max_seq_len = max_seq_len if max_seq_len is not None else captions.size(1)
|
173 |
+
h, c = self.init_hidden_state(batch_size)
|
174 |
|
175 |
+
outputs = torch.zeros(batch_size, max_seq_len, self.fc.out_features).to(features.device)
|
176 |
+
word_input = torch.tensor(2, dtype=torch.long).expand(batch_size).to(features.device) # vocab["<sos>"] ---> 2
|
177 |
+
|
178 |
+
for t in range(1, max_seq_len):
|
179 |
+
embeddings = self.embedding(word_input)
|
180 |
+
context, _ = self.attention(features, h[-1])
|
181 |
+
lstm_input_step = torch.cat([embeddings, context], dim=1).unsqueeze(1) # Combine context + word embedding
|
182 |
+
|
183 |
+
out, (h, c) = self.lstm(lstm_input_step, (h, c))
|
184 |
+
output = self.fc(out.squeeze(1))
|
185 |
+
outputs[:, t, :] = output
|
186 |
+
|
187 |
+
top1 = output.argmax(1)
|
188 |
|
189 |
+
if captions is not None and torch.rand(1).item() < teacher_forcing_ratio:
|
190 |
+
word_input = captions[:, t]
|
191 |
+
else:
|
192 |
+
word_input = top1
|
193 |
+
|
194 |
+
return outputs
|
195 |
+
|
196 |
+
def init_hidden_state(self, batch_size):
|
197 |
+
device = next(self.parameters()).device
|
198 |
+
h0 = torch.zeros(self.lstm.num_layers, batch_size, self.hidden_size).to(device)
|
199 |
+
c0 = torch.zeros(self.lstm.num_layers, batch_size, self.hidden_size).to(device)
|
200 |
+
return (h0, c0)
|
201 |
|
|
|
202 |
|
203 |
class ImgCap(nn.Module):
|
204 |
+
def __init__(self, feature_size, lstm_hidden_size, num_layers, vocab_size, embedding_dim):
|
205 |
super(ImgCap, self).__init__()
|
|
|
206 |
self.cnn = ResNet50()
|
207 |
+
self.lstm = lstm(feature_size, lstm_hidden_size, num_layers, embedding_dim, vocab_size)
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
def forward(self, images, captions):
|
210 |
cnn_features = self.cnn(images)
|
211 |
+
output = self.lstm(cnn_features, captions)
|
212 |
return output
|
213 |
|
214 |
def generate_caption(self, images, vocab, decoder, device="cpu", start_token="<sos>", end_token="<eos>", max_seq_length=100):
|
|
|
220 |
images = images.to(device)
|
221 |
batch_size = images.size(0)
|
222 |
|
223 |
+
captions = [[start_index,] for _ in range(batch_size)]
|
224 |
+
end_token_appear = [False] * batch_size
|
225 |
+
|
226 |
+
cnn_features = self.cnn(images) # (B, 49, 2048)
|
227 |
+
|
228 |
+
h, c = self.lstm.init_hidden_state(batch_size)
|
229 |
+
|
230 |
+
word_input = torch.full((batch_size,), start_index, dtype=torch.long).to(device)
|
231 |
+
|
232 |
+
for t in range(max_seq_length):
|
233 |
+
|
234 |
+
embeddings = self.lstm.embedding(word_input)
|
235 |
+
context, _ = self.lstm.attention(cnn_features, h[-1]) # Attention context
|
236 |
+
lstm_input_step = torch.cat([embeddings, context], dim=1).unsqueeze(1) # Combine context + word embedding
|
237 |
+
|
238 |
+
out, (h, c) = self.lstm.lstm(lstm_input_step, (h, c))
|
239 |
+
|
240 |
+
output = self.lstm.fc(out.squeeze(1)) # (B, vocab_size)
|
241 |
+
|
242 |
+
# Get the predicted word (greedy search)
|
243 |
+
predicted_word_indices = torch.argmax(output, dim=1) # (B,)
|
244 |
+
word_input = predicted_word_indices
|
245 |
+
|
246 |
+
|
247 |
+
for i in range(batch_size):
|
248 |
+
if not end_token_appear[i]:
|
249 |
+
predicted_word = vocab.lookup_token(predicted_word_indices[i].item())
|
250 |
+
if predicted_word == end_token:
|
251 |
+
captions[i].append(predicted_word_indices[i].item())
|
252 |
+
end_token_appear[i] = True
|
253 |
+
else:
|
254 |
+
captions[i].append(predicted_word_indices[i].item())
|
255 |
+
|
256 |
+
|
257 |
+
if all(end_token_appear): # Stop if all captions have reached the <eos> token
|
258 |
+
break
|
259 |
+
|
260 |
+
captions = [decoder(caption) for caption in captions]
|
261 |
+
|
262 |
+
return captions
|
263 |
+
|
264 |
+
def beam_search_caption(self, images, vocab, decoder, device="cpu",
|
265 |
+
start_token="<sos>", end_token="<eos>",
|
266 |
+
beam_width=3, max_seq_length=100):
|
267 |
+
self.eval()
|
268 |
+
|
269 |
+
with torch.no_grad():
|
270 |
+
start_index = vocab[start_token]
|
271 |
+
end_index = vocab[end_token]
|
272 |
+
images = images.to(device)
|
273 |
+
batch_size = images.size(0)
|
274 |
+
|
275 |
+
# Ensure batch_size is 1 for beam search (one image at a time)
|
276 |
+
if batch_size != 1:
|
277 |
+
raise ValueError("Beam search currently supports batch_size=1.")
|
278 |
+
|
279 |
+
cnn_features = self.cnn(images) # (B, 49, 2048)
|
280 |
+
h, c = self.lstm.init_hidden_state(batch_size)
|
281 |
+
word_input = torch.full((batch_size,), start_index, dtype=torch.long).to(device)
|
282 |
+
|
283 |
+
embeddings = self.lstm.embedding(word_input)
|
284 |
+
context, _ = self.lstm.attention(cnn_features, h[-1])
|
285 |
+
lstm_input = torch.cat([embeddings, context], dim=1).unsqueeze(1)
|
286 |
+
|
287 |
|
288 |
+
sequences = [([start_index], 0.0, lstm_input, (h, c))] # List of tuples: (sequence, score, input, state)
|
|
|
289 |
|
290 |
+
completed_sequences = []
|
291 |
|
292 |
+
for _ in range(max_seq_length):
|
293 |
+
all_candidates = []
|
|
|
|
|
|
|
294 |
|
295 |
+
for seq, score, lstm_input, (h,c) in sequences:
|
296 |
+
if seq[-1] == end_index:
|
297 |
+
completed_sequences.append((seq, score))
|
298 |
continue
|
299 |
|
300 |
+
lstm_out, (h_new, c_new) = model.lstm.lstm(lstm_input, (h, c)) # lstm_out: (1, 1, 1024)
|
|
|
|
|
301 |
|
302 |
+
output = model.lstm.fc(lstm_out.squeeze(1)) # Shape: (1, vocab_size)
|
303 |
|
304 |
+
log_probs = F.log_softmax(output, dim=1) # Shape: (1, vocab_size)
|
305 |
|
306 |
+
top_log_probs, top_indices = log_probs.topk(beam_width, dim=1) # Each of shape: (1, beam_width)
|
307 |
+
|
308 |
+
for i in range(beam_width):
|
309 |
+
token = top_indices[0, i].item()
|
310 |
+
token_log_prob = top_log_probs[0, i].item()
|
311 |
+
|
312 |
+
new_seq = seq + [token]
|
313 |
+
new_score = score + token_log_prob
|
314 |
+
|
315 |
+
token_tensor = torch.tensor([token], device=device)
|
316 |
+
embeddings = self.lstm.embedding(token_tensor)
|
317 |
+
context, _ = self.lstm.attention(cnn_features, h_new[-1])
|
318 |
+
new_lstm_input = torch.cat([embeddings, context], dim=1).unsqueeze(1)
|
319 |
+
|
320 |
+
if h_new is not None and c_new is not None:
|
321 |
+
h_new, c_new = (h_new.clone(), c_new.clone())
|
322 |
+
else:
|
323 |
+
h_new, c_new = None, None
|
324 |
+
|
325 |
+
all_candidates.append((new_seq, new_score, new_lstm_input, (h_new, c_new) ))
|
326 |
+
|
327 |
+
if not all_candidates:
|
328 |
+
break
|
329 |
+
|
330 |
+
ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
|
331 |
+
|
332 |
+
sequences = ordered[:beam_width]
|
333 |
+
|
334 |
+
if len(completed_sequences) >= beam_width:
|
335 |
+
break
|
336 |
+
|
337 |
+
if len(completed_sequences) == 0:
|
338 |
+
completed_sequences = sequences
|
339 |
+
|
340 |
+
best_seq = max(completed_sequences, key=lambda x: x[1])
|
341 |
+
best_caption = decoder(best_seq[0], vocab)
|
342 |
|
343 |
+
return best_caption
|