TBurdairon commited on
Commit
7a3bb0d
·
verified ·
1 Parent(s): 32aea44

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +54 -25
inference.py CHANGED
@@ -1,50 +1,79 @@
1
  from pathlib import Path
2
- import torch
3
  from PIL import Image
4
- import base64, io
5
-
6
  from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
7
 
8
- # -------- initialise model once at cold-start --------
 
 
 
 
 
 
 
9
  checkpoints = ESRGANUpscalerCheckpoints(
10
- esrgan=Path("checkpoints/4x-UltraSharp.pth")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  )
12
 
 
 
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
15
 
16
  enhancer = ESRGANUpscaler(checkpoints=checkpoints, device=device, dtype=dtype)
17
 
18
- # -------- Hugging Face Hosted Inference API Entry Point --------
 
 
19
  class EndpointHandler:
 
 
 
 
 
 
20
  def __init__(self, path="."):
21
- pass # path is not used since model is already initialized globally
22
 
23
  def __call__(self, inputs: dict) -> dict:
24
- """
25
- Expected payload:
26
- {"image": "<BASE64-STRING>"}
27
- Returns:
28
- { "enhanced_image": "<BASE64-PNG>", "original_size": [w,h], "enhanced_size": [w,h] }
29
- """
30
  if "image" not in inputs:
31
  return {"error": "No image provided"}
32
 
33
- # decode base64
34
- data = inputs["image"]
35
- if data.startswith("data:image"):
36
- data = data.split(",")[1]
37
- img = Image.open(io.BytesIO(base64.b64decode(data))).convert("RGB")
 
38
 
39
- # run ESRGAN ➜ SD-1.5 upscale
40
- result = enhancer.upscale(img)
41
 
42
  buf = io.BytesIO()
43
- result.save(buf, format="PNG")
44
- result_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
45
 
46
  return {
47
- "enhanced_image": result_b64,
48
  "original_size": img.size,
49
- "enhanced_size": result.size,
50
  }
 
1
  from pathlib import Path
2
+ import base64, io, torch
3
  from PIL import Image
4
+ from huggingface_hub import hf_hub_download
 
5
  from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
6
 
7
+ # ------------------------------------------------------------------
8
+ # 1️⃣ Download all required checkpoints once (cached by HF DL cache)
9
+ # ------------------------------------------------------------------
10
+ def dl(repo, fname, rev=None):
11
+ return Path(
12
+ hf_hub_download(repo_id=repo, filename=fname, revision=rev)
13
+ )
14
+
15
  checkpoints = ESRGANUpscalerCheckpoints(
16
+ # SD-1.5 Multi-Upscaler core
17
+ unet = dl("refiners/juggernaut.reborn.sd1_5.unet", "model.safetensors", "347d14c3c782c4959cc4d1bb1e336d19f7dda4d2"),
18
+ clip_text_encoder = dl("refiners/juggernaut.reborn.sd1_5.text_encoder", "model.safetensors", "744ad6a5c0437ec02ad826df9f6ede102bb27481"),
19
+ lda = dl("refiners/juggernaut.reborn.sd1_5.autoencoder", "model.safetensors", "3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19"),
20
+ controlnet_tile = dl("refiners/controlnet.sd1_5.tile", "model.safetensors", "48ced6ff8bfa873a8976fa467c3629a240643387"),
21
+
22
+ # ESRGAN 4× super-res
23
+ esrgan = dl("philz1337x/upscaler", "4x-UltraSharp.pth", "011deacac8270114eb7d2eeff4fe6fa9a837be70"),
24
+
25
+ # Negative prompt embedding
26
+ negative_embedding = dl("philz1337x/embeddings", "JuggernautNegative-neg.pt", "203caa7e9cc2bc225031a4021f6ab1ded283454a"),
27
+ negative_embedding_key = "string_to_param.*",
28
+
29
+ # LoRAs
30
+ loras = {
31
+ "more_details": dl("philz1337x/loras", "more_details.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"),
32
+ "sdxl_render" : dl("philz1337x/loras", "SDXLrender_v2.0.safetensors", "a3802c0280c0d00c2ab18d37454a8744c44e474e"),
33
+ },
34
  )
35
 
36
+ # ------------------------------------------------------------------
37
+ # 2️⃣ Instantiate the enhancer once (global singleton)
38
+ # ------------------------------------------------------------------
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
41
 
42
  enhancer = ESRGANUpscaler(checkpoints=checkpoints, device=device, dtype=dtype)
43
 
44
+ # ------------------------------------------------------------------
45
+ # 3️⃣ Hosted-Inference entry-point
46
+ # ------------------------------------------------------------------
47
  class EndpointHandler:
48
+ """
49
+ Hugging Face Hosted Inference API entrypoint.
50
+ Expects: {"image": "<BASE64-ENCODED>"}
51
+ Returns: {enhanced_image, original_size, enhanced_size}
52
+ """
53
+
54
  def __init__(self, path="."):
55
+ pass # all heavy work done globally above
56
 
57
  def __call__(self, inputs: dict) -> dict:
 
 
 
 
 
 
58
  if "image" not in inputs:
59
  return {"error": "No image provided"}
60
 
61
+ # strip optional data:image/...;base64, prefix
62
+ payload = inputs["image"].split(",", 1)[-1]
63
+ try:
64
+ img = Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB")
65
+ except Exception as e:
66
+ return {"error": f"Bad image payload: {e}"}
67
 
68
+ # run the two-stage upscaler
69
+ out = enhancer.upscale(img)
70
 
71
  buf = io.BytesIO()
72
+ out.save(buf, format="PNG")
73
+ out_b64 = base64.b64encode(buf.getvalue()).decode()
74
 
75
  return {
76
+ "enhanced_image": out_b64,
77
  "original_size": img.size,
78
+ "enhanced_size": out.size,
79
  }