MIRNet low-light image enhancement
MIRNet-based low-light image enhancer specialized on restoring dark images from events (concerts, parties, clubs...).
Project source-code and further documentation
Documentation about pre-training, fine-tuning, model architecture, usage and all source code used for building and inference can be found in the GitHub repository of the project.
This page currently stores the PyTorch model weights and model definition, a HuggingFace pipeline will be implemented in the future.
Using the model
To use the model, you need to have the model
folder, that you can dowload from this repository as well as on GitHub, present in your project folder.
Then, the following code can be used to download the model weights from HuggingFace and load them in PyTorch for downstream use of the model:
import torch
import torchvision.transforms as T
from PIL import Image
from huggingface_hub import hf_hub_download
from model.MIRNet.model import MIRNet
device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("mps")
if torch.backends.mps.is_available()
else torch.device("cpu")
)
# Download the model weights from the Hugging Face Hub
model_path = hf_hub_download(
repo_id="dblasko/mirnet-low-light-img-enhancement", filename="mirnet_finetuned.pth"
)
# Load the model
model = MIRNet().to(device)
model.load_state_dict(torch.load(model_path, map_location=device)["model_state_dict"])
# Use the model, for example for inference on an image
model.eval()
with torch.no_grad():
img = Image.open("image_path.png").convert("RGB")
img_tensor = T.Compose(
[
T.Resize(400), # Adjust image resizing depending on hardware
T.ToTensor(),
T.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
]
)(img).unsqueeze(0)
img_tensor = img_tensor.to(device)
if img_tensor.shape[2] % 8 != 0:
img_tensor = img_tensor[:, :, : -(img_tensor.shape[2] % 8), :]
if img_tensor.shape[3] % 8 != 0:
img_tensor = img_tensor[:, :, :, : -(img_tensor.shape[3] % 8)]
output = model(img_tensor)