|
import base64 |
|
import io |
|
from PIL import Image |
|
import torch |
|
from diffusers import StableDiffusionImg2ImgPipeline |
|
|
|
|
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
pipe = None |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str): |
|
|
|
pass |
|
|
|
def init(self): |
|
global pipe |
|
if pipe is None: |
|
|
|
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
"karthikAI/InstantID-i2i", |
|
revision="main", |
|
torch_dtype=torch.float16 |
|
).to(torch_device) |
|
|
|
def inference(self, model_inputs: dict) -> dict: |
|
|
|
b64 = model_inputs.get("inputs") |
|
if b64 is None: |
|
return {"error": "No 'inputs' key with base64 image provided."} |
|
img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") |
|
|
|
|
|
prompt = model_inputs.get("parameters", {}).get("prompt", "") |
|
|
|
|
|
out = pipe(prompt=prompt, image=img) |
|
result_img = out.images[0] |
|
|
|
|
|
buf = io.BytesIO() |
|
result_img.save(buf, format="PNG") |
|
b64_out = base64.b64encode(buf.getvalue()).decode() |
|
return {"generated_image_base64": b64_out} |