English
Oysiyl commited on
Commit
d24be00
·
verified ·
1 Parent(s): dfd9829

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +67 -0
handler.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from diffusers import AutoPipelineForText2Image
3
+ import torch
4
+
5
+
6
+ import numpy as np
7
+
8
+ # set device
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+ if device.type != 'cuda':
11
+ raise ValueError("need to run on GPU")
12
+ # set mixed precision dtype
13
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
14
+
15
+
16
+ class EndpointHandler():
17
+ def __init__(self, path=""):
18
+ # Load StableDiffusionPipeline
19
+ self.stable_diffusion_id = "stabilityai/stable-diffusion-xl-base-1.0"
20
+ self.pipe = AutoPipelineForText2Image.from_pretrained(self.stable_diffusion_id,
21
+ torch_dtype=dtype,
22
+ safety_checker=None)
23
+ self.pipe.load_lora_weights("Oysiyl/sdxl-lora-android-google-toy", weights="pytorch_lora_weights.safetensors")
24
+ self.pipe.enable_xformers_memory_efficient_attention()
25
+ self.pipe = self.pipe.to(device)
26
+ self.seed = 42
27
+ # Define Generator with seed
28
+ self.generator = torch.Generator(device="cpu").manual_seed(self.seed)
29
+
30
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
31
+ """
32
+ :param data: A dictionary contains `inputs`.
33
+ :return: A dictionary with `image` field contains image in base64.
34
+ """
35
+ prompt = data.pop("inputs", None)
36
+ seed = data.pop("seed", 42)
37
+
38
+ # Check if prompt is not provided
39
+ if prompt is None:
40
+ return {"error": "Please provide a prompt."}
41
+
42
+ # Check if seed changed
43
+ if seed is not None and seed != self.seed:
44
+ print(f"changing seed from {self.seed} to {seed}")
45
+ self.seed = seed
46
+ self.generator = torch.Generator(device="cpu").manual_seed(self.seed)
47
+
48
+
49
+ # hyperparamters
50
+ num_inference_steps = data.pop("num_inference_steps", 50)
51
+ guidance_scale = data.pop("guidance_scale", 7.5)
52
+ temperature = data.pop("temperature", 1.0)
53
+
54
+
55
+ # run inference pipeline
56
+ out = self.pipe(
57
+ prompt=prompt,
58
+ num_inference_steps=num_inference_steps,
59
+ guidance_scale=guidance_scale,
60
+ temperature=temperature,
61
+ num_images_per_prompt=1,
62
+ generator=self.generator
63
+ )
64
+
65
+
66
+ # return first generate PIL image
67
+ return out.images[0]