TBurdairon's picture
Upload inference.py with huggingface_hub
7a3bb0d verified
from pathlib import Path
import base64, io, torch
from PIL import Image
from huggingface_hub import hf_hub_download
from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
# ------------------------------------------------------------------
# 1️⃣ Download all required checkpoints once (cached by HF DL cache)
# ------------------------------------------------------------------
def dl(repo, fname, rev=None):
return Path(
hf_hub_download(repo_id=repo, filename=fname, revision=rev)
)
checkpoints = ESRGANUpscalerCheckpoints(
# SD-1.5 Multi-Upscaler core
unet = dl("refiners/juggernaut.reborn.sd1_5.unet", "model.safetensors", "347d14c3c782c4959cc4d1bb1e336d19f7dda4d2"),
clip_text_encoder = dl("refiners/juggernaut.reborn.sd1_5.text_encoder", "model.safetensors", "744ad6a5c0437ec02ad826df9f6ede102bb27481"),
lda = dl("refiners/juggernaut.reborn.sd1_5.autoencoder", "model.safetensors", "3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19"),
controlnet_tile = dl("refiners/controlnet.sd1_5.tile", "model.safetensors", "48ced6ff8bfa873a8976fa467c3629a240643387"),
# ESRGAN 4× super-res
esrgan = dl("philz1337x/upscaler", "4x-UltraSharp.pth", "011deacac8270114eb7d2eeff4fe6fa9a837be70"),
# Negative prompt embedding
negative_embedding = dl("philz1337x/embeddings", "JuggernautNegative-neg.pt", "203caa7e9cc2bc225031a4021f6ab1ded283454a"),
negative_embedding_key = "string_to_param.*",
# LoRAs
loras = {
"more_details": dl("philz1337x/loras", "more_details.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"),
"sdxl_render" : dl("philz1337x/loras", "SDXLrender_v2.0.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"),
},
)
# ------------------------------------------------------------------
# 2️⃣ Instantiate the enhancer once (global singleton)
# ------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
enhancer = ESRGANUpscaler(checkpoints=checkpoints, device=device, dtype=dtype)
# ------------------------------------------------------------------
# 3️⃣ Hosted-Inference entry-point
# ------------------------------------------------------------------
class EndpointHandler:
"""
Hugging Face Hosted Inference API entrypoint.
Expects: {"image": "<BASE64-ENCODED>"}
Returns: {enhanced_image, original_size, enhanced_size}
"""
def __init__(self, path="."):
pass # all heavy work done globally above
def __call__(self, inputs: dict) -> dict:
if "image" not in inputs:
return {"error": "No image provided"}
# strip optional data:image/...;base64, prefix
payload = inputs["image"].split(",", 1)[-1]
try:
img = Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB")
except Exception as e:
return {"error": f"Bad image payload: {e}"}
# run the two-stage upscaler
out = enhancer.upscale(img)
buf = io.BytesIO()
out.save(buf, format="PNG")
out_b64 = base64.b64encode(buf.getvalue()).decode()
return {
"enhanced_image": out_b64,
"original_size": img.size,
"enhanced_size": out.size,
}