Update handler.py
Browse files- handler.py +24 -65
handler.py
CHANGED
@@ -1,85 +1,44 @@
|
|
1 |
import base64
|
2 |
import io
|
3 |
-
import traceback
|
4 |
from PIL import Image
|
5 |
import torch
|
6 |
from diffusers import StableDiffusionImg2ImgPipeline
|
7 |
|
8 |
# Global pipeline instance
|
|
|
9 |
pipe = None
|
10 |
|
11 |
class EndpointHandler:
|
12 |
def __init__(self, model_dir: str):
|
13 |
-
#
|
14 |
-
|
15 |
|
16 |
def init(self):
|
17 |
-
"""
|
18 |
-
Load the InstantID-enhanced Stable Diffusion img2img model once when the endpoint starts.
|
19 |
-
"""
|
20 |
global pipe
|
21 |
if pipe is None:
|
|
|
22 |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
23 |
"karthikAI/InstantID-i2i",
|
24 |
revision="main",
|
25 |
-
torch_dtype=torch.float16
|
26 |
-
|
27 |
-
).to(self.device)
|
28 |
-
pipe.enable_attention_slicing()
|
29 |
|
30 |
def inference(self, model_inputs: dict) -> dict:
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
# 1. Decode incoming image
|
50 |
-
b64_img = model_inputs.get("inputs")
|
51 |
-
if not b64_img:
|
52 |
-
raise ValueError("No image data provided under 'inputs'.")
|
53 |
-
image_bytes = base64.b64decode(b64_img)
|
54 |
-
init_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
55 |
-
|
56 |
-
# 2. Extract parameters
|
57 |
-
params = model_inputs.get("parameters", {})
|
58 |
-
prompt = params.get("prompt", "")
|
59 |
-
strength = float(params.get("strength", 0.75))
|
60 |
-
guidance_scale = float(params.get("guidance_scale", 7.5))
|
61 |
-
num_steps = int(params.get("num_inference_steps", 50))
|
62 |
-
|
63 |
-
# 3. Run the img2img pipeline
|
64 |
-
result = pipe(
|
65 |
-
prompt=prompt,
|
66 |
-
image=init_img,
|
67 |
-
strength=strength,
|
68 |
-
guidance_scale=guidance_scale,
|
69 |
-
num_inference_steps=num_steps,
|
70 |
-
)
|
71 |
-
out_img = result.images[0]
|
72 |
-
|
73 |
-
# 4. Encode and return image
|
74 |
-
buffer = io.BytesIO()
|
75 |
-
out_img.save(buffer, format="PNG")
|
76 |
-
generated_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
77 |
-
return {"generated_image_base64": generated_b64}
|
78 |
-
|
79 |
-
except Exception as e:
|
80 |
-
# Return detailed error info for debugging
|
81 |
-
tb = traceback.format_exc()
|
82 |
-
return {
|
83 |
-
"error": str(e),
|
84 |
-
"traceback": tb
|
85 |
-
}
|
|
|
1 |
import base64
|
2 |
import io
|
|
|
3 |
from PIL import Image
|
4 |
import torch
|
5 |
from diffusers import StableDiffusionImg2ImgPipeline
|
6 |
|
7 |
# Global pipeline instance
|
8 |
+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
pipe = None
|
10 |
|
11 |
class EndpointHandler:
|
12 |
def __init__(self, model_dir: str):
|
13 |
+
# model_dir is ignored; HF clones your repo here
|
14 |
+
pass
|
15 |
|
16 |
def init(self):
|
|
|
|
|
|
|
17 |
global pipe
|
18 |
if pipe is None:
|
19 |
+
# Load your InstantID img2img model
|
20 |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
21 |
"karthikAI/InstantID-i2i",
|
22 |
revision="main",
|
23 |
+
torch_dtype=torch.float16
|
24 |
+
).to(torch_device)
|
|
|
|
|
25 |
|
26 |
def inference(self, model_inputs: dict) -> dict:
|
27 |
+
# 1) decode base64 image
|
28 |
+
b64 = model_inputs.get("inputs")
|
29 |
+
if b64 is None:
|
30 |
+
return {"error": "No 'inputs' key with base64 image provided."}
|
31 |
+
img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
|
32 |
+
|
33 |
+
# 2) extract prompt
|
34 |
+
prompt = model_inputs.get("parameters", {}).get("prompt", "")
|
35 |
+
|
36 |
+
# 3) minimal call: prompt + image only
|
37 |
+
out = pipe(prompt=prompt, image=img)
|
38 |
+
result_img = out.images[0]
|
39 |
+
|
40 |
+
# 4) encode output
|
41 |
+
buf = io.BytesIO()
|
42 |
+
result_img.save(buf, format="PNG")
|
43 |
+
b64_out = base64.b64encode(buf.getvalue()).decode()
|
44 |
+
return {"generated_image_base64": b64_out}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|