TBurdairon's picture
Upload folder using huggingface_hub
8d81eef verified
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import torch
from PIL import Image
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
MultiUpscaler,
UpscalerCheckpoints,
)
from esrgan_model import UpscalerESRGAN
@dataclass(kw_only=True)
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
"""Extends the SD-1.5 MultiUpscaler checkpoints to hold an extra ESRGAN file."""
esrgan: Path
class ESRGANUpscaler(MultiUpscaler):
"""
Multi-stage image enhancer that:
1. Runs ESRGAN 4× super-resolution first (tiling to avoid VRAM overflow),
2. Passes the up-scaled image to Stable-Diffusion 1.5 MultiUpscaler for refinement.
"""
def __init__(
self,
checkpoints: ESRGANUpscalerCheckpoints,
device: torch.device,
dtype: torch.dtype,
) -> None:
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
self.esrgan = UpscalerESRGAN(
checkpoints.esrgan, device=self.device, dtype=self.dtype
)
# ---- automatically called by HF when the model is moved to another device ----
def to(self, device: torch.device, dtype: torch.dtype):
self.esrgan.to(device=device, dtype=dtype)
self.sd = self.sd.to(device=device, dtype=dtype)
self.device = device
self.dtype = dtype
# ---- hook that runs *before* SD-1.5 up-scaling ----
def pre_upscale(
self,
image: Image.Image,
upscale_factor: float,
**_: Any,
) -> Image.Image:
# 4× ESRGAN first, then the SD-1.5 stage handles the residual upscale
image = self.esrgan.upscale_with_tiling(image)
return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)