from dataclasses import dataclass from pathlib import Path from typing import Any import torch from PIL import Image from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( MultiUpscaler, UpscalerCheckpoints, ) from esrgan_model import UpscalerESRGAN @dataclass(kw_only=True) class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): """Extends the SD-1.5 MultiUpscaler checkpoints to hold an extra ESRGAN file.""" esrgan: Path class ESRGANUpscaler(MultiUpscaler): """ Multi-stage image enhancer that: 1. Runs ESRGAN 4× super-resolution first (tiling to avoid VRAM overflow), 2. Passes the up-scaled image to Stable-Diffusion 1.5 MultiUpscaler for refinement. """ def __init__( self, checkpoints: ESRGANUpscalerCheckpoints, device: torch.device, dtype: torch.dtype, ) -> None: super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) self.esrgan = UpscalerESRGAN( checkpoints.esrgan, device=self.device, dtype=self.dtype ) # ---- automatically called by HF when the model is moved to another device ---- def to(self, device: torch.device, dtype: torch.dtype): self.esrgan.to(device=device, dtype=dtype) self.sd = self.sd.to(device=device, dtype=dtype) self.device = device self.dtype = dtype # ---- hook that runs *before* SD-1.5 up-scaling ---- def pre_upscale( self, image: Image.Image, upscale_factor: float, **_: Any, ) -> Image.Image: # 4× ESRGAN first, then the SD-1.5 stage handles the residual upscale image = self.esrgan.upscale_with_tiling(image) return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)