import torch import json import base64 import io from PIL import Image from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLInpaintPipeline # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("Need to run on GPU") class EndpointHandler: def __init__(self, path="mrcuddle/URPM-Inpaint-SDXL"): """Load the SDXL Inpainting model.""" self.pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( path, torch_dtype=torch.float16 ) self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config) self.pipeline = self.pipeline.to(device) def __call__(self, data: dict): """Custom call function for Hugging Face Inference Endpoints.""" try: inputs = data.pop("inputs", data) encoded_image = data.pop("image", None) encoded_mask_image = data.pop("mask_image", None) num_inference_steps = data.pop("num_inference_steps", 25) guidance_scale = data.pop("guidance_scale", 7.5) negative_prompt = data.pop("negative_prompt", None) height = data.pop("height", None) width = data.pop("width", None) # Process images if encoded_image and encoded_mask_image: image = self.decode_base64_image(encoded_image) mask_image = self.decode_base64_image(encoded_mask_image) else: raise ValueError("Both image and mask_image are required") # Run inference output_image = self.pipeline( prompt=inputs, image=image, mask_image=mask_image, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=1, negative_prompt=negative_prompt, height=height, width=width ).images[0] return json.dumps({"output": self.encode_base64_image(output_image)}) except Exception as e: return json.dumps({"error": str(e)}) def decode_base64_image(self, image_string): """Decode base64 encoded image.""" base64_image = base64.b64decode(image_string) buffer = io.BytesIO(base64_image) return Image.open(buffer).convert("RGB") def encode_base64_image(self, image): """Encode PIL image to base64.""" buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") # Create an instance of EndpointHandler handler = EndpointHandler() def handle(data: dict): return handler(data)