khalednabawi11's picture
Upload modeling.py
a572204 verified
import torch
import torch.nn as nn
from torchvision import models
from transformers import PreTrainedModel, PretrainedConfig
class ImageCaptioningConfig(PretrainedConfig):
model_type = "image-captioning"
def __init__(self, feature_dim=1024, embedding_dim=256, hidden_dim=512,
vocab_size=5000, num_layers=1, dropout=0.5, **kwargs):
super().__init__(**kwargs)
self.feature_dim = feature_dim
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.vocab_size = vocab_size
self.num_layers = num_layers
self.dropout = dropout
class LSTMDecoder(nn.Module):
def __init__(self, feature_dim, embedding_dim, hidden_dim, vocab_size, num_layers=1, dropout=0.5):
super(LSTMDecoder, self).__init__()
self.feature_project = nn.Linear(feature_dim, embedding_dim)
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers,
batch_first=True, bidirectional=True, dropout=dropout)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_dim * 2, vocab_size)
def forward(self, image_features, captions):
image_features = self.feature_project(image_features).unsqueeze(1)
embeddings = self.embedding(captions)
embeddings = torch.cat((image_features, embeddings), dim=1)
lstm_out, _ = self.lstm(embeddings)
lstm_out = self.dropout(lstm_out)
outputs = self.fc(lstm_out)
return outputs
class ImageCaptioningModel(PreTrainedModel):
config_class = ImageCaptioningConfig
def __init__(self, config: ImageCaptioningConfig):
super().__init__(config)
# Encoder: CheXNet (DenseNet121 without classifier)
self.encoder = models.densenet121(pretrained=True)
self.encoder.classifier = nn.Identity()
self.encoder.eval() # Encoder frozen
# Decoder: LSTM
self.decoder = LSTMDecoder(
feature_dim=config.feature_dim,
embedding_dim=config.embedding_dim,
hidden_dim=config.hidden_dim,
vocab_size=config.vocab_size,
num_layers=config.num_layers,
dropout=config.dropout
)
def forward(self, image, caption):
with torch.no_grad():
image_features = self.encoder(image)
output = self.decoder(image_features, caption)
return output