dmingod commited on
Commit
2bed270
·
1 Parent(s): 823f4b4
Files changed (1) hide show
  1. handler.py +65 -0
handler.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ from diffusers import AutoPipelineForText2Image
4
+ import torch
5
+ import base64
6
+ from io import BytesIO
7
+ from PIL import Image
8
+
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ self.pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
13
+ self.pipe.to("cuda")
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ """
17
+ data args:
18
+ inputs (:obj: `str`)
19
+ date (:obj: `str`)
20
+ Return:
21
+ A :obj:`list` | `dict`: will be serialized and returned
22
+ """
23
+ # get inputs
24
+ inputs = data.pop("inputs", data)
25
+ encoded_image = data.pop("image", None)
26
+ encoded_mask_image = data.pop("mask_image", None)
27
+
28
+ # hyperparamters
29
+ num_inference_steps = data.pop("num_inference_steps", 25)
30
+ guidance_scale = data.pop("guidance_scale", 7.5)
31
+ negative_prompt = data.pop("negative_prompt", None)
32
+ height = data.pop("height", None)
33
+ width = data.pop("width", None)
34
+
35
+ # process image
36
+ if encoded_image is not None and encoded_mask_image is not None:
37
+ image = self.decode_base64_image(encoded_image)
38
+ mask_image = self.decode_base64_image(encoded_mask_image)
39
+ else:
40
+ image = None
41
+ mask_image = None
42
+
43
+ # run inference pipeline
44
+ out = self.pipe(inputs,
45
+ image=image,
46
+ mask_image=mask_image,
47
+ num_inference_steps=num_inference_steps,
48
+ guidance_scale=guidance_scale,
49
+ num_images_per_prompt=1,
50
+ negative_prompt=negative_prompt,
51
+ height=height,
52
+ width=width
53
+ )
54
+
55
+ # return first generate PIL image
56
+ return out.images[0]
57
+
58
+ # helper to decode input image
59
+ def decode_base64_image(self, image_string):
60
+ base64_image = base64.b64decode(image_string)
61
+ buffer = BytesIO(base64_image)
62
+ image = Image.open(buffer)
63
+ return image
64
+
65
+