KarthikAI commited on
Commit
06a49e4
·
verified ·
1 Parent(s): 32b6c5f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +24 -65
handler.py CHANGED
@@ -1,85 +1,44 @@
1
  import base64
2
  import io
3
- import traceback
4
  from PIL import Image
5
  import torch
6
  from diffusers import StableDiffusionImg2ImgPipeline
7
 
8
  # Global pipeline instance
 
9
  pipe = None
10
 
11
  class EndpointHandler:
12
  def __init__(self, model_dir: str):
13
- # Determine device based on CUDA availability
14
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  def init(self):
17
- """
18
- Load the InstantID-enhanced Stable Diffusion img2img model once when the endpoint starts.
19
- """
20
  global pipe
21
  if pipe is None:
 
22
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
23
  "karthikAI/InstantID-i2i",
24
  revision="main",
25
- torch_dtype=torch.float16,
26
- safety_checker=None
27
- ).to(self.device)
28
- pipe.enable_attention_slicing()
29
 
30
  def inference(self, model_inputs: dict) -> dict:
31
- """
32
- Run a single img2img inference with detailed error debugging.
33
-
34
- Expects a JSON payload with:
35
- - "inputs": base64-encoded input image
36
- - "parameters": {
37
- "prompt": str,
38
- "strength": float,
39
- "guidance_scale": float,
40
- "num_inference_steps": int,
41
- }
42
- Returns on success:
43
- - "generated_image_base64": base64-encoded PNG
44
- On failure:
45
- - "error": error message
46
- - "traceback": full Python traceback
47
- """
48
- try:
49
- # 1. Decode incoming image
50
- b64_img = model_inputs.get("inputs")
51
- if not b64_img:
52
- raise ValueError("No image data provided under 'inputs'.")
53
- image_bytes = base64.b64decode(b64_img)
54
- init_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
55
-
56
- # 2. Extract parameters
57
- params = model_inputs.get("parameters", {})
58
- prompt = params.get("prompt", "")
59
- strength = float(params.get("strength", 0.75))
60
- guidance_scale = float(params.get("guidance_scale", 7.5))
61
- num_steps = int(params.get("num_inference_steps", 50))
62
-
63
- # 3. Run the img2img pipeline
64
- result = pipe(
65
- prompt=prompt,
66
- image=init_img,
67
- strength=strength,
68
- guidance_scale=guidance_scale,
69
- num_inference_steps=num_steps,
70
- )
71
- out_img = result.images[0]
72
-
73
- # 4. Encode and return image
74
- buffer = io.BytesIO()
75
- out_img.save(buffer, format="PNG")
76
- generated_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
77
- return {"generated_image_base64": generated_b64}
78
-
79
- except Exception as e:
80
- # Return detailed error info for debugging
81
- tb = traceback.format_exc()
82
- return {
83
- "error": str(e),
84
- "traceback": tb
85
- }
 
1
  import base64
2
  import io
 
3
  from PIL import Image
4
  import torch
5
  from diffusers import StableDiffusionImg2ImgPipeline
6
 
7
  # Global pipeline instance
8
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
9
  pipe = None
10
 
11
  class EndpointHandler:
12
  def __init__(self, model_dir: str):
13
+ # model_dir is ignored; HF clones your repo here
14
+ pass
15
 
16
  def init(self):
 
 
 
17
  global pipe
18
  if pipe is None:
19
+ # Load your InstantID img2img model
20
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
21
  "karthikAI/InstantID-i2i",
22
  revision="main",
23
+ torch_dtype=torch.float16
24
+ ).to(torch_device)
 
 
25
 
26
  def inference(self, model_inputs: dict) -> dict:
27
+ # 1) decode base64 image
28
+ b64 = model_inputs.get("inputs")
29
+ if b64 is None:
30
+ return {"error": "No 'inputs' key with base64 image provided."}
31
+ img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
32
+
33
+ # 2) extract prompt
34
+ prompt = model_inputs.get("parameters", {}).get("prompt", "")
35
+
36
+ # 3) minimal call: prompt + image only
37
+ out = pipe(prompt=prompt, image=img)
38
+ result_img = out.images[0]
39
+
40
+ # 4) encode output
41
+ buf = io.BytesIO()
42
+ result_img.save(buf, format="PNG")
43
+ b64_out = base64.b64encode(buf.getvalue()).decode()
44
+ return {"generated_image_base64": b64_out}