Update handler.py
Browse files- handler.py +41 -29
handler.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import base64
|
2 |
import io
|
|
|
3 |
from PIL import Image
|
4 |
import torch
|
5 |
from diffusers import StableDiffusionImg2ImgPipeline
|
@@ -19,7 +20,7 @@ class EndpointHandler:
|
|
19 |
global pipe
|
20 |
if pipe is None:
|
21 |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
22 |
-
"karthikAI/InstantID-i2i",
|
23 |
revision="main",
|
24 |
torch_dtype=torch.float16,
|
25 |
safety_checker=None
|
@@ -28,7 +29,7 @@ class EndpointHandler:
|
|
28 |
|
29 |
def inference(self, model_inputs: dict) -> dict:
|
30 |
"""
|
31 |
-
Run a single img2img inference.
|
32 |
|
33 |
Expects a JSON payload with:
|
34 |
- "inputs": base64-encoded input image
|
@@ -38,36 +39,47 @@ class EndpointHandler:
|
|
38 |
"guidance_scale": float,
|
39 |
"num_inference_steps": int,
|
40 |
}
|
41 |
-
Returns
|
42 |
- "generated_image_base64": base64-encoded PNG
|
|
|
|
|
|
|
43 |
"""
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
72 |
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import base64
|
2 |
import io
|
3 |
+
import traceback
|
4 |
from PIL import Image
|
5 |
import torch
|
6 |
from diffusers import StableDiffusionImg2ImgPipeline
|
|
|
20 |
global pipe
|
21 |
if pipe is None:
|
22 |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
23 |
+
"karthikAI/InstantID-i2i",
|
24 |
revision="main",
|
25 |
torch_dtype=torch.float16,
|
26 |
safety_checker=None
|
|
|
29 |
|
30 |
def inference(self, model_inputs: dict) -> dict:
|
31 |
"""
|
32 |
+
Run a single img2img inference with detailed error debugging.
|
33 |
|
34 |
Expects a JSON payload with:
|
35 |
- "inputs": base64-encoded input image
|
|
|
39 |
"guidance_scale": float,
|
40 |
"num_inference_steps": int,
|
41 |
}
|
42 |
+
Returns on success:
|
43 |
- "generated_image_base64": base64-encoded PNG
|
44 |
+
On failure:
|
45 |
+
- "error": error message
|
46 |
+
- "traceback": full Python traceback
|
47 |
"""
|
48 |
+
try:
|
49 |
+
# 1. Decode incoming image
|
50 |
+
b64_img = model_inputs.get("inputs")
|
51 |
+
if not b64_img:
|
52 |
+
raise ValueError("No image data provided under 'inputs'.")
|
53 |
+
image_bytes = base64.b64decode(b64_img)
|
54 |
+
init_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
55 |
|
56 |
+
# 2. Extract parameters
|
57 |
+
params = model_inputs.get("parameters", {})
|
58 |
+
prompt = params.get("prompt", "")
|
59 |
+
strength = float(params.get("strength", 0.75))
|
60 |
+
guidance_scale = float(params.get("guidance_scale", 7.5))
|
61 |
+
num_steps = int(params.get("num_inference_steps", 50))
|
62 |
|
63 |
+
# 3. Run the img2img pipeline
|
64 |
+
result = pipe(
|
65 |
+
prompt=prompt,
|
66 |
+
image=init_img,
|
67 |
+
strength=strength,
|
68 |
+
guidance_scale=guidance_scale,
|
69 |
+
num_inference_steps=num_steps,
|
70 |
+
)
|
71 |
+
out_img = result.images[0]
|
72 |
|
73 |
+
# 4. Encode and return image
|
74 |
+
buffer = io.BytesIO()
|
75 |
+
out_img.save(buffer, format="PNG")
|
76 |
+
generated_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
77 |
+
return {"generated_image_base64": generated_b64}
|
78 |
|
79 |
+
except Exception as e:
|
80 |
+
# Return detailed error info for debugging
|
81 |
+
tb = traceback.format_exc()
|
82 |
+
return {
|
83 |
+
"error": str(e),
|
84 |
+
"traceback": tb
|
85 |
+
}
|