skytnt commited on
Commit
caf6587
·
1 Parent(s): 541aa3f
Files changed (1) hide show
  1. pipeline.py +346 -668
pipeline.py CHANGED
@@ -1,30 +1,57 @@
1
  import inspect
2
  import re
3
- from typing import Any, Callable, Dict, List, Optional, Union
4
 
5
  import numpy as np
6
- import PIL
7
  import torch
8
- from packaging import version
9
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
-
11
- from diffusers import DiffusionPipeline
12
- from diffusers.configuration_utils import FrozenDict
13
- from diffusers.image_processor import VaeImageProcessor
14
- from diffusers.loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
15
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
16
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
17
- from diffusers.schedulers import KarrasDiffusionSchedulers
18
- from diffusers.utils import (
19
- PIL_INTERPOLATION,
20
- deprecate,
21
- is_accelerate_available,
22
- is_accelerate_version,
23
- logging,
24
- randn_tensor,
25
- )
26
-
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # ------------------------------------------------------------------------------
29
 
30
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -135,7 +162,7 @@ def parse_prompt_attention(text):
135
  return res
136
 
137
 
138
- def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
139
  r"""
140
  Tokenize a list of prompts and return its tokens with weights of each token.
141
 
@@ -150,8 +177,8 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
150
  text_weight = []
151
  for word, weight in texts_and_weights:
152
  # tokenize and discard the starting and the ending token
153
- token = pipe.tokenizer(word).input_ids[1:-1]
154
- text_token += token
155
  # copy the weight by length of token
156
  text_weight += [weight] * len(token)
157
  # stop if the text is too long (longer than truncation limit)
@@ -170,14 +197,14 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
170
  return tokens, weights
171
 
172
 
173
- def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
174
  r"""
175
  Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
176
  """
177
  max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
178
  weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
179
  for i in range(len(tokens)):
180
- tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
181
  if no_boseos_middle:
182
  weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
183
  else:
@@ -196,8 +223,8 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos
196
 
197
 
198
  def get_unweighted_text_embeddings(
199
- pipe: DiffusionPipeline,
200
- text_input: torch.Tensor,
201
  chunk_length: int,
202
  no_boseos_middle: Optional[bool] = True,
203
  ):
@@ -210,12 +237,13 @@ def get_unweighted_text_embeddings(
210
  text_embeddings = []
211
  for i in range(max_embeddings_multiples):
212
  # extract the i-th chunk
213
- text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
214
 
215
  # cover the head and the tail by the starting and the ending tokens
216
  text_input_chunk[:, 0] = text_input[0, 0]
217
  text_input_chunk[:, -1] = text_input[0, -1]
218
- text_embedding = pipe.text_encoder(text_input_chunk)[0]
 
219
 
220
  if no_boseos_middle:
221
  if i == 0:
@@ -229,20 +257,21 @@ def get_unweighted_text_embeddings(
229
  text_embedding = text_embedding[:, 1:-1]
230
 
231
  text_embeddings.append(text_embedding)
232
- text_embeddings = torch.concat(text_embeddings, axis=1)
233
  else:
234
- text_embeddings = pipe.text_encoder(text_input)[0]
235
  return text_embeddings
236
 
237
 
238
  def get_weighted_text_embeddings(
239
- pipe: DiffusionPipeline,
240
  prompt: Union[str, List[str]],
241
  uncond_prompt: Optional[Union[str, List[str]]] = None,
242
- max_embeddings_multiples: Optional[int] = 3,
243
  no_boseos_middle: Optional[bool] = False,
244
  skip_parsing: Optional[bool] = False,
245
  skip_weighting: Optional[bool] = False,
 
246
  ):
247
  r"""
248
  Prompts can be assigned with local weights using brackets. For example,
