|
from pathlib import Path |
|
import base64, io, torch |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints |
|
|
|
|
|
|
|
|
|
def dl(repo, fname, rev=None): |
|
return Path( |
|
hf_hub_download(repo_id=repo, filename=fname, revision=rev) |
|
) |
|
|
|
checkpoints = ESRGANUpscalerCheckpoints( |
|
|
|
unet = dl("refiners/juggernaut.reborn.sd1_5.unet", "model.safetensors", "347d14c3c782c4959cc4d1bb1e336d19f7dda4d2"), |
|
clip_text_encoder = dl("refiners/juggernaut.reborn.sd1_5.text_encoder", "model.safetensors", "744ad6a5c0437ec02ad826df9f6ede102bb27481"), |
|
lda = dl("refiners/juggernaut.reborn.sd1_5.autoencoder", "model.safetensors", "3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19"), |
|
controlnet_tile = dl("refiners/controlnet.sd1_5.tile", "model.safetensors", "48ced6ff8bfa873a8976fa467c3629a240643387"), |
|
|
|
|
|
esrgan = dl("philz1337x/upscaler", "4x-UltraSharp.pth", "011deacac8270114eb7d2eeff4fe6fa9a837be70"), |
|
|
|
|
|
negative_embedding = dl("philz1337x/embeddings", "JuggernautNegative-neg.pt", "203caa7e9cc2bc225031a4021f6ab1ded283454a"), |
|
negative_embedding_key = "string_to_param.*", |
|
|
|
|
|
loras = { |
|
"more_details": dl("philz1337x/loras", "more_details.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"), |
|
"sdxl_render" : dl("philz1337x/loras", "SDXLrender_v2.0.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"), |
|
}, |
|
) |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 |
|
|
|
enhancer = ESRGANUpscaler(checkpoints=checkpoints, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
""" |
|
Hugging Face Hosted Inference API entrypoint. |
|
Expects: {"image": "<BASE64-ENCODED>"} |
|
Returns: {enhanced_image, original_size, enhanced_size} |
|
""" |
|
|
|
def __init__(self, path="."): |
|
pass |
|
|
|
def __call__(self, inputs: dict) -> dict: |
|
if "image" not in inputs: |
|
return {"error": "No image provided"} |
|
|
|
|
|
payload = inputs["image"].split(",", 1)[-1] |
|
try: |
|
img = Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB") |
|
except Exception as e: |
|
return {"error": f"Bad image payload: {e}"} |
|
|
|
|
|
out = enhancer.upscale(img) |
|
|
|
buf = io.BytesIO() |
|
out.save(buf, format="PNG") |
|
out_b64 = base64.b64encode(buf.getvalue()).decode() |
|
|
|
return { |
|
"enhanced_image": out_b64, |
|
"original_size": img.size, |
|
"enhanced_size": out.size, |
|
} |
|
|