Vision-GPT: Image Captioning Model
A multimodal model combining Vision Transformer (ViT-B/16) and GPT-2 for image captioning, trained on Flickr8K dataset.
Model Description
This model generates natural language captions for images by:
- Encoding images using a pre-trained ViT-B/16 vision encoder
- Projecting visual features into GPT-2's embedding space
- Generating captions autoregressively with GPT-2
Training Details
- Dataset: Flickr8K (all splits: train, validation, test)
- Vision Encoder: ViT-B/16 (frozen)
- Language Model: GPT-2 (frozen backbone, trainable projection)
- Training: Only the vision-to-text projection layer is trained
Model Versions
π¦ FP32 (Full Precision)
- Size: ~0.83 GB
- Precision: 32-bit floating point
- Use case: Maximum accuracy, research
π¦ FP16 (Half Precision)
- Size: ~0.83 GB
- Precision: 16-bit floating point
- Use case: Faster inference, reduced memory (~50% smaller)
Usage
Installation
pip install torch torchvision transformers pillow huggingface_hub
Loading the Model (FP32)
import torch
from transformers import GPT2Tokenizer
from PIL import Image
from torchvision import transforms
# Load checkpoint
checkpoint = torch.load("model_fp32/model_checkpoint.pth", map_location="cpu")
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("model_fp32/tokenizer")
# Load your model architecture (you need to define this)
# model = YourVisionGPTModel(config)
# model.load_state_dict(checkpoint['model_state_dict'])
# model.eval()
print("Model loaded successfully!")
Loading the Model (FP16)
# Load FP16 checkpoint
checkpoint = torch.load("model_fp16/model_checkpoint.pth", map_location="cpu")
# Load model and convert to FP16
# model = YourVisionGPTModel(config)
# model.load_state_dict(checkpoint['model_state_dict'])
# model.half() # Convert to FP16
# model.eval()
# For inference with FP16, also convert input images to FP16
Image Preprocessing
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Lambda(lambda x: x.convert('RGB')),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
# Load and preprocess image
image = Image.open("your_image.jpg")
image_tensor = image_transform(image).unsqueeze(0) # Add batch dimension
Generate Caption
# Generate caption
with torch.no_grad():
# Forward pass
generated_ids = model.generate(
image_tensor,
max_length=50,
num_beams=5,
temperature=0.7
)
# Decode caption
caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"Generated caption: {caption}")
Model Architecture
βββββββββββββββββββ
β Input Image β
β (224x224) β
ββββββββββ¬βββββββββ
β
βΌ
βββββββββββββββββββ
β ViT-B/16 β
β (frozen) β
ββββββββββ¬βββββββββ
β
βΌ
βββββββββββββββββββ
β Projection β
β (trainable) β
ββββββββββ¬βββββββββ
β
βΌ
βββββββββββββββββββ
β GPT-2 β
β (frozen) β
ββββββββββ¬βββββββββ
β
βΌ
βββββββββββββββββββ
β Caption Output β
βββββββββββββββββββ
Limitations
- Trained only on Flickr8K (limited domain)
- English captions only
- Input images must be 224x224
- May generate generic captions for out-of-domain images
Citation
If you use this model, please cite:
@misc{vision-gpt-flickr8k,
author = {gurumurthy3},
title = {Vision-GPT: Image Captioning with ViT and GPT-2},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/gurumurthy3/vision-gpt-flickr8k}}
}
License
MIT License
Acknowledgments
- Vision Transformer (ViT): Dosovitskiy et al.
- GPT-2: OpenAI
- Flickr8K Dataset: Hodosh et al.