You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

HybridTransformer-MFIF: Focal Transformer & CrossViT Hybrid for Multi-Focus Image Fusion

HybridTransformer MFIF Logo

License Python PyTorch HuggingFace Demo

A state-of-the-art PyTorch implementation combining Focal Transformer and CrossViT architectures for multi-focus image fusion (MFIF). This hybrid model intelligently merges images with different focal planes to create a single, comprehensively focused output.

πŸ”— Project Resources

Platform Description Link
πŸš€ Interactive Demo Try the model online with your own images Launch Demo
πŸ€— Model Repository Download pre-trained weights and config This Repository
πŸ“Š Training Tutorial Complete pipeline with GPU acceleration Kaggle Notebook
πŸ“ Source Code Full implementation and documentation GitHub Repository
πŸ“¦ Training Dataset Lytro Multi-Focus dataset Kaggle Dataset

Model Details

Model Description

HybridTransformer-MFIF is a novel deep learning architecture that addresses the multi-focus image fusion task by combining two powerful transformer-based approaches:

  • 🎯 Focal Transformer: Provides adaptive spatial attention with multi-scale focal windows for enhanced feature extraction
  • πŸ”„ CrossViT: Enables cross-attention between near-focus and far-focus images for optimal information fusion
  • ⚑ Hybrid Integration: Sequential processing pipeline optimized specifically for image fusion tasks

The model takes two input images of the same scene with different focal planes and produces a single output image that preserves the best-focused regions from both inputs.

  • Model type: Vision Transformer (Hybrid Architecture)
  • Language(s): PyTorch implementation
  • License: MIT
  • Repository: GitHub

Uses

Direct Use

The model is designed for multi-focus image fusion applications:

import torch
from transformers import pipeline

# Load the model
fusion_pipeline = pipeline(
    "image-to-image",
    model="divitmittal/HybridTransformer-MFIF",
    device=0 if torch.cuda.is_available() else -1
)

# Fuse two images with different focus regions
result = fusion_pipeline({
    "near_focus": "path/to/near_focus_image.jpg",
    "far_focus": "path/to/far_focus_image.jpg"
})

Training Details

Training Data

