|
import torch
|
|
import torch.nn as nn
|
|
from torchvision import models
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
|
|
class ImageCaptioningConfig(PretrainedConfig):
|
|
model_type = "image-captioning"
|
|
|
|
def __init__(self, feature_dim=1024, embedding_dim=256, hidden_dim=512,
|
|
vocab_size=5000, num_layers=1, dropout=0.5, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.feature_dim = feature_dim
|
|
self.embedding_dim = embedding_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.vocab_size = vocab_size
|
|
self.num_layers = num_layers
|
|
self.dropout = dropout
|
|
|
|
|
|
class LSTMDecoder(nn.Module):
|
|
def __init__(self, feature_dim, embedding_dim, hidden_dim, vocab_size, num_layers=1, dropout=0.5):
|
|
super(LSTMDecoder, self).__init__()
|
|
self.feature_project = nn.Linear(feature_dim, embedding_dim)
|
|
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
|
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers,
|
|
batch_first=True, bidirectional=True, dropout=dropout)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.fc = nn.Linear(hidden_dim * 2, vocab_size)
|
|
|
|
def forward(self, image_features, captions):
|
|
image_features = self.feature_project(image_features).unsqueeze(1)
|
|
embeddings = self.embedding(captions)
|
|
embeddings = torch.cat((image_features, embeddings), dim=1)
|
|
lstm_out, _ = self.lstm(embeddings)
|
|
lstm_out = self.dropout(lstm_out)
|
|
outputs = self.fc(lstm_out)
|
|
return outputs
|
|
|
|
|
|
class ImageCaptioningModel(PreTrainedModel):
|
|
config_class = ImageCaptioningConfig
|
|
|
|
def __init__(self, config: ImageCaptioningConfig):
|
|
super().__init__(config)
|
|
|
|
|
|
self.encoder = models.densenet121(pretrained=True)
|
|
self.encoder.classifier = nn.Identity()
|
|
self.encoder.eval()
|
|
|
|
|
|
self.decoder = LSTMDecoder(
|
|
feature_dim=config.feature_dim,
|
|
embedding_dim=config.embedding_dim,
|
|
hidden_dim=config.hidden_dim,
|
|
vocab_size=config.vocab_size,
|
|
num_layers=config.num_layers,
|
|
dropout=config.dropout
|
|
)
|
|
|
|
def forward(self, image, caption):
|
|
with torch.no_grad():
|
|
image_features = self.encoder(image)
|
|
output = self.decoder(image_features, caption)
|
|
return output
|
|
|