chansung commited on
Commit
44565ca
·
1 Parent(s): 390eb53

update custom handler

Browse files
Files changed (2) hide show
  1. __pycache__/handler.cpython-38.pyc +0 -0
  2. handler.py +18 -4
__pycache__/handler.cpython-38.pyc CHANGED
Binary files a/__pycache__/handler.cpython-38.pyc and b/__pycache__/handler.cpython-38.pyc differ
 
handler.py CHANGED
@@ -1,16 +1,30 @@
1
  from typing import Dict, List, Any
 
2
  import base64
 
3
  import keras_cv
4
 
5
  class EndpointHandler():
6
- def __init__(self, path=""):
7
- self.sd = keras_cv.models.StableDiffusionV2(img_width=512, img_height=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def __call__(self, data: Dict[str, Any]) -> str:
10
- # get inputs
11
  prompt = data.pop("inputs", data)
12
  batch_size = data.pop("batch_size", 1)
13
 
14
- # run normal prediction
15
  images = self.sd.text_to_image(prompt, batch_size=batch_size)
16
  return base64.b64encode(images.tobytes()).decode()
 
1
  from typing import Dict, List, Any
2
+ import sys
3
  import base64
4
+ import logging
5
  import keras_cv
6
 
7
  class EndpointHandler():
8
+ def __init__(self, path="", version="2"):
9
+ self.sd = self._instantiate_stable_diffusion(version)
10
+
11
+ if isinstance(self.sd, str):
12
+ sys.exit(self.sd)
13
+ else:
14
+ self.sd.text_to_image("test prompt", batch_size=1)
15
+ logging.warning(f"Stable Diffusion v{version} is fully loaded")
16
+
17
+ def _instantiate_stable_diffusion(self, version: str):
18
+ if version is "1.4":
19
+ return keras_cv.models.StableDiffusion(img_width=512, img_height=512)
20
+ elif version is "2":
21
+ return keras_cv.models.StableDiffusionV2(img_width=512, img_height=512)
22
+ else:
23
+ return f"v{version} is not supported"
24
 
25
  def __call__(self, data: Dict[str, Any]) -> str:
 
26
  prompt = data.pop("inputs", data)
27
  batch_size = data.pop("batch_size", 1)
28
 
 
29
  images = self.sd.text_to_image(prompt, batch_size=batch_size)
30
  return base64.b64encode(images.tobytes()).decode()