import torch import torch.nn as nn from torchvision import models from transformers import PreTrainedModel, PretrainedConfig class ImageCaptioningConfig(PretrainedConfig): model_type = "image-captioning" def __init__(self, feature_dim=1024, embedding_dim=256, hidden_dim=512, vocab_size=5000, num_layers=1, dropout=0.5, **kwargs): super().__init__(**kwargs) self.feature_dim = feature_dim self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.vocab_size = vocab_size self.num_layers = num_layers self.dropout = dropout class LSTMDecoder(nn.Module): def __init__(self, feature_dim, embedding_dim, hidden_dim, vocab_size, num_layers=1, dropout=0.5): super(LSTMDecoder, self).__init__() self.feature_project = nn.Linear(feature_dim, embedding_dim) self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=dropout) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(hidden_dim * 2, vocab_size) def forward(self, image_features, captions): image_features = self.feature_project(image_features).unsqueeze(1) embeddings = self.embedding(captions) embeddings = torch.cat((image_features, embeddings), dim=1) lstm_out, _ = self.lstm(embeddings) lstm_out = self.dropout(lstm_out) outputs = self.fc(lstm_out) return outputs class ImageCaptioningModel(PreTrainedModel): config_class = ImageCaptioningConfig def __init__(self, config: ImageCaptioningConfig): super().__init__(config) # Encoder: CheXNet (DenseNet121 without classifier) self.encoder = models.densenet121(pretrained=True) self.encoder.classifier = nn.Identity() self.encoder.eval() # Encoder frozen # Decoder: LSTM self.decoder = LSTMDecoder( feature_dim=config.feature_dim, embedding_dim=config.embedding_dim, hidden_dim=config.hidden_dim, vocab_size=config.vocab_size, num_layers=config.num_layers, dropout=config.dropout ) def forward(self, image, caption): with torch.no_grad(): image_features = self.encoder(image) output = self.decoder(image_features, caption) return output