Upload inference.py with huggingface_hub
Browse files- inference.py +54 -25
inference.py
CHANGED
@@ -1,50 +1,79 @@
|
|
1 |
from pathlib import Path
|
2 |
-
import torch
|
3 |
from PIL import Image
|
4 |
-
import
|
5 |
-
|
6 |
from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
|
7 |
|
8 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
checkpoints = ESRGANUpscalerCheckpoints(
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
)
|
12 |
|
|
|
|
|
|
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
-
dtype
|
15 |
|
16 |
enhancer = ESRGANUpscaler(checkpoints=checkpoints, device=device, dtype=dtype)
|
17 |
|
18 |
-
#
|
|
|
|
|
19 |
class EndpointHandler:
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def __init__(self, path="."):
|
21 |
-
pass #
|
22 |
|
23 |
def __call__(self, inputs: dict) -> dict:
|
24 |
-
"""
|
25 |
-
Expected payload:
|
26 |
-
{"image": "<BASE64-STRING>"}
|
27 |
-
Returns:
|
28 |
-
{ "enhanced_image": "<BASE64-PNG>", "original_size": [w,h], "enhanced_size": [w,h] }
|
29 |
-
"""
|
30 |
if "image" not in inputs:
|
31 |
return {"error": "No image provided"}
|
32 |
|
33 |
-
#
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
38 |
|
39 |
-
# run
|
40 |
-
|
41 |
|
42 |
buf = io.BytesIO()
|
43 |
-
|
44 |
-
|
45 |
|
46 |
return {
|
47 |
-
"enhanced_image":
|
48 |
"original_size": img.size,
|
49 |
-
"enhanced_size":
|
50 |
}
|
|
|
1 |
from pathlib import Path
|
2 |
+
import base64, io, torch
|
3 |
from PIL import Image
|
4 |
+
from huggingface_hub import hf_hub_download
|
|
|
5 |
from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
|
6 |
|
7 |
+
# ------------------------------------------------------------------
|
8 |
+
# 1️⃣ Download all required checkpoints once (cached by HF DL cache)
|
9 |
+
# ------------------------------------------------------------------
|
10 |
+
def dl(repo, fname, rev=None):
|
11 |
+
return Path(
|
12 |
+
hf_hub_download(repo_id=repo, filename=fname, revision=rev)
|
13 |
+
)
|
14 |
+
|
15 |
checkpoints = ESRGANUpscalerCheckpoints(
|
16 |
+
# SD-1.5 Multi-Upscaler core
|
17 |
+
unet = dl("refiners/juggernaut.reborn.sd1_5.unet", "model.safetensors", "347d14c3c782c4959cc4d1bb1e336d19f7dda4d2"),
|
18 |
+
clip_text_encoder = dl("refiners/juggernaut.reborn.sd1_5.text_encoder", "model.safetensors", "744ad6a5c0437ec02ad826df9f6ede102bb27481"),
|
19 |
+
lda = dl("refiners/juggernaut.reborn.sd1_5.autoencoder", "model.safetensors", "3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19"),
|
20 |
+
controlnet_tile = dl("refiners/controlnet.sd1_5.tile", "model.safetensors", "48ced6ff8bfa873a8976fa467c3629a240643387"),
|
21 |
+
|
22 |
+
# ESRGAN 4× super-res
|
23 |
+
esrgan = dl("philz1337x/upscaler", "4x-UltraSharp.pth", "011deacac8270114eb7d2eeff4fe6fa9a837be70"),
|
24 |
+
|
25 |
+
# Negative prompt embedding
|
26 |
+
negative_embedding = dl("philz1337x/embeddings", "JuggernautNegative-neg.pt", "203caa7e9cc2bc225031a4021f6ab1ded283454a"),
|
27 |
+
negative_embedding_key = "string_to_param.*",
|
28 |
+
|
29 |
+
# LoRAs
|
30 |
+
loras = {
|
31 |
+
"more_details": dl("philz1337x/loras", "more_details.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"),
|
32 |
+
"sdxl_render" : dl("philz1337x/loras", "SDXLrender_v2.0.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"),
|
33 |
+
},
|
34 |
)
|
35 |
|
36 |
+
# ------------------------------------------------------------------
|
37 |
+
# 2️⃣ Instantiate the enhancer once (global singleton)
|
38 |
+
# ------------------------------------------------------------------
|
39 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
41 |
|
42 |
enhancer = ESRGANUpscaler(checkpoints=checkpoints, device=device, dtype=dtype)
|
43 |
|
44 |
+
# ------------------------------------------------------------------
|
45 |
+
# 3️⃣ Hosted-Inference entry-point
|
46 |
+
# ------------------------------------------------------------------
|
47 |
class EndpointHandler:
|
48 |
+
"""
|
49 |
+
Hugging Face Hosted Inference API entrypoint.
|
50 |
+
Expects: {"image": "<BASE64-ENCODED>"}
|
51 |
+
Returns: {enhanced_image, original_size, enhanced_size}
|
52 |
+
"""
|
53 |
+
|
54 |
def __init__(self, path="."):
|
55 |
+
pass # all heavy work done globally above
|
56 |
|
57 |
def __call__(self, inputs: dict) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
if "image" not in inputs:
|
59 |
return {"error": "No image provided"}
|
60 |
|
61 |
+
# strip optional data:image/...;base64, prefix
|
62 |
+
payload = inputs["image"].split(",", 1)[-1]
|
63 |
+
try:
|
64 |
+
img = Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB")
|
65 |
+
except Exception as e:
|
66 |
+
return {"error": f"Bad image payload: {e}"}
|
67 |
|
68 |
+
# run the two-stage upscaler
|
69 |
+
out = enhancer.upscale(img)
|
70 |
|
71 |
buf = io.BytesIO()
|
72 |
+
out.save(buf, format="PNG")
|
73 |
+
out_b64 = base64.b64encode(buf.getvalue()).decode()
|
74 |
|
75 |
return {
|
76 |
+
"enhanced_image": out_b64,
|
77 |
"original_size": img.size,
|
78 |
+
"enhanced_size": out.size,
|
79 |
}
|