English
Oysiyl commited on
Commit
61c08a9
·
verified ·
1 Parent(s): 06c9d2f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -12
handler.py CHANGED
@@ -22,10 +22,7 @@ class EndpointHandler():
22
  safety_checker=None)
23
  self.pipe.load_lora_weights("Oysiyl/sdxl-lora-android-google-toy", weights="pytorch_lora_weights.safetensors")
24
  self.pipe.enable_xformers_memory_efficient_attention()
25
- self.pipe = self.pipe.to(device)
26
- self.seed = 42
27
- # Define Generator with seed
28
- self.generator = torch.Generator(device="cpu").manual_seed(self.seed)
29
 
30
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
31
  """
@@ -38,13 +35,8 @@ class EndpointHandler():
38
  # Check if prompt is not provided
39
  if prompt is None:
40
  return {"error": "Please provide a prompt."}
41
-
42
- # Check if seed changed
43
- if seed is not None and seed != self.seed:
44
- print(f"changing seed from {self.seed} to {seed}")
45
- self.seed = seed
46
- self.generator = torch.Generator(device="cpu").manual_seed(self.seed)
47
-
48
 
49
  # hyperparamters
50
  num_inference_steps = data.pop("num_inference_steps", 50)
@@ -59,7 +51,8 @@ class EndpointHandler():
59
  guidance_scale=guidance_scale,
60
  temperature=temperature,
61
  num_images_per_prompt=1,
62
- generator=self.generator
 
63
  )
64
 
65
 
 
22
  safety_checker=None)
23
  self.pipe.load_lora_weights("Oysiyl/sdxl-lora-android-google-toy", weights="pytorch_lora_weights.safetensors")
24
  self.pipe.enable_xformers_memory_efficient_attention()
25
+ self.pipe.to(device)
 
 
 
26
 
27
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
28
  """
 
35
  # Check if prompt is not provided
36
  if prompt is None:
37
  return {"error": "Please provide a prompt."}
38
+
39
+ generator = torch.Generator(device="cpu").manual_seed(self.seed)
 
 
 
 
 
40
 
41
  # hyperparamters
42
  num_inference_steps = data.pop("num_inference_steps", 50)
 
51
  guidance_scale=guidance_scale,
52
  temperature=temperature,
53
  num_images_per_prompt=1,
54
+ seed=seed,
55
+ generator=generator
56
  )
57
 
58