from pathlib import Path import base64, io, torch from PIL import Image from huggingface_hub import hf_hub_download from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints # ------------------------------------------------------------------ # 1️⃣ Download all required checkpoints once (cached by HF DL cache) # ------------------------------------------------------------------ def dl(repo, fname, rev=None): return Path( hf_hub_download(repo_id=repo, filename=fname, revision=rev) ) checkpoints = ESRGANUpscalerCheckpoints( # SD-1.5 Multi-Upscaler core unet = dl("refiners/juggernaut.reborn.sd1_5.unet", "model.safetensors", "347d14c3c782c4959cc4d1bb1e336d19f7dda4d2"), clip_text_encoder = dl("refiners/juggernaut.reborn.sd1_5.text_encoder", "model.safetensors", "744ad6a5c0437ec02ad826df9f6ede102bb27481"), lda = dl("refiners/juggernaut.reborn.sd1_5.autoencoder", "model.safetensors", "3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19"), controlnet_tile = dl("refiners/controlnet.sd1_5.tile", "model.safetensors", "48ced6ff8bfa873a8976fa467c3629a240643387"), # ESRGAN 4× super-res esrgan = dl("philz1337x/upscaler", "4x-UltraSharp.pth", "011deacac8270114eb7d2eeff4fe6fa9a837be70"), # Negative prompt embedding negative_embedding = dl("philz1337x/embeddings", "JuggernautNegative-neg.pt", "203caa7e9cc2bc225031a4021f6ab1ded283454a"), negative_embedding_key = "string_to_param.*", # LoRAs loras = { "more_details": dl("philz1337x/loras", "more_details.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"), "sdxl_render" : dl("philz1337x/loras", "SDXLrender_v2.0.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"), }, ) # ------------------------------------------------------------------ # 2️⃣ Instantiate the enhancer once (global singleton) # ------------------------------------------------------------------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 enhancer = ESRGANUpscaler(checkpoints=checkpoints, device=device, dtype=dtype) # ------------------------------------------------------------------ # 3️⃣ Hosted-Inference entry-point # ------------------------------------------------------------------ class EndpointHandler: """ Hugging Face Hosted Inference API entrypoint. Expects: {"image": ""} Returns: {enhanced_image, original_size, enhanced_size} """ def __init__(self, path="."): pass # all heavy work done globally above def __call__(self, inputs: dict) -> dict: if "image" not in inputs: return {"error": "No image provided"} # strip optional data:image/...;base64, prefix payload = inputs["image"].split(",", 1)[-1] try: img = Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB") except Exception as e: return {"error": f"Bad image payload: {e}"} # run the two-stage upscaler out = enhancer.upscale(img) buf = io.BytesIO() out.save(buf, format="PNG") out_b64 = base64.b64encode(buf.getvalue()).decode() return { "enhanced_image": out_b64, "original_size": img.size, "enhanced_size": out.size, }