Vision-GPT: Image Captioning Model

A multimodal model combining Vision Transformer (ViT-B/16) and GPT-2 for image captioning, trained on Flickr8K dataset.

Model Description

This model generates natural language captions for images by:

  1. Encoding images using a pre-trained ViT-B/16 vision encoder
  2. Projecting visual features into GPT-2's embedding space
  3. Generating captions autoregressively with GPT-2

Training Details

  • Dataset: Flickr8K (all splits: train, validation, test)
  • Vision Encoder: ViT-B/16 (frozen)
  • Language Model: GPT-2 (frozen backbone, trainable projection)
  • Training: Only the vision-to-text projection layer is trained

Model Versions

πŸ“¦ FP32 (Full Precision)

  • Size: ~0.83 GB
  • Precision: 32-bit floating point
  • Use case: Maximum accuracy, research

πŸ“¦ FP16 (Half Precision)

  • Size: ~0.83 GB
  • Precision: 16-bit floating point
  • Use case: Faster inference, reduced memory (~50% smaller)

Usage

Installation

pip install torch torchvision transformers pillow huggingface_hub

Loading the Model (FP32)

import torch
from transformers import GPT2Tokenizer
from PIL import Image
from torchvision import transforms

# Load checkpoint
checkpoint = torch.load("model_fp32/model_checkpoint.pth", map_location="cpu")

# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("model_fp32/tokenizer")

# Load your model architecture (you need to define this)
# model = YourVisionGPTModel(config)
# model.load_state_dict(checkpoint['model_state_dict'])
# model.eval()

print("Model loaded successfully!")

Loading the Model (FP16)

# Load FP16 checkpoint
checkpoint = torch.load("model_fp16/model_checkpoint.pth", map_location="cpu")

# Load model and convert to FP16
# model = YourVisionGPTModel(config)
# model.load_state_dict(checkpoint['model_state_dict'])
# model.half()  # Convert to FP16
# model.eval()

# For inference with FP16, also convert input images to FP16

Image Preprocessing

image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda x: x.convert('RGB')),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# Load and preprocess image
image = Image.open("your_image.jpg")
image_tensor = image_transform(image).unsqueeze(0)  # Add batch dimension

Generate Caption

# Generate caption
with torch.no_grad():
    # Forward pass
    generated_ids = model.generate(
        image_tensor,
        max_length=50,
        num_beams=5,
        temperature=0.7
    )
    
    # Decode caption
    caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"Generated caption: {caption}")

Model Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Input Image    β”‚
β”‚   (224x224)     β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         β”‚
         β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   ViT-B/16      β”‚
β”‚   (frozen)      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         β”‚
         β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Projection     β”‚
β”‚  (trainable)    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         β”‚
         β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚    GPT-2        β”‚
β”‚   (frozen)      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         β”‚
         β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Caption Output β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Limitations

  • Trained only on Flickr8K (limited domain)
  • English captions only
  • Input images must be 224x224
  • May generate generic captions for out-of-domain images

Citation

If you use this model, please cite:

@misc{vision-gpt-flickr8k,
  author = {gurumurthy3},
  title = {Vision-GPT: Image Captioning with ViT and GPT-2},
  year = {2025},
  publisher = {Hugging Face},
  howpublished = {\url{https://huggingface.co/gurumurthy3/vision-gpt-flickr8k}}
}

License

MIT License

Acknowledgments

  • Vision Transformer (ViT): Dosovitskiy et al.
  • GPT-2: OpenAI
  • Flickr8K Dataset: Hodosh et al.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train gurumurthy3/vision-gpt-flickr8k

Space using gurumurthy3/vision-gpt-flickr8k 1