@@ -252,14 +281,14 @@ def get_weighted_text_embeddings(
252
  Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
253
 
254
  Args:
255
- pipe (`DiffusionPipeline`):
256
  Pipe to provide access to the tokenizer and the text encoder.
257
  prompt (`str` or `List[str]`):
258
  The prompt or prompts to guide the image generation.
259
  uncond_prompt (`str` or `List[str]`):
260
  The unconditional prompt or prompts for guide the image generation. If unconditional prompt
261
  is provided, the embeddings of prompt and uncond_prompt are concatenated.
262
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
263
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
264
  no_boseos_middle (`bool`, *optional*, defaults to `False`):
265
  If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
@@ -281,7 +310,8 @@ def get_weighted_text_embeddings(
281
  uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
282
  else:
283
  prompt_tokens = [
284
- token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
 
285
  ]
286
  prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
287
  if uncond_prompt is not None:
@@ -289,7 +319,12 @@ def get_weighted_text_embeddings(
289
  uncond_prompt = [uncond_prompt]
290
  uncond_tokens = [
291
  token[1:-1]
292
- for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
 
 
 
 
 
293
  ]
294
  uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
295
 
@@ -308,18 +343,16 @@ def get_weighted_text_embeddings(
308
  # pad the length of tokens and weights
309
  bos = pipe.tokenizer.bos_token_id
310
  eos = pipe.tokenizer.eos_token_id
311
- pad = getattr(pipe.tokenizer, "pad_token_id", eos)
312
  prompt_tokens, prompt_weights = pad_tokens_and_weights(
313
  prompt_tokens,
314
  prompt_weights,
315
  max_length,
316
  bos,
317
  eos,
318
- pad,
319
  no_boseos_middle=no_boseos_middle,
320
  chunk_length=pipe.tokenizer.model_max_length,
321
  )
322
- prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
323
  if uncond_prompt is not None:
324
  uncond_tokens, uncond_weights = pad_tokens_and_weights(
325
  uncond_tokens,
@@ -327,11 +360,10 @@ def get_weighted_text_embeddings(
327
  max_length,
328
  bos,
329
  eos,
330
- pad,
331
  no_boseos_middle=no_boseos_middle,
332
  chunk_length=pipe.tokenizer.model_max_length,
333
  )
334
- uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
335
 
336
  # get the embeddings
337
  text_embeddings = get_unweighted_text_embeddings(
@@ -340,7 +372,7 @@ def get_weighted_text_embeddings(
340
  pipe.tokenizer.model_max_length,
341
  no_boseos_middle=no_boseos_middle,
342
  )
343
- prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
344
  if uncond_prompt is not None:
345
  uncond_embeddings = get_unweighted_text_embeddings(
346
  pipe,
@@ -348,308 +380,120 @@ def get_weighted_text_embeddings(
348
  pipe.tokenizer.model_max_length,
349
  no_boseos_middle=no_boseos_middle,
350
  )
351
- uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
352
 
353
  # assign weights to the prompts and normalize in the sense of mean
354
  # TODO: should we normalize by chunk or in a whole (current implementation)?
355
  if (not skip_parsing) and (not skip_weighting):
356
- previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
357
- text_embeddings *= prompt_weights.unsqueeze(-1)
358
- current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
359
- text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
360
  if uncond_prompt is not None:
361
- previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
362
- uncond_embeddings *= uncond_weights.unsqueeze(-1)
363
- current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
364
- uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
365
 
 
 
 
366
  if uncond_prompt is not None:
367
  return text_embeddings, uncond_embeddings
368
- return text_embeddings, None
 
369
 
370
 
371
- def preprocess_image(image, batch_size):
372
  w, h = image.size
373
- w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
374
  image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
375
  image = np.array(image).astype(np.float32) / 255.0
376
- image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
377
- image = torch.from_numpy(image)
378
  return 2.0 * image - 1.0
379
 
380
 
381
- def preprocess_mask(mask, batch_size, scale_factor=8):
382
- if not isinstance(mask, torch.FloatTensor):
383
- mask = mask.convert("L")
384
- w, h = mask.size
385
- w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
386
- mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
387
- mask = np.array(mask).astype(np.float32) / 255.0
388
- mask = np.tile(mask, (4, 1, 1))
389
- mask = np.vstack([mask[None]] * batch_size)
390
- mask = 1 - mask # repaint white, keep black
391
- mask = torch.from_numpy(mask)
392
- return mask
393
-
394
- else:
395
- valid_mask_channel_sizes = [1, 3]
396
- # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
397
- if mask.shape[3] in valid_mask_channel_sizes:
398
- mask = mask.permute(0, 3, 1, 2)
399
- elif mask.shape[1] not in valid_mask_channel_sizes:
400
- raise ValueError(
401
- f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
402
- f" but received mask of shape {tuple(mask.shape)}"
403
- )
404
- # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
405
- mask = mask.mean(dim=1, keepdim=True)
406
- h, w = mask.shape[-2:]
407
- h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
408
- mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
409
- return mask
410
 
411
 
412
- class StableDiffusionLongPromptWeightingPipeline(
413
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
414
- ):
415
  r"""
416
  Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
417
  weighting in prompt.
418
 
419
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
420
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
421
-
422
- Args:
423
- vae ([`AutoencoderKL`]):
424
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
425
- text_encoder ([`CLIPTextModel`]):
426
- Frozen text-encoder. Stable Diffusion uses the text portion of
427
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
428
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
429
- tokenizer (`CLIPTokenizer`):
430
- Tokenizer of class
431
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
432
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
433
- scheduler ([`SchedulerMixin`]):
434
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
435
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
436
- safety_checker ([`StableDiffusionSafetyChecker`]):
437
- Classification module that estimates whether generated images could be considered offensive or harmful.
438
- Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
439
- feature_extractor ([`CLIPImageProcessor`]):
440
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
441
  """
442
-
443
- _optional_components = ["safety_checker", "feature_extractor"]
444
-
445
- def __init__(
446
- self,
447
- vae: AutoencoderKL,
448
- text_encoder: CLIPTextModel,
449
- tokenizer: CLIPTokenizer,
450
- unet: UNet2DConditionModel,
451
- scheduler: KarrasDiffusionSchedulers,
452
- safety_checker: StableDiffusionSafetyChecker,
453
- feature_extractor: CLIPImageProcessor,
454
- requires_safety_checker: bool = True,
455
- ):
456
- super().__init__()
457
-
458
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
459
- deprecation_message = (
460
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
461
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
462
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
463
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
464
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
465
- " file"
466
- )
467
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
468
- new_config = dict(scheduler.config)
469
- new_config["steps_offset"] = 1
470
- scheduler._internal_dict = FrozenDict(new_config)
471
-
472
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
473
- deprecation_message = (
474
- f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
475
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
476
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
477
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
478
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
479
- )
480
- deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
481
- new_config = dict(scheduler.config)
482
- new_config["clip_sample"] = False
483
- scheduler._internal_dict = FrozenDict(new_config)
484
-
485
- if safety_checker is None and requires_safety_checker:
486
- logger.warning(
487
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
488
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
489
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
490
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
491
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
492
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
493
  )
 
494
 
495
- if safety_checker is not None and feature_extractor is None:
496
- raise ValueError(
497
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
498
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
499
- )
500
 
501
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
502
- version.parse(unet.config._diffusers_version).base_version
503
- ) < version.parse("0.9.0.dev0")
504
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
505
- if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
506
- deprecation_message = (
507
- "The configuration file of the unet has set the default `sample_size` to smaller than"
508
- " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
509
- " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
510
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
511
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
512
- " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
513
- " in the config might lead to incorrect results in future versions. If you have downloaded this"
514
- " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
515
- " the `unet/config.json` file"
 
 
 
 
 
516
  )
517
- deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
518
- new_config = dict(unet.config)
519
- new_config["sample_size"] = 64
520
- unet._internal_dict = FrozenDict(new_config)
521
- self.register_modules(
522
- vae=vae,
523
- text_encoder=text_encoder,
524
- tokenizer=tokenizer,
525
- unet=unet,
526
- scheduler=scheduler,
527
- safety_checker=safety_checker,
528
- feature_extractor=feature_extractor,
529
- )
530
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
531
 
532
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
533
- self.register_to_config(
534
- requires_safety_checker=requires_safety_checker,
535
- )
536
-
537
- def enable_vae_slicing(self):
538
- r"""
539
- Enable sliced VAE decoding.
540
-
541
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
542
- steps. This is useful to save some memory and allow larger batch sizes.
543
- """
544
- self.vae.enable_slicing()
545
-
546
- def disable_vae_slicing(self):
547
- r"""
548
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
549
- computing decoding in one step.
550
- """
551
- self.vae.disable_slicing()
552
-
553
- def enable_vae_tiling(self):
554
- r"""
555
- Enable tiled VAE decoding.
556
-
557
- When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
558
- several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
559
- """
560
- self.vae.enable_tiling()
561
-
562
- def disable_vae_tiling(self):
563
- r"""
564
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
565
- computing decoding in one step.
566
- """
567
- self.vae.disable_tiling()
568
-
569
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
570
- def enable_sequential_cpu_offload(self, gpu_id=0):
571
- r"""
572
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
573
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
574
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
575
- Note that offloading happens on a submodule basis. Memory savings are higher than with
576
- `enable_model_cpu_offload`, but performance is lower.
577
- """
578
- if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
579
- from accelerate import cpu_offload
580
- else:
581
- raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
582
-
583
- device = torch.device(f"cuda:{gpu_id}")
584
-
585
- if self.device.type != "cpu":
586
- self.to("cpu", silence_dtype_warnings=True)
587
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
588
-
589
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
590
- cpu_offload(cpu_offloaded_model, device)
591
-
592
- if self.safety_checker is not None:
593
- cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
594
-
595
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
596
- def enable_model_cpu_offload(self, gpu_id=0):
597
- r"""
598
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
599
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
600
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
601
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
602
- """
603
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
604
- from accelerate import cpu_offload_with_hook
605
- else:
606
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
607
-
608
- device = torch.device(f"cuda:{gpu_id}")
609
-
610
- if self.device.type != "cpu":
611
- self.to("cpu", silence_dtype_warnings=True)
612
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
613
-
614
- hook = None
615
- for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
616
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
617
-
618
- if self.safety_checker is not None:
619
- _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
620
-
621
- # We'll offload the last model manually.
622
- self.final_offload_hook = hook
623
-
624
- @property
625
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
626
- def _execution_device(self):
627
- r"""
628
- Returns the device on which the pipeline's models will be executed. After calling
629
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
630
- hooks.
631
- """
632
- if not hasattr(self.unet, "_hf_hook"):
633
- return self.device
634
- for module in self.unet.modules():
635
- if (
636
- hasattr(module, "_hf_hook")
637
- and hasattr(module._hf_hook, "execution_device")
638
- and module._hf_hook.execution_device is not None
639
- ):
640
- return torch.device(module._hf_hook.execution_device)
641
- return self.device
642
 
643
  def _encode_prompt(
644
  self,
645
  prompt,
646
- device,
647
  num_images_per_prompt,
648
  do_classifier_free_guidance,
649
- negative_prompt=None,
650
- max_embeddings_multiples=3,
651
- prompt_embeds: Optional[torch.FloatTensor] = None,
652
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
653
  ):
654
  r"""
655
  Encodes the prompt into text encoder hidden states.
@@ -657,8 +501,6 @@ class StableDiffusionLongPromptWeightingPipeline(
657
  Args:
658
  prompt (`str` or `list(int)`):
659
  prompt to be encoded
660
- device: (`torch.device`):
661
- torch device
662
  num_images_per_prompt (`int`):
663
  number of images that should be generated per prompt
664
  do_classifier_free_guidance (`bool`):
@@ -669,71 +511,43 @@ class StableDiffusionLongPromptWeightingPipeline(
669
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
670
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
671
  """
672
- if prompt is not None and isinstance(prompt, str):
673
- batch_size = 1
674
- elif prompt is not None and isinstance(prompt, list):
675
- batch_size = len(prompt)
676
- else:
677
- batch_size = prompt_embeds.shape[0]
678
-
679
- if negative_prompt_embeds is None:
680
- if negative_prompt is None:
681
- negative_prompt = [""] * batch_size
682
- elif isinstance(negative_prompt, str):
683
- negative_prompt = [negative_prompt] * batch_size
684
- if batch_size != len(negative_prompt):
685
- raise ValueError(
686
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
687
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
688
- " the batch size of `prompt`."
689
- )
690
- if prompt_embeds is None or negative_prompt_embeds is None:
691
- if isinstance(self, TextualInversionLoaderMixin):
692
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
693
- if do_classifier_free_guidance and negative_prompt_embeds is None:
694
- negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
695
-
696
- prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
697
- pipe=self,
698
- prompt=prompt,
699
- uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
700
- max_embeddings_multiples=max_embeddings_multiples,
701
  )
702
- if prompt_embeds is None:
703
- prompt_embeds = prompt_embeds1
704
- if negative_prompt_embeds is None:
705
- negative_prompt_embeds = negative_prompt_embeds1
706
 
707
- bs_embed, seq_len, _ = prompt_embeds.shape
708
- # duplicate text embeddings for each generation per prompt, using mps friendly method
709
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
710
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
 
711
 
 
712
  if do_classifier_free_guidance:
713
- bs_embed, seq_len, _ = negative_prompt_embeds.shape
714
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
715
- negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
716
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
717
 
718
- return prompt_embeds
719
 
720
- def check_inputs(
721
- self,
722
- prompt,
723
- height,
724
- width,
725
- strength,
726
- callback_steps,
727
- negative_prompt=None,
728
- prompt_embeds=None,
729
- negative_prompt_embeds=None,
730
- ):
731
- if height % 8 != 0 or width % 8 != 0:
732
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
733
 
734
  if strength < 0 or strength > 1:
735
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
736
 
 
 
 
737
  if (callback_steps is None) or (
738
  callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
739
  ):
@@ -742,60 +556,46 @@ class StableDiffusionLongPromptWeightingPipeline(
742
  f" {type(callback_steps)}."
743
  )
744
 
745
- if prompt is not None and prompt_embeds is not None:
746
- raise ValueError(
747
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
748
- " only forward one of the two."
749
- )
750
- elif prompt is None and prompt_embeds is None:
751
- raise ValueError(
752
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
753
- )
754
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
755
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
756
-
757
- if negative_prompt is not None and negative_prompt_embeds is not None:
758
- raise ValueError(
759
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
760
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
761
- )
762
-
763
- if prompt_embeds is not None and negative_prompt_embeds is not None:
764
- if prompt_embeds.shape != negative_prompt_embeds.shape:
765
- raise ValueError(
766
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
767
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
768
- f" {negative_prompt_embeds.shape}."
769
- )
770
-
771
- def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
772
  if is_text2img:
773
- return self.scheduler.timesteps.to(device), num_inference_steps
774
  else:
775
  # get the original timestep using init_timestep
776
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
777
-
778
- t_start = max(num_inference_steps - init_timestep, 0)
779
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
780
 
 
 
781
  return timesteps, num_inference_steps - t_start
782
 
783
- def run_safety_checker(self, image, device, dtype):
784
  if self.safety_checker is not None:
785
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
786
- image, has_nsfw_concept = self.safety_checker(
787
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
788
- )
 
 
 
 
 
 
 
 
789
  else:
790
  has_nsfw_concept = None
791
  return image, has_nsfw_concept
792
 
793
  def decode_latents(self, latents):
794
- latents = 1 / self.vae.config.scaling_factor * latents
795
- image = self.vae.decode(latents).sample
796
- image = (image / 2 + 0.5).clamp(0, 1)
797
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
798
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
 
 
 
799
  return image
800
 
801
  def prepare_extra_step_kwargs(self, generator, eta):
@@ -815,51 +615,36 @@ class StableDiffusionLongPromptWeightingPipeline(
815
  extra_step_kwargs["generator"] = generator
816
  return extra_step_kwargs
817
 
818
- def prepare_latents(
819
- self,
820
- image,
821
- timestep,
822
- num_images_per_prompt,
823
- batch_size,
824
- num_channels_latents,
825
- height,
826
- width,
827
- dtype,
828
- device,
829
- generator,
830
- latents=None,
831
- ):
832
  if image is None:
833
- batch_size = batch_size * num_images_per_prompt
834
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
835
- if isinstance(generator, list) and len(generator) != batch_size:
836
- raise ValueError(
837
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
838
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
839
- )
840
 
841
  if latents is None:
842
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
843
  else:
844
- latents = latents.to(device)
 
845
 
846
  # scale the initial noise by the standard deviation required by the scheduler
847
- latents = latents * self.scheduler.init_noise_sigma
848
  return latents, None, None
849
  else:
850
- image = image.to(device=self.device, dtype=dtype)
851
- init_latent_dist = self.vae.encode(image).latent_dist
852
- init_latents = init_latent_dist.sample(generator=generator)
853
- init_latents = self.vae.config.scaling_factor * init_latents
854
-
855
- # Expand init_latents for batch_size and num_images_per_prompt
856
- init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
857
  init_latents_orig = init_latents
 
858
 
859
  # add noise to latents using the timesteps
860
- noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
861
- init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
862
- latents = init_latents
 
863
  return latents, init_latents_orig, noise
864
 
865
  @torch.no_grad()
@@ -867,27 +652,24 @@ class StableDiffusionLongPromptWeightingPipeline(
867
  self,
868
  prompt: Union[str, List[str]],
869
  negative_prompt: Optional[Union[str, List[str]]] = None,
870
- image: Union[torch.FloatTensor, PIL.Image.Image] = None,
871
- mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
872
  height: int = 512,
873
  width: int = 512,
874
  num_inference_steps: int = 50,
875
  guidance_scale: float = 7.5,
876
  strength: float = 0.8,
877
  num_images_per_prompt: Optional[int] = 1,
878
- add_predicted_noise: Optional[bool] = False,
879
  eta: float = 0.0,
880
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
881
- latents: Optional[torch.FloatTensor] = None,
882
- prompt_embeds: Optional[torch.FloatTensor] = None,
883
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
884
  max_embeddings_multiples: Optional[int] = 3,
885
  output_type: Optional[str] = "pil",
886
  return_dict: bool = True,
887
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
888
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
889
- callback_steps: int = 1,
890
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
891
  ):
892
  r"""
893
  Function invoked when calling the pipeline for generation.
@@ -898,10 +680,10 @@ class StableDiffusionLongPromptWeightingPipeline(
898
  negative_prompt (`str` or `List[str]`, *optional*):
899
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
900
  if `guidance_scale` is less than `1`).
901
- image (`torch.FloatTensor` or `PIL.Image.Image`):
902
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
903
  process.
904
- mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
905
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
906
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
907
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
@@ -927,26 +709,16 @@ class StableDiffusionLongPromptWeightingPipeline(
927
  `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
928
  num_images_per_prompt (`int`, *optional*, defaults to 1):
929
  The number of images to generate per prompt.
930
- add_predicted_noise (`bool`, *optional*, defaults to True):
931
- Use predicted noise instead of random noise when constructing noisy versions of the original image in
932
- the reverse diffusion process
933
  eta (`float`, *optional*, defaults to 0.0):
934
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
935
  [`schedulers.DDIMScheduler`], will be ignored for others.
936
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
937
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
938
- to make generation deterministic.
939
- latents (`torch.FloatTensor`, *optional*):
940
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
941
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
942
  tensor will ge generated by sampling using the supplied random `generator`.
943
- prompt_embeds (`torch.FloatTensor`, *optional*):
944
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
945
- provided, text embeddings will be generated from `prompt` input argument.
946
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
947
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
948
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
949
- argument.
950
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
951
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
952
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -957,17 +729,13 @@ class StableDiffusionLongPromptWeightingPipeline(
957
  plain tuple.
958
  callback (`Callable`, *optional*):
959
  A function that will be called every `callback_steps` steps during inference. The function will be
960
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
961
  is_cancelled_callback (`Callable`, *optional*):
962
  A function that will be called every `callback_steps` steps during inference. If the function returns
963
  `True`, the inference will be cancelled.
964
  callback_steps (`int`, *optional*, defaults to 1):
965
  The frequency at which the `callback` function will be called. If not specified, the callback will be
966
  called at every step.
967
- cross_attention_kwargs (`dict`, *optional*):
968
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
969
- `self.processor` in
970
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
971
 
972
  Returns:
973
  `None` if cancelled by `is_cancelled_callback`,
@@ -977,71 +745,64 @@ class StableDiffusionLongPromptWeightingPipeline(
977
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
978
  (nsfw) content, according to the `safety_checker`.
979
  """
 
 
 
 
980
  # 0. Default height and width to unet
981
  height = height or self.unet.config.sample_size * self.vae_scale_factor
982
  width = width or self.unet.config.sample_size * self.vae_scale_factor
983
 
984
  # 1. Check inputs. Raise error if not correct
985
- self.check_inputs(
986
- prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
987
- )
988
 
989
  # 2. Define call parameters
990
- if prompt is not None and isinstance(prompt, str):
991
- batch_size = 1
992
- elif prompt is not None and isinstance(prompt, list):
993
- batch_size = len(prompt)
994
- else:
995
- batch_size = prompt_embeds.shape[0]
996
-
997
- device = self._execution_device
998
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
999
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1000
  # corresponds to doing no classifier free guidance.
1001
  do_classifier_free_guidance = guidance_scale > 1.0
1002
 
1003
  # 3. Encode input prompt
1004
- prompt_embeds = self._encode_prompt(
1005
  prompt,
1006
- device,
1007
  num_images_per_prompt,
1008
  do_classifier_free_guidance,
1009
  negative_prompt,
1010
  max_embeddings_multiples,
1011
- prompt_embeds=prompt_embeds,
1012
- negative_prompt_embeds=negative_prompt_embeds,
1013
  )
1014
- dtype = prompt_embeds.dtype
1015
 
1016
  # 4. Preprocess image and mask
1017
  if isinstance(image, PIL.Image.Image):
1018
- image = preprocess_image(image, batch_size)
1019
  if image is not None:
1020
- image = image.to(device=self.device, dtype=dtype)
1021
  if isinstance(mask_image, PIL.Image.Image):
1022
- mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
1023
  if mask_image is not None:
1024
- mask = mask_image.to(device=self.device, dtype=dtype)
1025
- mask = torch.cat([mask] * num_images_per_prompt)
1026
  else:
1027
  mask = None
1028
 
1029
  # 5. set timesteps
1030
- self.scheduler.set_timesteps(num_inference_steps, device=device)
1031
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
 
 
 
 
1032
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1033
 
1034
  # 6. Prepare latent variables
1035
  latents, init_latents_orig, noise = self.prepare_latents(
1036
  image,
1037
  latent_timestep,
1038
- num_images_per_prompt,
1039
- batch_size,
1040
- self.unet.config.in_channels,
1041
  height,
1042
  width,
1043
  dtype,
1044
- device,
1045
  generator,
1046
  latents,
1047
  )
@@ -1050,70 +811,56 @@ class StableDiffusionLongPromptWeightingPipeline(
1050
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1051
 
1052
  # 8. Denoising loop
1053
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1054
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1055
- for i, t in enumerate(timesteps):
1056
- # expand the latents if we are doing classifier free guidance
1057
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1058
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1059
-
1060
- # predict the noise residual
1061
- noise_pred = self.unet(
1062
- latent_model_input,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1063
  t,
1064
- encoder_hidden_states=prompt_embeds,
1065
- cross_attention_kwargs=cross_attention_kwargs,
1066
- ).sample
1067
-
1068
- # perform guidance
1069
- if do_classifier_free_guidance:
1070
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1071
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1072
-
1073
- # compute the previous noisy sample x_t -> x_t-1
1074
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1075
-
1076
- if mask is not None:
1077
- # masking
1078
- if add_predicted_noise:
1079
- init_latents_proper = self.scheduler.add_noise(
1080
- init_latents_orig, noise_pred_uncond, torch.tensor([t])
1081
- )
1082
- else:
1083
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
1084
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
1085
-
1086
- # call the callback, if provided
1087
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1088
- progress_bar.update()
1089
- if i % callback_steps == 0:
1090
- if callback is not None:
1091
- callback(i, t, latents)
1092
- if is_cancelled_callback is not None and is_cancelled_callback():
1093
- return None
1094
-
1095
- if output_type == "latent":
1096
- image = latents
1097
- has_nsfw_concept = None
1098
- elif output_type == "pil":
1099
- # 9. Post-processing
1100
- image = self.decode_latents(latents)
1101
 
1102
- # 10. Run safety checker
1103
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
 
 
 
 
1104
 
1105
- # 11. Convert to PIL
1106
- image = self.numpy_to_pil(image)
1107
- else:
1108
- # 9. Post-processing
1109
- image = self.decode_latents(latents)
1110
 
1111
- # 10. Run safety checker
1112
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1113
 
1114
- # Offload last model to CPU
1115
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1116
- self.final_offload_hook.offload()
1117
 
1118
  if not return_dict:
1119
  return image, has_nsfw_concept
@@ -1130,17 +877,14 @@ class StableDiffusionLongPromptWeightingPipeline(
1130
  guidance_scale: float = 7.5,
1131
  num_images_per_prompt: Optional[int] = 1,
1132
  eta: float = 0.0,
1133
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1134
- latents: Optional[torch.FloatTensor] = None,
1135
- prompt_embeds: Optional[torch.FloatTensor] = None,
1136
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1137
  max_embeddings_multiples: Optional[int] = 3,
1138
  output_type: Optional[str] = "pil",
1139
  return_dict: bool = True,
1140
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1141
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
1142
- callback_steps: int = 1,
1143
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1144
  ):
1145
  r"""
1146
  Function for text-to-image generation.
@@ -1168,20 +912,13 @@ class StableDiffusionLongPromptWeightingPipeline(
1168
  eta (`float`, *optional*, defaults to 0.0):
1169
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1170
  [`schedulers.DDIMScheduler`], will be ignored for others.
1171
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1172
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1173
- to make generation deterministic.
1174
- latents (`torch.FloatTensor`, *optional*):
1175
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1176
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1177
  tensor will ge generated by sampling using the supplied random `generator`.
1178
- prompt_embeds (`torch.FloatTensor`, *optional*):
1179
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1180
- provided, text embeddings will be generated from `prompt` input argument.
1181
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1182
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1183
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1184
- argument.
1185
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1186
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1187
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -1192,20 +929,11 @@ class StableDiffusionLongPromptWeightingPipeline(
1192
  plain tuple.
1193
  callback (`Callable`, *optional*):
1194
  A function that will be called every `callback_steps` steps during inference. The function will be
1195
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1196
- is_cancelled_callback (`Callable`, *optional*):
1197
- A function that will be called every `callback_steps` steps during inference. If the function returns
1198
- `True`, the inference will be cancelled.
1199
  callback_steps (`int`, *optional*, defaults to 1):
1200
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1201
  called at every step.
1202
- cross_attention_kwargs (`dict`, *optional*):
1203
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1204
- `self.processor` in
1205
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1206
-
1207
  Returns:
1208
- `None` if cancelled by `is_cancelled_callback`,
1209
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1210
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1211
  When returning a tuple, the first element is a list with the generated images, and the second element is a
@@ -1223,20 +951,17 @@ class StableDiffusionLongPromptWeightingPipeline(
1223
  eta=eta,
1224
  generator=generator,
1225
  latents=latents,
1226
- prompt_embeds=prompt_embeds,
1227
- negative_prompt_embeds=negative_prompt_embeds,
1228
  max_embeddings_multiples=max_embeddings_multiples,
1229
  output_type=output_type,
1230
  return_dict=return_dict,
1231
  callback=callback,
1232
- is_cancelled_callback=is_cancelled_callback,
1233
  callback_steps=callback_steps,
1234
- cross_attention_kwargs=cross_attention_kwargs,
1235
  )
1236
 
1237
  def img2img(
1238
  self,
1239
- image: Union[torch.FloatTensor, PIL.Image.Image],
1240
  prompt: Union[str, List[str]],
1241
  negative_prompt: Optional[Union[str, List[str]]] = None,
1242
  strength: float = 0.8,
@@ -1244,22 +969,19 @@ class StableDiffusionLongPromptWeightingPipeline(
1244
  guidance_scale: Optional[float] = 7.5,
1245
  num_images_per_prompt: Optional[int] = 1,
1246
  eta: Optional[float] = 0.0,
1247
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1248
- prompt_embeds: Optional[torch.FloatTensor] = None,
1249
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1250
  max_embeddings_multiples: Optional[int] = 3,
1251
  output_type: Optional[str] = "pil",
1252
  return_dict: bool = True,
1253
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1254
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
1255
- callback_steps: int = 1,
1256
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1257
  ):
1258
  r"""
1259
  Function for image-to-image generation.
1260
  Args:
1261
- image (`torch.FloatTensor` or `PIL.Image.Image`):
1262
- `Image`, or tensor representing an image batch, that will be used as the starting point for the
1263
  process.
1264
  prompt (`str` or `List[str]`):
1265
  The prompt or prompts to guide the image generation.
@@ -1286,16 +1008,9 @@ class StableDiffusionLongPromptWeightingPipeline(
1286
  eta (`float`, *optional*, defaults to 0.0):
1287
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1288
  [`schedulers.DDIMScheduler`], will be ignored for others.
1289
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1290
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1291
- to make generation deterministic.
1292
- prompt_embeds (`torch.FloatTensor`, *optional*):
1293
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1294
- provided, text embeddings will be generated from `prompt` input argument.
1295
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1296
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1297
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1298
- argument.
1299
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1300
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1301
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -1306,20 +1021,12 @@ class StableDiffusionLongPromptWeightingPipeline(
1306
  plain tuple.
1307
  callback (`Callable`, *optional*):
1308
  A function that will be called every `callback_steps` steps during inference. The function will be
1309
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1310
- is_cancelled_callback (`Callable`, *optional*):
1311
- A function that will be called every `callback_steps` steps during inference. If the function returns
1312
- `True`, the inference will be cancelled.
1313
  callback_steps (`int`, *optional*, defaults to 1):
1314
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1315
  called at every step.
1316
- cross_attention_kwargs (`dict`, *optional*):
1317
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1318
- `self.processor` in
1319
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1320
-
1321
  Returns:
1322
- `None` if cancelled by `is_cancelled_callback`,
1323
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1324
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1325
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
@@ -1335,47 +1042,40 @@ class StableDiffusionLongPromptWeightingPipeline(
1335
  num_images_per_prompt=num_images_per_prompt,
1336
  eta=eta,
1337
  generator=generator,
1338
- prompt_embeds=prompt_embeds,
1339
- negative_prompt_embeds=negative_prompt_embeds,
1340
  max_embeddings_multiples=max_embeddings_multiples,
1341
  output_type=output_type,
1342
  return_dict=return_dict,
1343
  callback=callback,
1344
- is_cancelled_callback=is_cancelled_callback,
1345
  callback_steps=callback_steps,
1346
- cross_attention_kwargs=cross_attention_kwargs,
1347
  )
1348
 
1349
  def inpaint(
1350
  self,
1351
- image: Union[torch.FloatTensor, PIL.Image.Image],
1352
- mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1353
  prompt: Union[str, List[str]],
1354
  negative_prompt: Optional[Union[str, List[str]]] = None,
1355
  strength: float = 0.8,
1356
  num_inference_steps: Optional[int] = 50,
1357
  guidance_scale: Optional[float] = 7.5,
1358
  num_images_per_prompt: Optional[int] = 1,
1359
- add_predicted_noise: Optional[bool] = False,
1360
  eta: Optional[float] = 0.0,
1361
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1362
- prompt_embeds: Optional[torch.FloatTensor] = None,
1363
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1364
  max_embeddings_multiples: Optional[int] = 3,
1365
  output_type: Optional[str] = "pil",
1366
  return_dict: bool = True,
1367
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1368
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
1369
- callback_steps: int = 1,
1370
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1371
  ):
1372
  r"""
1373
  Function for inpaint.
1374
  Args:
1375
- image (`torch.FloatTensor` or `PIL.Image.Image`):
1376
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
1377
  process. This is the image whose masked region will be inpainted.
1378
- mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1379
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1380
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1381
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
@@ -1401,22 +1101,12 @@ class StableDiffusionLongPromptWeightingPipeline(
1401
  usually at the expense of lower image quality.
1402
  num_images_per_prompt (`int`, *optional*, defaults to 1):
1403
  The number of images to generate per prompt.
1404
- add_predicted_noise (`bool`, *optional*, defaults to True):
1405
- Use predicted noise instead of random noise when constructing noisy versions of the original image in
1406
- the reverse diffusion process
1407
  eta (`float`, *optional*, defaults to 0.0):
1408
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1409
  [`schedulers.DDIMScheduler`], will be ignored for others.
1410
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1411
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1412
- to make generation deterministic.
1413
- prompt_embeds (`torch.FloatTensor`, *optional*):
1414
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1415
- provided, text embeddings will be generated from `prompt` input argument.
1416
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1417
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1418
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1419
- argument.
1420
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1421
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1422
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -1427,20 +1117,12 @@ class StableDiffusionLongPromptWeightingPipeline(
1427
  plain tuple.
1428
  callback (`Callable`, *optional*):
1429
  A function that will be called every `callback_steps` steps during inference. The function will be
1430
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1431
- is_cancelled_callback (`Callable`, *optional*):
1432
- A function that will be called every `callback_steps` steps during inference. If the function returns
1433
- `True`, the inference will be cancelled.
1434
  callback_steps (`int`, *optional*, defaults to 1):
1435
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1436
  called at every step.
1437
- cross_attention_kwargs (`dict`, *optional*):
1438
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1439
- `self.processor` in
1440
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1441
-
1442
  Returns:
1443
- `None` if cancelled by `is_cancelled_callback`,
1444
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1445
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1446
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
@@ -1455,16 +1137,12 @@ class StableDiffusionLongPromptWeightingPipeline(
1455
  guidance_scale=guidance_scale,
1456
  strength=strength,
1457
  num_images_per_prompt=num_images_per_prompt,
1458
- add_predicted_noise=add_predicted_noise,
1459
  eta=eta,
1460
  generator=generator,
1461
- prompt_embeds=prompt_embeds,
1462
- negative_prompt_embeds=negative_prompt_embeds,
1463
  max_embeddings_multiples=max_embeddings_multiples,
1464
  output_type=output_type,
1465
  return_dict=return_dict,
1466
  callback=callback,
1467
- is_cancelled_callback=is_cancelled_callback,
1468
  callback_steps=callback_steps,
1469
- cross_attention_kwargs=cross_attention_kwargs,
1470
- )
 
1
  import inspect
2
  import re
3
+ from typing import Callable, List, Optional, Union
4
 
5
  import numpy as np
 
6
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ import diffusers
9
+ import PIL
10
+ from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
11
+ from diffusers.onnx_utils import OnnxRuntimeModel
12
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
+ from diffusers.utils import deprecate, logging
14
+ from packaging import version
15
+ from transformers import CLIPFeatureExtractor, CLIPTokenizer
16
+
17
+
18
+ try:
19
+ from diffusers.onnx_utils import ORT_TO_NP_TYPE
20
+ except ImportError:
21
+ ORT_TO_NP_TYPE = {
22
+ "tensor(bool)": np.bool_,
23
+ "tensor(int8)": np.int8,
24
+ "tensor(uint8)": np.uint8,
25
+ "tensor(int16)": np.int16,
26
+ "tensor(uint16)": np.uint16,
27
+ "tensor(int32)": np.int32,
28
+ "tensor(uint32)": np.uint32,
29
+ "tensor(int64)": np.int64,
30
+ "tensor(uint64)": np.uint64,
31
+ "tensor(float16)": np.float16,
32
+ "tensor(float)": np.float32,
33
+ "tensor(double)": np.float64,
34
+ }
35
+
36
+ try:
37
+ from diffusers.utils import PIL_INTERPOLATION
38
+ except ImportError:
39
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
40
+ PIL_INTERPOLATION = {
41
+ "linear": PIL.Image.Resampling.BILINEAR,
42
+ "bilinear": PIL.Image.Resampling.BILINEAR,
43
+ "bicubic": PIL.Image.Resampling.BICUBIC,
44
+ "lanczos": PIL.Image.Resampling.LANCZOS,
45
+ "nearest": PIL.Image.Resampling.NEAREST,
46
+ }
47
+ else:
48
+ PIL_INTERPOLATION = {
49
+ "linear": PIL.Image.LINEAR,
50
+ "bilinear": PIL.Image.BILINEAR,
51
+ "bicubic": PIL.Image.BICUBIC,
52
+ "lanczos": PIL.Image.LANCZOS,
53
+ "nearest": PIL.Image.NEAREST,
54
+ }
55
  # ------------------------------------------------------------------------------
56
 
57
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
162
  return res
163
 
164
 
165
+ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
166
  r"""
167
  Tokenize a list of prompts and return its tokens with weights of each token.
168
 
 
177
  text_weight = []
178
  for word, weight in texts_and_weights:
179
  # tokenize and discard the starting and the ending token
180
+ token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
181
+ text_token += list(token)
182
  # copy the weight by length of token
183
  text_weight += [weight] * len(token)
184
  # stop if the text is too long (longer than truncation limit)
 
197
  return tokens, weights
198
 
199
 
200
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
201
  r"""
202
  Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
203
  """
204
  max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
205
  weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
206
  for i in range(len(tokens)):
207
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
208
  if no_boseos_middle:
209
  weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
210
  else:
 
223
 
224
 
225
  def get_unweighted_text_embeddings(
226
+ pipe,
227
+ text_input: np.array,
228
  chunk_length: int,
229
  no_boseos_middle: Optional[bool] = True,
230
  ):
 
237
  text_embeddings = []
238
  for i in range(max_embeddings_multiples):
239
  # extract the i-th chunk
240
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy()
241
 
242
  # cover the head and the tail by the starting and the ending tokens
243
  text_input_chunk[:, 0] = text_input[0, 0]
244
  text_input_chunk[:, -1] = text_input[0, -1]
245
+
246
+ text_embedding = pipe.text_encoder(input_ids=text_input_chunk)[0]
247
 
248
  if no_boseos_middle:
249
  if i == 0:
 
257
  text_embedding = text_embedding[:, 1:-1]
258
 
259
  text_embeddings.append(text_embedding)
260
+ text_embeddings = np.concatenate(text_embeddings, axis=1)
261
  else:
262
+ text_embeddings = pipe.text_encoder(input_ids=text_input)[0]
263
  return text_embeddings
264
 
265
 
266
  def get_weighted_text_embeddings(
267
+ pipe,
268
  prompt: Union[str, List[str]],
269
  uncond_prompt: Optional[Union[str, List[str]]] = None,
270
+ max_embeddings_multiples: Optional[int] = 4,
271
  no_boseos_middle: Optional[bool] = False,
272
  skip_parsing: Optional[bool] = False,
273
  skip_weighting: Optional[bool] = False,
274
+ **kwargs,
275
  ):
276
  r"""
277
  Prompts can be assigned with local weights using brackets. For example,
 
281
  Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
282
 
283
  Args:
284
+ pipe (`OnnxStableDiffusionPipeline`):
285
  Pipe to provide access to the tokenizer and the text encoder.
286
  prompt (`str` or `List[str]`):
287
  The prompt or prompts to guide the image generation.
288
  uncond_prompt (`str` or `List[str]`):
289
  The unconditional prompt or prompts for guide the image generation. If unconditional prompt
290
  is provided, the embeddings of prompt and uncond_prompt are concatenated.
291
+ max_embeddings_multiples (`int`, *optional*, defaults to `1`):
292
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
293
  no_boseos_middle (`bool`, *optional*, defaults to `False`):
294
  If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
 
310
  uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
311
  else:
312
  prompt_tokens = [
313
+ token[1:-1]
314
+ for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids
315
  ]
316
  prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
317
  if uncond_prompt is not None:
 
319
  uncond_prompt = [uncond_prompt]
320
  uncond_tokens = [
321
  token[1:-1]
322
+ for token in pipe.tokenizer(
323
+ uncond_prompt,
324
+ max_length=max_length,
325
+ truncation=True,
326
+ return_tensors="np",
327
+ ).input_ids
328
  ]
329
  uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
330
 
 
343
  # pad the length of tokens and weights
344
  bos = pipe.tokenizer.bos_token_id
345
  eos = pipe.tokenizer.eos_token_id
 
346
  prompt_tokens, prompt_weights = pad_tokens_and_weights(
347
  prompt_tokens,
348
  prompt_weights,
349
  max_length,
350
  bos,
351
  eos,
 
352
  no_boseos_middle=no_boseos_middle,
353
  chunk_length=pipe.tokenizer.model_max_length,
354
  )
355
+ prompt_tokens = np.array(prompt_tokens, dtype=np.int32)
356
  if uncond_prompt is not None:
357
  uncond_tokens, uncond_weights = pad_tokens_and_weights(
358
  uncond_tokens,
 
360
  max_length,
361
  bos,
362
  eos,
 
363
  no_boseos_middle=no_boseos_middle,
364
  chunk_length=pipe.tokenizer.model_max_length,
365
  )
366
+ uncond_tokens = np.array(uncond_tokens, dtype=np.int32)
367
 
368
  # get the embeddings
369
  text_embeddings = get_unweighted_text_embeddings(
 
372
  pipe.tokenizer.model_max_length,
373
  no_boseos_middle=no_boseos_middle,
374
  )
375
+ prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
376
  if uncond_prompt is not None:
377
  uncond_embeddings = get_unweighted_text_embeddings(
378
  pipe,
 
380
  pipe.tokenizer.model_max_length,
381
  no_boseos_middle=no_boseos_middle,
382
  )
383
+ uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
384
 
385
  # assign weights to the prompts and normalize in the sense of mean
386
  # TODO: should we normalize by chunk or in a whole (current implementation)?
387
  if (not skip_parsing) and (not skip_weighting):
388
+ previous_mean = text_embeddings.mean(axis=(-2, -1))
389
+ text_embeddings *= prompt_weights[:, :, None]
390
+ text_embeddings *= (previous_mean / text_embeddings.mean(axis=(-2, -1)))[:, None, None]
 
391
  if uncond_prompt is not None:
392
+ previous_mean = uncond_embeddings.mean(axis=(-2, -1))
393
+ uncond_embeddings *= uncond_weights[:, :, None]
394
+ uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=(-2, -1)))[:, None, None]
 
395
 
396
+ # For classifier free guidance, we need to do two forward passes.
397
+ # Here we concatenate the unconditional and text embeddings into a single batch
398
+ # to avoid doing two forward passes
399
  if uncond_prompt is not None:
400
  return text_embeddings, uncond_embeddings
401
+
402
+ return text_embeddings
403
 
404
 
405
+ def preprocess_image(image):
406
  w, h = image.size
407
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
408
  image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
409
  image = np.array(image).astype(np.float32) / 255.0
410
+ image = image[None].transpose(0, 3, 1, 2)
 
411
  return 2.0 * image - 1.0
412
 
413
 
414
+ def preprocess_mask(mask, scale_factor=8):
415
+ mask = mask.convert("L")
416
+ w, h = mask.size
417
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
418
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
419
+ mask = np.array(mask).astype(np.float32) / 255.0
420
+ mask = np.tile(mask, (4, 1, 1))
421
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
422
+ mask = 1 - mask # repaint white, keep black
423
+ return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
 
426
+ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline):
 
 
427
  r"""
428
  Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
429
  weighting in prompt.
430
 
431
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
432
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  """
434
+ if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
435
+
436
+ def __init__(
437
+ self,
438
+ vae_encoder: OnnxRuntimeModel,
439
+ vae_decoder: OnnxRuntimeModel,
440
+ text_encoder: OnnxRuntimeModel,
441
+ tokenizer: CLIPTokenizer,
442
+ unet: OnnxRuntimeModel,
443
+ scheduler: SchedulerMixin,
444
+ safety_checker: OnnxRuntimeModel,
445
+ feature_extractor: CLIPFeatureExtractor,
446
+ requires_safety_checker: bool = True,
447
+ ):
448
+ super().__init__(
449
+ vae_encoder=vae_encoder,
450
+ vae_decoder=vae_decoder,
451
+ text_encoder=text_encoder,
452
+ tokenizer=tokenizer,
453
+ unet=unet,
454
+ scheduler=scheduler,
455
+ safety_checker=safety_checker,
456
+ feature_extractor=feature_extractor,
457
+ requires_safety_checker=requires_safety_checker,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  )
459
+ self.__init__additional__()
460
 
461
+ else:
 
 
 
 
462
 
463
+ def __init__(
464
+ self,
465
+ vae_encoder: OnnxRuntimeModel,
466
+ vae_decoder: OnnxRuntimeModel,
467
+ text_encoder: OnnxRuntimeModel,
468
+ tokenizer: CLIPTokenizer,
469
+ unet: OnnxRuntimeModel,
470
+ scheduler: SchedulerMixin,
471
+ safety_checker: OnnxRuntimeModel,
472
+ feature_extractor: CLIPFeatureExtractor,
473
+ ):
474
+ super().__init__(
475
+ vae_encoder=vae_encoder,
476
+ vae_decoder=vae_decoder,
477
+ text_encoder=text_encoder,
478
+ tokenizer=tokenizer,
479
+ unet=unet,
480
+ scheduler=scheduler,
481
+ safety_checker=safety_checker,
482
+ feature_extractor=feature_extractor,
483
  )
484
+ self.__init__additional__()
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
+ def __init__additional__(self):
487
+ self.unet_in_channels = 4
488
+ self.vae_scale_factor = 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
  def _encode_prompt(
491
  self,
492
  prompt,
 
493
  num_images_per_prompt,
494
  do_classifier_free_guidance,
495
+ negative_prompt,
496
+ max_embeddings_multiples,
 
 
497
  ):
498
  r"""
499
  Encodes the prompt into text encoder hidden states.
 
501
  Args:
502
  prompt (`str` or `list(int)`):
503
  prompt to be encoded
 
 
504
  num_images_per_prompt (`int`):
505
  number of images that should be generated per prompt
506
  do_classifier_free_guidance (`bool`):
 
511
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
512
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
513
  """
514
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
515
+
516
+ if negative_prompt is None:
517
+ negative_prompt = [""] * batch_size
518
+ elif isinstance(negative_prompt, str):
519
+ negative_prompt = [negative_prompt] * batch_size
520
+ if batch_size != len(negative_prompt):
521
+ raise ValueError(
522
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
523
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
524
+ " the batch size of `prompt`."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  )
 
 
 
 
526
 
527
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
528
+ pipe=self,
529
+ prompt=prompt,
530
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
531
+ max_embeddings_multiples=max_embeddings_multiples,
532
+ )
533
 
534
+ text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
535
  if do_classifier_free_guidance:
536
+ uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)
537
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
 
 
538
 
539
+ return text_embeddings
540
 
541
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
542
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
543
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
 
 
 
 
 
 
 
 
 
544
 
545
  if strength < 0 or strength > 1:
546
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
547
 
548
+ if height % 8 != 0 or width % 8 != 0:
549
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
550
+
551
  if (callback_steps is None) or (
552
  callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
553
  ):
 
556
  f" {type(callback_steps)}."
557
  )
558
 
559
+ def get_timesteps(self, num_inference_steps, strength, is_text2img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  if is_text2img:
561
+ return self.scheduler.timesteps, num_inference_steps
562
  else:
563
  # get the original timestep using init_timestep
564
+ offset = self.scheduler.config.get("steps_offset", 0)
565
+ init_timestep = int(num_inference_steps * strength) + offset
566
+ init_timestep = min(init_timestep, num_inference_steps)
 
567
 
568
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
569
+ timesteps = self.scheduler.timesteps[t_start:]
570
  return timesteps, num_inference_steps - t_start
571
 
572
+ def run_safety_checker(self, image):
573
  if self.safety_checker is not None:
574
+ safety_checker_input = self.feature_extractor(
575
+ self.numpy_to_pil(image), return_tensors="np"
576
+ ).pixel_values.astype(image.dtype)
577
+ # There will throw an error if use safety_checker directly and batchsize>1
578
+ images, has_nsfw_concept = [], []
579
+ for i in range(image.shape[0]):
580
+ image_i, has_nsfw_concept_i = self.safety_checker(
581
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
582
+ )
583
+ images.append(image_i)
584
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
585
+ image = np.concatenate(images)
586
  else:
587
  has_nsfw_concept = None
588
  return image, has_nsfw_concept
589
 
590
  def decode_latents(self, latents):
591
+ latents = 1 / 0.18215 * latents
592
+ # image = self.vae_decoder(latent_sample=latents)[0]
593
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
594
+ image = np.concatenate(
595
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
596
+ )
597
+ image = np.clip(image / 2 + 0.5, 0, 1)
598
+ image = image.transpose((0, 2, 3, 1))
599
  return image
600
 
601
  def prepare_extra_step_kwargs(self, generator, eta):
 
615
  extra_step_kwargs["generator"] = generator
616
  return extra_step_kwargs
617
 
618
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  if image is None:
620
+ shape = (
621
+ batch_size,
622
+ self.unet_in_channels,
623
+ height // self.vae_scale_factor,
624
+ width // self.vae_scale_factor,
625
+ )
 
626
 
627
  if latents is None:
628
+ latents = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
629
  else:
630
+ if latents.shape != shape:
631
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
632
 
633
  # scale the initial noise by the standard deviation required by the scheduler
634
+ latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy()
635
  return latents, None, None
636
  else:
637
+ init_latents = self.vae_encoder(sample=image)[0]
638
+ init_latents = 0.18215 * init_latents
639
+ init_latents = np.concatenate([init_latents] * batch_size, axis=0)
 
 
 
 
640
  init_latents_orig = init_latents
641
+ shape = init_latents.shape
642
 
643
  # add noise to latents using the timesteps
644
+ noise = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
645
+ latents = self.scheduler.add_noise(
646
+ torch.from_numpy(init_latents), torch.from_numpy(noise), timestep
647
+ ).numpy()
648
  return latents, init_latents_orig, noise
649
 
650
  @torch.no_grad()
 
652
  self,
653
  prompt: Union[str, List[str]],
654
  negative_prompt: Optional[Union[str, List[str]]] = None,
655
+ image: Union[np.ndarray, PIL.Image.Image] = None,
656
+ mask_image: Union[np.ndarray, PIL.Image.Image] = None,
657
  height: int = 512,
658
  width: int = 512,
659
  num_inference_steps: int = 50,
660
  guidance_scale: float = 7.5,
661
  strength: float = 0.8,
662
  num_images_per_prompt: Optional[int] = 1,
 
663
  eta: float = 0.0,
664
+ generator: Optional[torch.Generator] = None,
665
+ latents: Optional[np.ndarray] = None,
 
 
666
  max_embeddings_multiples: Optional[int] = 3,
667
  output_type: Optional[str] = "pil",
668
  return_dict: bool = True,
669
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
670
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
671
+ callback_steps: Optional[int] = 1,
672
+ **kwargs,
673
  ):
674
  r"""
675
  Function invoked when calling the pipeline for generation.
 
680
  negative_prompt (`str` or `List[str]`, *optional*):
681
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
682
  if `guidance_scale` is less than `1`).
683
+ image (`np.ndarray` or `PIL.Image.Image`):
684
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
685
  process.
686
+ mask_image (`np.ndarray` or `PIL.Image.Image`):
687
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
688
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
689
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
 
709
  `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
710
  num_images_per_prompt (`int`, *optional*, defaults to 1):
711
  The number of images to generate per prompt.
 
 
 
712
  eta (`float`, *optional*, defaults to 0.0):
713
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
714
  [`schedulers.DDIMScheduler`], will be ignored for others.
715
+ generator (`torch.Generator`, *optional*):
716
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
717
+ deterministic.
718
+ latents (`np.ndarray`, *optional*):
719
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
720
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
721
  tensor will ge generated by sampling using the supplied random `generator`.
 
 
 
 
 
 
 
722
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
723
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
724
  output_type (`str`, *optional*, defaults to `"pil"`):
 
729
  plain tuple.
730
  callback (`Callable`, *optional*):
731
  A function that will be called every `callback_steps` steps during inference. The function will be
732
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
733
  is_cancelled_callback (`Callable`, *optional*):
734
  A function that will be called every `callback_steps` steps during inference. If the function returns
735
  `True`, the inference will be cancelled.
736
  callback_steps (`int`, *optional*, defaults to 1):
737
  The frequency at which the `callback` function will be called. If not specified, the callback will be
738
  called at every step.
 
 
 
 
739
 
740
  Returns:
741
  `None` if cancelled by `is_cancelled_callback`,
 
745
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
746
  (nsfw) content, according to the `safety_checker`.
747
  """
748
+ message = "Please use `image` instead of `init_image`."
749
+ init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
750
+ image = init_image or image
751
+
752
  # 0. Default height and width to unet
753
  height = height or self.unet.config.sample_size * self.vae_scale_factor
754
  width = width or self.unet.config.sample_size * self.vae_scale_factor
755
 
756
  # 1. Check inputs. Raise error if not correct
757
+ self.check_inputs(prompt, height, width, strength, callback_steps)
 
 
758
 
759
  # 2. Define call parameters
760
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
 
 
 
 
 
 
 
761
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
762
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
763
  # corresponds to doing no classifier free guidance.
764
  do_classifier_free_guidance = guidance_scale > 1.0
765
 
766
  # 3. Encode input prompt
767
+ text_embeddings = self._encode_prompt(
768
  prompt,
 
769
  num_images_per_prompt,
770
  do_classifier_free_guidance,
771
  negative_prompt,
772
  max_embeddings_multiples,
 
 
773
  )
774
+ dtype = text_embeddings.dtype
775
 
776
  # 4. Preprocess image and mask
777
  if isinstance(image, PIL.Image.Image):
778
+ image = preprocess_image(image)
779
  if image is not None:
780
+ image = image.astype(dtype)
781
  if isinstance(mask_image, PIL.Image.Image):
782
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
783
  if mask_image is not None:
784
+ mask = mask_image.astype(dtype)
785
+ mask = np.concatenate([mask] * batch_size * num_images_per_prompt)
786
  else:
787
  mask = None
788
 
789
  # 5. set timesteps
790
+ self.scheduler.set_timesteps(num_inference_steps)
791
+ timestep_dtype = next(
792
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
793
+ )
794
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
795
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None)
796
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
797
 
798
  # 6. Prepare latent variables
799
  latents, init_latents_orig, noise = self.prepare_latents(
800
  image,
801
  latent_timestep,
802
+ batch_size * num_images_per_prompt,
 
 
803
  height,
804
  width,
805
  dtype,
 
806
  generator,
807
  latents,
808
  )
 
811
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
812
 
813
  # 8. Denoising loop
814
+ for i, t in enumerate(self.progress_bar(timesteps)):
815
+ # expand the latents if we are doing classifier free guidance
816
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
817
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
818
+ latent_model_input = latent_model_input.numpy()
819
+
820
+ # predict the noise residual
821
+ noise_pred = self.unet(
822
+ sample=latent_model_input,
823
+ timestep=np.array([t], dtype=timestep_dtype),
824
+ encoder_hidden_states=text_embeddings,
825
+ )
826
+ noise_pred = noise_pred[0]
827
+
828
+ # perform guidance
829
+ if do_classifier_free_guidance:
830
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
831
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
832
+
833
+ # compute the previous noisy sample x_t -> x_t-1
834
+ scheduler_output = self.scheduler.step(
835
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
836
+ )
837
+ latents = scheduler_output.prev_sample.numpy()
838
+
839
+ if mask is not None:
840
+ # masking
841
+ init_latents_proper = self.scheduler.add_noise(
842
+ torch.from_numpy(init_latents_orig),
843
+ torch.from_numpy(noise),
844
  t,
845
+ ).numpy()
846
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
847
 
848
+ # call the callback, if provided
849
+ if i % callback_steps == 0:
850
+ if callback is not None:
851
+ callback(i, t, latents)
852
+ if is_cancelled_callback is not None and is_cancelled_callback():
853
+ return None
854
 
855
+ # 9. Post-processing
856
+ image = self.decode_latents(latents)
 
 
 
857
 
858
+ # 10. Run safety checker
859
+ image, has_nsfw_concept = self.run_safety_checker(image)
860
 
861
+ # 11. Convert to PIL
862
+ if output_type == "pil":
863
+ image = self.numpy_to_pil(image)
864
 
865
  if not return_dict:
866
  return image, has_nsfw_concept
 
877
  guidance_scale: float = 7.5,
878
  num_images_per_prompt: Optional[int] = 1,
879
  eta: float = 0.0,
880
+ generator: Optional[torch.Generator] = None,
881
+ latents: Optional[np.ndarray] = None,
 
 
882
  max_embeddings_multiples: Optional[int] = 3,
883
  output_type: Optional[str] = "pil",
884
  return_dict: bool = True,
885
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
886
+ callback_steps: Optional[int] = 1,
887
+ **kwargs,
 
888
  ):
889
  r"""
890
  Function for text-to-image generation.
 
912
  eta (`float`, *optional*, defaults to 0.0):
913
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
914
  [`schedulers.DDIMScheduler`], will be ignored for others.
915
+ generator (`torch.Generator`, *optional*):
916
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
917
+ deterministic.
918
+ latents (`np.ndarray`, *optional*):
919
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
920
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
921
  tensor will ge generated by sampling using the supplied random `generator`.
 
 
 
 
 
 
 
922
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
923
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
924
  output_type (`str`, *optional*, defaults to `"pil"`):
 
929
  plain tuple.
930
  callback (`Callable`, *optional*):
931
  A function that will be called every `callback_steps` steps during inference. The function will be
932
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
 
 
 
933
  callback_steps (`int`, *optional*, defaults to 1):
934
  The frequency at which the `callback` function will be called. If not specified, the callback will be
935
  called at every step.
 
 
 
 
 
936
  Returns:
 
937
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
938
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
939
  When returning a tuple, the first element is a list with the generated images, and the second element is a
 
951
  eta=eta,
952
  generator=generator,
953
  latents=latents,
 
 
954
  max_embeddings_multiples=max_embeddings_multiples,
955
  output_type=output_type,
956
  return_dict=return_dict,
957
  callback=callback,
 
958
  callback_steps=callback_steps,
959
+ **kwargs,
960
  )
961
 
962
  def img2img(
963
  self,
964
+ image: Union[np.ndarray, PIL.Image.Image],
965
  prompt: Union[str, List[str]],
966
  negative_prompt: Optional[Union[str, List[str]]] = None,
967
  strength: float = 0.8,
 
969
  guidance_scale: Optional[float] = 7.5,
970
  num_images_per_prompt: Optional[int] = 1,
971
  eta: Optional[float] = 0.0,
972
+ generator: Optional[torch.Generator] = None,
 
 
973
  max_embeddings_multiples: Optional[int] = 3,
974
  output_type: Optional[str] = "pil",
975
  return_dict: bool = True,
976
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
977
+ callback_steps: Optional[int] = 1,
978
+ **kwargs,
 
979
  ):
980
  r"""
981
  Function for image-to-image generation.
982
  Args:
983
+ image (`np.ndarray` or `PIL.Image.Image`):
984
+ `Image`, or ndarray representing an image batch, that will be used as the starting point for the
985
  process.
986
  prompt (`str` or `List[str]`):
987
  The prompt or prompts to guide the image generation.
 
1008
  eta (`float`, *optional*, defaults to 0.0):
1009
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1010
  [`schedulers.DDIMScheduler`], will be ignored for others.
1011
+ generator (`torch.Generator`, *optional*):
1012
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1013
+ deterministic.
 
 
 
 
 
 
 
1014
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1015
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1016
  output_type (`str`, *optional*, defaults to `"pil"`):
 
1021
  plain tuple.
1022
  callback (`Callable`, *optional*):
1023
  A function that will be called every `callback_steps` steps during inference. The function will be
1024
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
 
 
 
1025
  callback_steps (`int`, *optional*, defaults to 1):
1026
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1027
  called at every step.
 
 
 
 
 
1028
  Returns:
1029
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1030
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1031
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1032
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
 
1042
  num_images_per_prompt=num_images_per_prompt,
1043
  eta=eta,
1044
  generator=generator,
 
 
1045
  max_embeddings_multiples=max_embeddings_multiples,
1046
  output_type=output_type,
1047
  return_dict=return_dict,
1048
  callback=callback,
 
1049
  callback_steps=callback_steps,
1050
+ **kwargs,
1051
  )
1052
 
1053
  def inpaint(
1054
  self,
1055
+ image: Union[np.ndarray, PIL.Image.Image],
1056
+ mask_image: Union[np.ndarray, PIL.Image.Image],
1057
  prompt: Union[str, List[str]],
1058
  negative_prompt: Optional[Union[str, List[str]]] = None,
1059
  strength: float = 0.8,
1060
  num_inference_steps: Optional[int] = 50,
1061
  guidance_scale: Optional[float] = 7.5,
1062
  num_images_per_prompt: Optional[int] = 1,
 
1063
  eta: Optional[float] = 0.0,
1064
+ generator: Optional[torch.Generator] = None,
 
 
1065
  max_embeddings_multiples: Optional[int] = 3,
1066
  output_type: Optional[str] = "pil",
1067
  return_dict: bool = True,
1068
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
1069
+ callback_steps: Optional[int] = 1,
1070
+ **kwargs,
 
1071
  ):
1072
  r"""
1073
  Function for inpaint.
1074
  Args:
1075
+ image (`np.ndarray` or `PIL.Image.Image`):
1076
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
1077
  process. This is the image whose masked region will be inpainted.
1078
+ mask_image (`np.ndarray` or `PIL.Image.Image`):
1079
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1080
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1081
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
 
1101
  usually at the expense of lower image quality.
1102
  num_images_per_prompt (`int`, *optional*, defaults to 1):
1103
  The number of images to generate per prompt.
 
 
 
1104
  eta (`float`, *optional*, defaults to 0.0):
1105
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1106
  [`schedulers.DDIMScheduler`], will be ignored for others.
1107
+ generator (`torch.Generator`, *optional*):
1108
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1109
+ deterministic.
 
 
 
 
 
 
 
1110
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1111
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1112
  output_type (`str`, *optional*, defaults to `"pil"`):
 
1117
  plain tuple.
1118
  callback (`Callable`, *optional*):
1119
  A function that will be called every `callback_steps` steps during inference. The function will be
1120
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
 
 
 
1121
  callback_steps (`int`, *optional*, defaults to 1):
1122
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1123
  called at every step.
 
 
 
 
 
1124
  Returns:
1125
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1126
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1127
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1128
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
 
1137
  guidance_scale=guidance_scale,
1138
  strength=strength,
1139
  num_images_per_prompt=num_images_per_prompt,
 
1140
  eta=eta,
1141
  generator=generator,
 
 
1142
  max_embeddings_multiples=max_embeddings_multiples,
1143
  output_type=output_type,
1144
  return_dict=return_dict,
1145
  callback=callback,
 
1146
  callback_steps=callback_steps,
1147
+ **kwargs,
1148
+ )