KarthikAI commited on
Commit
32b6c5f
·
verified ·
1 Parent(s): 2120866

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +41 -29
handler.py CHANGED
@@ -1,5 +1,6 @@
1
  import base64
2
  import io
 
3
  from PIL import Image
4
  import torch
5
  from diffusers import StableDiffusionImg2ImgPipeline
@@ -19,7 +20,7 @@ class EndpointHandler:
19
  global pipe
20
  if pipe is None:
21
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
22
- "karthikAI/InstantID-i2i", # Your HF repo with InstantID adapter
23
  revision="main",
24
  torch_dtype=torch.float16,
25
  safety_checker=None
@@ -28,7 +29,7 @@ class EndpointHandler:
28
 
29
  def inference(self, model_inputs: dict) -> dict:
30
  """
31
- Run a single img2img inference.
32
 
33
  Expects a JSON payload with:
34
  - "inputs": base64-encoded input image
@@ -38,36 +39,47 @@ class EndpointHandler:
38
  "guidance_scale": float,
39
  "num_inference_steps": int,
40
  }
41
- Returns a dict with:
42
  - "generated_image_base64": base64-encoded PNG
 
 
 
43
  """
44
- # 1. Decode the incoming image
45
- b64_img = model_inputs.get("inputs")
46
- if not b64_img:
47
- raise ValueError("No image data provided under 'inputs'.")
48
- image_bytes = base64.b64decode(b64_img)
49
- init_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
50
 
51
- # 2. Extract parameters
52
- params = model_inputs.get("parameters", {})
53
- prompt = params.get("prompt", "")
54
- strength = float(params.get("strength", 0.75))
55
- guidance_scale = float(params.get("guidance_scale", 7.5))
56
- num_steps = int(params.get("num_inference_steps", 50))
57
 
58
- # 3. Run the img2img pipeline
59
- result = pipe(
60
- prompt=prompt,
61
- image=init_img,
62
- strength=strength,
63
- guidance_scale=guidance_scale,
64
- num_inference_steps=num_steps,
65
- )
66
- out_img = result.images[0]
67
 
68
- # 4. Encode the output image back to base64
69
- buffer = io.BytesIO()
70
- out_img.save(buffer, format="PNG")
71
- generated_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
 
72
 
73
- return {"generated_image_base64": generated_b64}
 
 
 
 
 
 
 
1
  import base64
2
  import io
3
+ import traceback
4
  from PIL import Image
5
  import torch
6
  from diffusers import StableDiffusionImg2ImgPipeline
 
20
  global pipe
21
  if pipe is None:
22
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
23
+ "karthikAI/InstantID-i2i",
24
  revision="main",
25
  torch_dtype=torch.float16,
26
  safety_checker=None
 
29
 
30
  def inference(self, model_inputs: dict) -> dict:
31
  """
32
+ Run a single img2img inference with detailed error debugging.
33
 
34
  Expects a JSON payload with:
35
  - "inputs": base64-encoded input image
 
39
  "guidance_scale": float,
40
  "num_inference_steps": int,
41
  }
42
+ Returns on success:
43
  - "generated_image_base64": base64-encoded PNG
44
+ On failure:
45
+ - "error": error message
46
+ - "traceback": full Python traceback
47
  """
48
+ try:
49
+ # 1. Decode incoming image
50
+ b64_img = model_inputs.get("inputs")
51
+ if not b64_img:
52
+ raise ValueError("No image data provided under 'inputs'.")
53
+ image_bytes = base64.b64decode(b64_img)
54
+ init_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
55
 
56
+ # 2. Extract parameters
57
+ params = model_inputs.get("parameters", {})
58
+ prompt = params.get("prompt", "")
59
+ strength = float(params.get("strength", 0.75))
60
+ guidance_scale = float(params.get("guidance_scale", 7.5))
61
+ num_steps = int(params.get("num_inference_steps", 50))
62
 
63
+ # 3. Run the img2img pipeline
64
+ result = pipe(
65
+ prompt=prompt,
66
+ image=init_img,
67
+ strength=strength,
68
+ guidance_scale=guidance_scale,
69
+ num_inference_steps=num_steps,
70
+ )
71
+ out_img = result.images[0]
72
 
73
+ # 4. Encode and return image
74
+ buffer = io.BytesIO()
75
+ out_img.save(buffer, format="PNG")
76
+ generated_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
77
+ return {"generated_image_base64": generated_b64}
78
 
79
+ except Exception as e:
80
+ # Return detailed error info for debugging
81
+ tb = traceback.format_exc()
82
+ return {
83
+ "error": str(e),
84
+ "traceback": tb
85
+ }