The model was trained on the Lytro Multi-Focus Dataset:

  • Dataset: 20 image pairs (near-focus + far-focus) from Lytro camera
  • Resolution: 520Γ—520 pixels, resized to 224Γ—224 for training
  • Format: RGB color images in JPEG format
  • Augmentation: Random horizontal flip, rotation (Β±10Β°), color jittering
  • Split: 80% training, 20% validation (using Triple Series for validation)
  • Normalization: ImageNet statistics (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

Training Procedure

Training Hyperparameters

  • Optimizer: AdamW
  • Learning Rate: 1e-4 with cosine annealing
  • Batch Size: 8 (adjustable based on available memory)
  • Epochs: 50 with early stopping (patience=15)
  • Weight Decay: 1e-4
  • Gradient Clipping: L2 norm clipping at 1.0
  • Mixed Precision: Enabled (AMP) for faster training

Model Architecture

  • Input Size: 224Γ—224Γ—3
  • Patch Size: 16Γ—16
  • Embedding Dimension: 768
  • CrossViT Blocks: 4 layers
  • Focal Transformer Blocks: 6 layers
  • Attention Heads: 12
  • Focal Window Size: 9Γ—9
  • Focal Levels: 3
  • Total Parameters: ~73M

Loss Function

Custom multi-component loss combining:

  • L1 Loss (Ξ±=1.0): Pixel-wise reconstruction
  • SSIM Loss (Ξ²=0.5): Structural similarity preservation
  • Perceptual Loss (Ξ³=0.3): VGG-based feature matching
  • Gradient Loss (Ξ΄=0.2): Edge preservation
  • Focus Map Loss (Ξ΅=0.1): Focus quality enhancement

Evaluation

Testing Data, Factors & Metrics

Testing Data

  • Primary: Lytro Multi-Focus Dataset (Triple Series, 4 image sets)
  • Secondary: Standard MFIF benchmarks for comparison
  • Evaluation Protocol: Hold-out test set with no overlap with training data

Evaluation Metrics

The model is evaluated using comprehensive fusion quality metrics:

Metric Description Range Higher is Better
PSNR Peak Signal-to-Noise Ratio 0-∞ dB βœ“
SSIM Structural Similarity Index 0-1 βœ“
QABF Quality Assessment Based on Features 0-1 βœ“
VIF Visual Information Fidelity 0-1 βœ“
MI Mutual Information 0-∞ βœ“
SF Spatial Frequency 0-∞ βœ“

Results

Quantitative Performance

Metric Value Unit Benchmark Comparison
PSNR 28.5 dB State-of-the-art
SSIM 0.92 index Excellent
QABF 0.85 index High quality
VIF 0.78 index Very good
SF 12.3 score Superior

Computational Performance

  • Inference Time: ~150ms per image pair (GPU)
  • Memory Usage: ~4GB VRAM for 224Γ—224 images
  • Model Size: 294MB (73M parameters)
  • Supported Hardware: CUDA-enabled GPUs, CPU fallback available

Technical Specifications

Model Architecture

Model Architecture

The FocalCrossViTHybrid architecture consists of:

1. Patch Embedding Layer

  • Converts input images (224Γ—224Γ—3) into patch tokens (14Γ—14Γ—768)
  • Shared embedding for both near-focus and far-focus inputs
  • Learnable positional encoding added to patches

2. CrossViT Processing (4 blocks)

  • Cross-Attention Mechanism: Enables information exchange between near/far features
  • Multi-Head Attention: 12 attention heads for diverse feature interactions
  • MLP Layers: Feed-forward networks with GELU activation
  • Residual Connections: Skip connections for gradient flow

3. Focal Transformer Processing (6 blocks)

  • Focal Modulation: Multi-scale spatial attention with learnable focal windows
  • Hierarchical Processing: Progressive feature refinement
  • Adaptive Focus: Dynamic attention based on spatial content
  • Window Sizes: 9Γ—9 base window with 3 focal levels

4. Fusion and Decoder

  • Feature Fusion: Learned combination of processed features
  • Upsampling Decoder: Series of transposed convolutions
  • Output Generation: Sigmoid activation for final image output

Software Requirements

  • Python: β‰₯3.8
  • PyTorch: β‰₯2.0.0
  • torchvision: β‰₯0.15.0
  • PIL/Pillow: For image processing
  • NumPy: For numerical operations

Hardware Requirements

  • Minimum: 8GB RAM, CPU inference supported
  • Recommended: 16GB RAM, NVIDIA GPU with 4GB+ VRAM
  • Optimal: NVIDIA RTX 3080/4080 or similar for fast inference

How to Use

Quick Start

import torch
from PIL import Image
from transformers import pipeline

# Initialize the fusion pipeline
fusion_model = pipeline(
    "image-to-image",
    model="divitmittal/HybridTransformer-MFIF",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)

# Load your images
near_focus_img = Image.open("near_focus.jpg")
far_focus_img = Image.open("far_focus.jpg")

# Perform fusion
fused_result = fusion_model({
    "near_focus": near_focus_img,
    "far_focus": far_focus_img
})

# Save the result
fused_result.save("fused_output.jpg")

Advanced Usage

import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import AutoModel, AutoConfig

# Load model configuration and weights
config = AutoConfig.from_pretrained("divitmittal/HybridTransformer-MFIF")
model = AutoModel.from_pretrained("divitmittal/HybridTransformer-MFIF")
model.eval()

# Preprocessing pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Process images
near_tensor = transform(near_focus_img).unsqueeze(0)
far_tensor = transform(far_focus_img).unsqueeze(0)

# Inference
with torch.no_grad():
    fused_tensor = model(near_tensor, far_tensor)

# Post-process output
fused_image = transforms.ToPILImage()(fused_tensor.squeeze(0))

Limitations and Bias

Known Limitations

  • Input Constraints: Requires exactly two input images with different focus regions
  • Resolution: Optimized for 224Γ—224 input; larger images may need preprocessing
  • Scene Types: Best performance on natural scenes; may struggle with highly synthetic content
  • Computational Cost: Requires significant GPU memory for optimal performance

Potential Biases

  • Dataset Bias: Trained primarily on Lytro camera data; may not generalize perfectly to all camera types
  • Content Bias: Performance may vary based on scene complexity and focus distribution
  • Color Space: Optimized for RGB color images; grayscale performance not extensively tested

Ethical Considerations

  • Intended Use: Research and legitimate photography applications
  • Misuse Prevention: Should not be used to create misleading or deceptive images
  • Privacy: Users should ensure they have rights to process uploaded images
  • Transparency: Model limitations should be communicated when deployed in applications

If you find this model useful, please consider ❀️ liking the repository!

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Evaluation results