File size: 3,483 Bytes
a4196d6
7a3bb0d
a4196d6
7a3bb0d
a4196d6
 
7a3bb0d
 
 
 
 
 
 
 
a4196d6
7a3bb0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4196d6
 
7a3bb0d
 
 
a4196d6
7a3bb0d
a4196d6
8d81eef
a4196d6
7a3bb0d
 
 
32aea44
7a3bb0d
 
 
 
 
 
32aea44
7a3bb0d
32aea44
 
 
 
 
7a3bb0d
 
 
 
 
 
32aea44
7a3bb0d
 
32aea44
 
7a3bb0d
 
32aea44
 
7a3bb0d
32aea44
7a3bb0d
32aea44
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from pathlib import Path
import base64, io, torch
from PIL import Image
from huggingface_hub import hf_hub_download
from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints

# ------------------------------------------------------------------
# 1️⃣  Download all required checkpoints once (cached by HF DL cache)
# ------------------------------------------------------------------
def dl(repo, fname, rev=None):
    return Path(
        hf_hub_download(repo_id=repo, filename=fname, revision=rev)
    )

checkpoints = ESRGANUpscalerCheckpoints(
    # SD-1.5 Multi-Upscaler core
    unet              = dl("refiners/juggernaut.reborn.sd1_5.unet",         "model.safetensors",   "347d14c3c782c4959cc4d1bb1e336d19f7dda4d2"),
    clip_text_encoder = dl("refiners/juggernaut.reborn.sd1_5.text_encoder", "model.safetensors",   "744ad6a5c0437ec02ad826df9f6ede102bb27481"),
    lda               = dl("refiners/juggernaut.reborn.sd1_5.autoencoder",  "model.safetensors",   "3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19"),
    controlnet_tile   = dl("refiners/controlnet.sd1_5.tile",                "model.safetensors",   "48ced6ff8bfa873a8976fa467c3629a240643387"),

    # ESRGAN 4× super-res
    esrgan            = dl("philz1337x/upscaler",                           "4x-UltraSharp.pth",   "011deacac8270114eb7d2eeff4fe6fa9a837be70"),

    # Negative prompt embedding
    negative_embedding     = dl("philz1337x/embeddings", "JuggernautNegative-neg.pt", "203caa7e9cc2bc225031a4021f6ab1ded283454a"),
    negative_embedding_key = "string_to_param.*",

    # LoRAs
    loras = {
        "more_details": dl("philz1337x/loras", "more_details.safetensors",     "a3802c0280c0d00c2ab18d37454a8744c44e474e"),
        "sdxl_render" : dl("philz1337x/loras", "SDXLrender_v2.0.safetensors",  "a3802c0280c0d00c2ab18d37454a8744c44e474e"),
    },
)

# ------------------------------------------------------------------
# 2️⃣  Instantiate the enhancer once (global singleton)
# ------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype  = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32

enhancer = ESRGANUpscaler(checkpoints=checkpoints, device=device, dtype=dtype)

# ------------------------------------------------------------------
# 3️⃣  Hosted-Inference entry-point
# ------------------------------------------------------------------
class EndpointHandler:
    """
    Hugging Face Hosted Inference API entrypoint.
    Expects: {"image": "<BASE64-ENCODED>"}
    Returns: {enhanced_image, original_size, enhanced_size}
    """

    def __init__(self, path="."):
        pass  # all heavy work done globally above

    def __call__(self, inputs: dict) -> dict:
        if "image" not in inputs:
            return {"error": "No image provided"}

        # strip optional data:image/...;base64, prefix
        payload = inputs["image"].split(",", 1)[-1]
        try:
            img = Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB")
        except Exception as e:
            return {"error": f"Bad image payload: {e}"}

        # run the two-stage upscaler
        out = enhancer.upscale(img)

        buf = io.BytesIO()
        out.save(buf, format="PNG")
        out_b64 = base64.b64encode(buf.getvalue()).decode()

        return {
            "enhanced_image": out_b64,
            "original_size": img.size,
            "enhanced_size": out.size,
        }