File size: 3,483 Bytes
a4196d6 7a3bb0d a4196d6 7a3bb0d a4196d6 7a3bb0d a4196d6 7a3bb0d a4196d6 7a3bb0d a4196d6 7a3bb0d a4196d6 8d81eef a4196d6 7a3bb0d 32aea44 7a3bb0d 32aea44 7a3bb0d 32aea44 7a3bb0d 32aea44 7a3bb0d 32aea44 7a3bb0d 32aea44 7a3bb0d 32aea44 7a3bb0d 32aea44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from pathlib import Path
import base64, io, torch
from PIL import Image
from huggingface_hub import hf_hub_download
from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
# ------------------------------------------------------------------
# 1️⃣ Download all required checkpoints once (cached by HF DL cache)
# ------------------------------------------------------------------
def dl(repo, fname, rev=None):
return Path(
hf_hub_download(repo_id=repo, filename=fname, revision=rev)
)
checkpoints = ESRGANUpscalerCheckpoints(
# SD-1.5 Multi-Upscaler core
unet = dl("refiners/juggernaut.reborn.sd1_5.unet", "model.safetensors", "347d14c3c782c4959cc4d1bb1e336d19f7dda4d2"),
clip_text_encoder = dl("refiners/juggernaut.reborn.sd1_5.text_encoder", "model.safetensors", "744ad6a5c0437ec02ad826df9f6ede102bb27481"),
lda = dl("refiners/juggernaut.reborn.sd1_5.autoencoder", "model.safetensors", "3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19"),
controlnet_tile = dl("refiners/controlnet.sd1_5.tile", "model.safetensors", "48ced6ff8bfa873a8976fa467c3629a240643387"),
# ESRGAN 4× super-res
esrgan = dl("philz1337x/upscaler", "4x-UltraSharp.pth", "011deacac8270114eb7d2eeff4fe6fa9a837be70"),
# Negative prompt embedding
negative_embedding = dl("philz1337x/embeddings", "JuggernautNegative-neg.pt", "203caa7e9cc2bc225031a4021f6ab1ded283454a"),
negative_embedding_key = "string_to_param.*",
# LoRAs
loras = {
"more_details": dl("philz1337x/loras", "more_details.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"),
"sdxl_render" : dl("philz1337x/loras", "SDXLrender_v2.0.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"),
},
)
# ------------------------------------------------------------------
# 2️⃣ Instantiate the enhancer once (global singleton)
# ------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
enhancer = ESRGANUpscaler(checkpoints=checkpoints, device=device, dtype=dtype)
# ------------------------------------------------------------------
# 3️⃣ Hosted-Inference entry-point
# ------------------------------------------------------------------
class EndpointHandler:
"""
Hugging Face Hosted Inference API entrypoint.
Expects: {"image": "<BASE64-ENCODED>"}
Returns: {enhanced_image, original_size, enhanced_size}
"""
def __init__(self, path="."):
pass # all heavy work done globally above
def __call__(self, inputs: dict) -> dict:
if "image" not in inputs:
return {"error": "No image provided"}
# strip optional data:image/...;base64, prefix
payload = inputs["image"].split(",", 1)[-1]
try:
img = Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB")
except Exception as e:
return {"error": f"Bad image payload: {e}"}
# run the two-stage upscaler
out = enhancer.upscale(img)
buf = io.BytesIO()
out.save(buf, format="PNG")
out_b64 = base64.b64encode(buf.getvalue()).decode()
return {
"enhanced_image": out_b64,
"original_size": img.size,
"enhanced_size": out.size,
}
|