HybridTransformer-MFIF: Focal Transformer & CrossViT Hybrid for Multi-Focus Image Fusion

A state-of-the-art PyTorch implementation combining Focal Transformer and CrossViT architectures for multi-focus image fusion (MFIF). This hybrid model intelligently merges images with different focal planes to create a single, comprehensively focused output.
Model Details
Model Description
HybridTransformer-MFIF is a novel deep learning architecture that addresses the multi-focus image fusion task by combining two powerful transformer-based approaches:
- π― Focal Transformer: Provides adaptive spatial attention with multi-scale focal windows for enhanced feature extraction
- π CrossViT: Enables cross-attention between near-focus and far-focus images for optimal information fusion
- β‘ Hybrid Integration: Sequential processing pipeline optimized specifically for image fusion tasks
The model takes two input images of the same scene with different focal planes and produces a single output image that preserves the best-focused regions from both inputs.
- Model type: Vision Transformer (Hybrid Architecture)
- Language(s): PyTorch implementation
- License: MIT
- Repository: GitHub
Uses
Direct Use
The model is designed for multi-focus image fusion applications:
import torch
from transformers import pipeline
# Load the model
fusion_pipeline = pipeline(
"image-to-image",
model="divitmittal/HybridTransformer-MFIF",
device=0 if torch.cuda.is_available() else -1
)
# Fuse two images with different focus regions
result = fusion_pipeline({
"near_focus": "path/to/near_focus_image.jpg",
"far_focus": "path/to/far_focus_image.jpg"
})
Intended Use Cases
- π± Mobile Photography: Combine multiple shots with different focus points
- π¬ Scientific Imaging: Merge microscopy images with varying focal depths
- ποΈ Landscape Photography: Create fully focused images from focus-bracketed shots
- π Document Processing: Ensure all text regions are in perfect focus
- π¨ Creative Photography: Artistic control over focus blending and depth
Out-of-Scope Use
- Single image super-resolution or enhancement
- General image-to-image translation tasks
- Real-time video processing (model optimized for static images)
- Fusion of more than two input images simultaneously
Training Details
Training Data
The model was trained on the Lytro Multi-Focus Dataset:
- Dataset: 20 image pairs (near-focus + far-focus) from Lytro camera
- Resolution: 520Γ520 pixels, resized to 224Γ224 for training
- Format: RGB color images in JPEG format
- Augmentation: Random horizontal flip, rotation (Β±10Β°), color jittering
- Split: 80% training, 20% validation (using Triple Series for validation)
- Normalization: ImageNet statistics (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Training Procedure
Training Hyperparameters
- Optimizer: AdamW
- Learning Rate: 1e-4 with cosine annealing
- Batch Size: 8 (adjustable based on available memory)
- Epochs: 50 with early stopping (patience=15)
- Weight Decay: 1e-4
- Gradient Clipping: L2 norm clipping at 1.0
- Mixed Precision: Enabled (AMP) for faster training
Model Architecture
- Input Size: 224Γ224Γ3
- Patch Size: 16Γ16
- Embedding Dimension: 768
- CrossViT Blocks: 4 layers
- Focal Transformer Blocks: 6 layers
- Attention Heads: 12
- Focal Window Size: 9Γ9
- Focal Levels: 3
- Total Parameters: ~73M
Loss Function
Custom multi-component loss combining:
- L1 Loss (Ξ±=1.0): Pixel-wise reconstruction
- SSIM Loss (Ξ²=0.5): Structural similarity preservation
- Perceptual Loss (Ξ³=0.3): VGG-based feature matching
- Gradient Loss (Ξ΄=0.2): Edge preservation
- Focus Map Loss (Ξ΅=0.1): Focus quality enhancement
Evaluation
Testing Data, Factors & Metrics
Testing Data
- Primary: Lytro Multi-Focus Dataset (Triple Series, 4 image sets)
- Secondary: Standard MFIF benchmarks for comparison
- Evaluation Protocol: Hold-out test set with no overlap with training data
Evaluation Metrics
The model is evaluated using comprehensive fusion quality metrics:
Metric | Description | Range | Higher is Better |
---|---|---|---|
PSNR | Peak Signal-to-Noise Ratio | 0-β dB | β |
SSIM | Structural Similarity Index | 0-1 | β |
QABF | Quality Assessment Based on Features | 0-1 | β |
VIF | Visual Information Fidelity | 0-1 | β |
MI | Mutual Information | 0-β | β |
SF | Spatial Frequency | 0-β | β |
Results
Quantitative Performance
Metric | Value | Unit | Benchmark Comparison |
---|---|---|---|
PSNR | 28.5 | dB | State-of-the-art |
SSIM | 0.92 | index | Excellent |
QABF | 0.85 | index | High quality |
VIF | 0.78 | index | Very good |
SF | 12.3 | score | Superior |
Computational Performance
- Inference Time: ~150ms per image pair (GPU)
- Memory Usage: ~4GB VRAM for 224Γ224 images
- Model Size: 294MB (73M parameters)
- Supported Hardware: CUDA-enabled GPUs, CPU fallback available
Technical Specifications
Model Architecture

The FocalCrossViTHybrid architecture consists of:
1. Patch Embedding Layer
- Converts input images (224Γ224Γ3) into patch tokens (14Γ14Γ768)
- Shared embedding for both near-focus and far-focus inputs
- Learnable positional encoding added to patches
2. CrossViT Processing (4 blocks)
- Cross-Attention Mechanism: Enables information exchange between near/far features
- Multi-Head Attention: 12 attention heads for diverse feature interactions
- MLP Layers: Feed-forward networks with GELU activation
- Residual Connections: Skip connections for gradient flow
3. Focal Transformer Processing (6 blocks)
- Focal Modulation: Multi-scale spatial attention with learnable focal windows
- Hierarchical Processing: Progressive feature refinement
- Adaptive Focus: Dynamic attention based on spatial content
- Window Sizes: 9Γ9 base window with 3 focal levels
4. Fusion and Decoder
- Feature Fusion: Learned combination of processed features
- Upsampling Decoder: Series of transposed convolutions
- Output Generation: Sigmoid activation for final image output
Software Requirements
- Python: β₯3.8
- PyTorch: β₯2.0.0
- torchvision: β₯0.15.0
- PIL/Pillow: For image processing
- NumPy: For numerical operations
Hardware Requirements
- Minimum: 8GB RAM, CPU inference supported
- Recommended: 16GB RAM, NVIDIA GPU with 4GB+ VRAM
- Optimal: NVIDIA RTX 3080/4080 or similar for fast inference
How to Use
Quick Start
import torch
from PIL import Image
from transformers import pipeline
# Initialize the fusion pipeline
fusion_model = pipeline(
"image-to-image",
model="divitmittal/HybridTransformer-MFIF",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
# Load your images
near_focus_img = Image.open("near_focus.jpg")
far_focus_img = Image.open("far_focus.jpg")
# Perform fusion
fused_result = fusion_model({
"near_focus": near_focus_img,
"far_focus": far_focus_img
})
# Save the result
fused_result.save("fused_output.jpg")
Advanced Usage
import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import AutoModel, AutoConfig
# Load model configuration and weights
config = AutoConfig.from_pretrained("divitmittal/HybridTransformer-MFIF")
model = AutoModel.from_pretrained("divitmittal/HybridTransformer-MFIF")
model.eval()
# Preprocessing pipeline
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Process images
near_tensor = transform(near_focus_img).unsqueeze(0)
far_tensor = transform(far_focus_img).unsqueeze(0)
# Inference
with torch.no_grad():
fused_tensor = model(near_tensor, far_tensor)
# Post-process output
fused_image = transforms.ToPILImage()(fused_tensor.squeeze(0))
Limitations and Bias
Known Limitations
- Input Constraints: Requires exactly two input images with different focus regions
- Resolution: Optimized for 224Γ224 input; larger images may need preprocessing
- Scene Types: Best performance on natural scenes; may struggle with highly synthetic content
- Computational Cost: Requires significant GPU memory for optimal performance
Potential Biases
- Dataset Bias: Trained primarily on Lytro camera data; may not generalize perfectly to all camera types
- Content Bias: Performance may vary based on scene complexity and focus distribution
- Color Space: Optimized for RGB color images; grayscale performance not extensively tested
Ethical Considerations
- Intended Use: Research and legitimate photography applications
- Misuse Prevention: Should not be used to create misleading or deceptive images
- Privacy: Users should ensure they have rights to process uploaded images
- Transparency: Model limitations should be communicated when deployed in applications
Citation
If you use this model in your research, please cite:
@software{mittal2024hybridtransformer,
title={HybridTransformer-MFIF: Focal Transformer and CrossViT Hybrid for Multi-Focus Image Fusion},
author={Mittal, Divit},
year={2024},
url={https://github.com/DivitMittal/HybridTransformer-MFIF},
note={PyTorch implementation with pre-trained models available at HuggingFace Model Hub}
}
π Project Resources
Platform | Description | Link |
---|---|---|
π Interactive Demo | Try the model online with your own images | Launch Demo |
π€ Model Repository | Download pre-trained weights and config | This Repository |
π Training Tutorial | Complete pipeline with GPU acceleration | Kaggle Notebook |
π Source Code | Full implementation and documentation | GitHub Repository |
π¦ Training Dataset | Lytro Multi-Focus dataset | Kaggle Dataset |
Built with β€οΈ for the computer vision community
If you find this model useful, please consider β starring the repository!
Evaluation results
- PSNR on Lytro Multi-Focus Datasettest set self-reported28.500
- SSIM on Lytro Multi-Focus Datasettest set self-reported0.920
- QABF on Lytro Multi-Focus Datasettest set self-reported0.850
- Structural Fitness on Lytro Multi-Focus Datasettest set self-reported12.300
- VIF on Lytro Multi-Focus Datasettest set self-reported0.780