zhaozhilin commited on
Commit
ecdc6c7
·
1 Parent(s): 51fca57
Files changed (2) hide show
  1. pipeline.py +336 -659
  2. pipeline_lpw_stable_diffusion.py +1472 -0
pipeline.py CHANGED
@@ -1,31 +1,57 @@
1
- # source https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
  import inspect
3
  import re
4
- from typing import Any, Callable, Dict, List, Optional, Union
5
 
6
  import numpy as np
7
  import PIL.Image
8
  import torch
9
  from packaging import version
10
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
11
-
12
- from diffusers import DiffusionPipeline
13
- from diffusers.configuration_utils import FrozenDict
14
- from diffusers.image_processor import VaeImageProcessor
15
- from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
16
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
- from diffusers.schedulers import KarrasDiffusionSchedulers
19
- from diffusers.utils import (
20
- PIL_INTERPOLATION,
21
- deprecate,
22
- is_accelerate_available,
23
- is_accelerate_version,
24
- logging,
25
- )
26
- from diffusers.utils.torch_utils import randn_tensor
27
-
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # ------------------------------------------------------------------------------
30
 
31
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -136,7 +162,7 @@ def parse_prompt_attention(text):
136
  return res
137
 
138
 
139
- def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
140
  r"""
141
  Tokenize a list of prompts and return its tokens with weights of each token.
142
 
@@ -151,8 +177,8 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
151
  text_weight = []
152
  for word, weight in texts_and_weights:
153
  # tokenize and discard the starting and the ending token
154
- token = pipe.tokenizer(word).input_ids[1:-1]
155
- text_token += token
156
  # copy the weight by length of token
157
  text_weight += [weight] * len(token)
158
  # stop if the text is too long (longer than truncation limit)
@@ -197,8 +223,8 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos
197
 
198
 
199
  def get_unweighted_text_embeddings(
200
- pipe: DiffusionPipeline,
201
- text_input: torch.Tensor,
202
  chunk_length: int,
203
  no_boseos_middle: Optional[bool] = True,
204
  ):
@@ -211,12 +237,13 @@ def get_unweighted_text_embeddings(
211
  text_embeddings = []
212
  for i in range(max_embeddings_multiples):
213
  # extract the i-th chunk
214
- text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
215
 
216
  # cover the head and the tail by the starting and the ending tokens
217
  text_input_chunk[:, 0] = text_input[0, 0]
218
  text_input_chunk[:, -1] = text_input[0, -1]
219
- text_embedding = pipe.text_encoder(text_input_chunk)[0]
 
220
 
221
  if no_boseos_middle:
222
  if i == 0:
@@ -230,20 +257,21 @@ def get_unweighted_text_embeddings(
230
  text_embedding = text_embedding[:, 1:-1]
231
 
232
  text_embeddings.append(text_embedding)
233
- text_embeddings = torch.concat(text_embeddings, axis=1)
234
  else:
235
- text_embeddings = pipe.text_encoder(text_input)[0]
236
  return text_embeddings
237
 
238
 
239
  def get_weighted_text_embeddings(
240
- pipe: DiffusionPipeline,
241
  prompt: Union[str, List[str]],
242
  uncond_prompt: Optional[Union[str, List[str]]] = None,
243
- max_embeddings_multiples: Optional[int] = 3,
244
  no_boseos_middle: Optional[bool] = False,
245
  skip_parsing: Optional[bool] = False,
246
  skip_weighting: Optional[bool] = False,
 
247
  ):
248
  r"""
249
  Prompts can be assigned with local weights using brackets. For example,
@@ -253,14 +281,14 @@ def get_weighted_text_embeddings(
253
  Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
254
 
255
  Args:
256
- pipe (`DiffusionPipeline`):
257
  Pipe to provide access to the tokenizer and the text encoder.
258
  prompt (`str` or `List[str]`):
259
  The prompt or prompts to guide the image generation.
260
  uncond_prompt (`str` or `List[str]`):
261
  The unconditional prompt or prompts for guide the image generation. If unconditional prompt
262
  is provided, the embeddings of prompt and uncond_prompt are concatenated.
263
- max_embeddings_multiples (`int`, *optional*, defaults to `3`):
264
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
265
  no_boseos_middle (`bool`, *optional*, defaults to `False`):
266
  If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
@@ -282,7 +310,8 @@ def get_weighted_text_embeddings(
282
  uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
283
  else:
284
  prompt_tokens = [
285
- token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
 
286
  ]
287
  prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
288
  if uncond_prompt is not None:
@@ -290,7 +319,12 @@ def get_weighted_text_embeddings(
290
  uncond_prompt = [uncond_prompt]
291
  uncond_tokens = [
292
  token[1:-1]
293
- for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
 
 
 
 
 
294
  ]
295
  uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
296
 
@@ -320,7 +354,7 @@ def get_weighted_text_embeddings(
320
  no_boseos_middle=no_boseos_middle,
321
  chunk_length=pipe.tokenizer.model_max_length,
322
  )
323
- prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
324
  if uncond_prompt is not None:
325
  uncond_tokens, uncond_weights = pad_tokens_and_weights(
326
  uncond_tokens,
@@ -332,7 +366,7 @@ def get_weighted_text_embeddings(
332
  no_boseos_middle=no_boseos_middle,
333
  chunk_length=pipe.tokenizer.model_max_length,
334
  )
335
- uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
336
 
337
  # get the embeddings
338
  text_embeddings = get_unweighted_text_embeddings(
@@ -341,7 +375,7 @@ def get_weighted_text_embeddings(
341
  pipe.tokenizer.model_max_length,
342
  no_boseos_middle=no_boseos_middle,
343
  )
344
- prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
345
  if uncond_prompt is not None:
346
  uncond_embeddings = get_unweighted_text_embeddings(
347
  pipe,
@@ -349,308 +383,121 @@ def get_weighted_text_embeddings(
349
  pipe.tokenizer.model_max_length,
350
  no_boseos_middle=no_boseos_middle,
351
  )
352
- uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
353
 
354
  # assign weights to the prompts and normalize in the sense of mean
355
  # TODO: should we normalize by chunk or in a whole (current implementation)?
356
  if (not skip_parsing) and (not skip_weighting):
357
- previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
358
- text_embeddings *= prompt_weights.unsqueeze(-1)
359
- current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
360
- text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
361
  if uncond_prompt is not None:
362
- previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
363
- uncond_embeddings *= uncond_weights.unsqueeze(-1)
364
- current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
365
- uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
366
 
 
 
 
367
  if uncond_prompt is not None:
368
  return text_embeddings, uncond_embeddings
369
- return text_embeddings, None
 
370
 
371
 
372
- def preprocess_image(image, batch_size):
373
  w, h = image.size
374
- w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
375
  image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
376
  image = np.array(image).astype(np.float32) / 255.0
377
- image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
378
- image = torch.from_numpy(image)
379
  return 2.0 * image - 1.0
380
 
381
 
382
- def preprocess_mask(mask, batch_size, scale_factor=8):
383
- if not isinstance(mask, torch.FloatTensor):
384
- mask = mask.convert("L")
385
- w, h = mask.size
386
- w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
387
- mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
388
- mask = np.array(mask).astype(np.float32) / 255.0
389
- mask = np.tile(mask, (4, 1, 1))
390
- mask = np.vstack([mask[None]] * batch_size)
391
- mask = 1 - mask # repaint white, keep black
392
- mask = torch.from_numpy(mask)
393
- return mask
394
-
395
- else:
396
- valid_mask_channel_sizes = [1, 3]
397
- # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
398
- if mask.shape[3] in valid_mask_channel_sizes:
399
- mask = mask.permute(0, 3, 1, 2)
400
- elif mask.shape[1] not in valid_mask_channel_sizes:
401
- raise ValueError(
402
- f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
403
- f" but received mask of shape {tuple(mask.shape)}"
404
- )
405
- # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
406
- mask = mask.mean(dim=1, keepdim=True)
407
- h, w = mask.shape[-2:]
408
- h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
409
- mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
410
- return mask
411
 
412
 
413
- class StableDiffusionLongPromptWeightingPipeline(
414
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
415
- ):
416
  r"""
417
  Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
418
  weighting in prompt.
419
 
420
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
421
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
422
-
423
- Args:
424
- vae ([`AutoencoderKL`]):
425
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
426
- text_encoder ([`CLIPTextModel`]):
427
- Frozen text-encoder. Stable Diffusion uses the text portion of
428
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
429
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
430
- tokenizer (`CLIPTokenizer`):
431
- Tokenizer of class
432
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
433
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
434
- scheduler ([`SchedulerMixin`]):
435
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
436
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
437
- safety_checker ([`StableDiffusionSafetyChecker`]):
438
- Classification module that estimates whether generated images could be considered offensive or harmful.
439
- Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
440
- feature_extractor ([`CLIPImageProcessor`]):
441
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
442
  """
443
 
444
- _optional_components = ["safety_checker", "feature_extractor"]
445
-
446
- def __init__(
447
- self,
448
- vae: AutoencoderKL,
449
- text_encoder: CLIPTextModel,
450
- tokenizer: CLIPTokenizer,
451
- unet: UNet2DConditionModel,
452
- scheduler: KarrasDiffusionSchedulers,
453
- safety_checker: StableDiffusionSafetyChecker,
454
- feature_extractor: CLIPImageProcessor,
455
- requires_safety_checker: bool = True,
456
- ):
457
- super().__init__()
458
-
459
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
460
- deprecation_message = (
461
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
462
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
463
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
464
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
465
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
466
- " file"
467
- )
468
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
469
- new_config = dict(scheduler.config)
470
- new_config["steps_offset"] = 1
471
- scheduler._internal_dict = FrozenDict(new_config)
472
-
473
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
474
- deprecation_message = (
475
- f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
476
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
477
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
478
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
479
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
480
- )
481
- deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
482
- new_config = dict(scheduler.config)
483
- new_config["clip_sample"] = False
484
- scheduler._internal_dict = FrozenDict(new_config)
485
-
486
- if safety_checker is None and requires_safety_checker:
487
- logger.warning(
488
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
489
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
490
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
491
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
492
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
493
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
494
  )
 
495
 
496
- if safety_checker is not None and feature_extractor is None:
497
- raise ValueError(
498
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
499
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
500
- )
501
 
502
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
503
- version.parse(unet.config._diffusers_version).base_version
504
- ) < version.parse("0.9.0.dev0")
505
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
506
- if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
507
- deprecation_message = (
508
- "The configuration file of the unet has set the default `sample_size` to smaller than"
509
- " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
510
- " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
511
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
512
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
513
- " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
514
- " in the config might lead to incorrect results in future versions. If you have downloaded this"
515
- " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
516
- " the `unet/config.json` file"
 
 
 
 
 
517
  )
518
- deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
519
- new_config = dict(unet.config)
520
- new_config["sample_size"] = 64
521
- unet._internal_dict = FrozenDict(new_config)
522
- self.register_modules(
523
- vae=vae,
524
- text_encoder=text_encoder,
525
- tokenizer=tokenizer,
526
- unet=unet,
527
- scheduler=scheduler,
528
- safety_checker=safety_checker,
529
- feature_extractor=feature_extractor,
530
- )
531
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
532
 
533
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
534
- self.register_to_config(
535
- requires_safety_checker=requires_safety_checker,
536
- )
537
 
538
- def enable_vae_slicing(self):
539
- r"""
540
- Enable sliced VAE decoding.
541
-
542
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
543
- steps. This is useful to save some memory and allow larger batch sizes.
544
- """
545
- self.vae.enable_slicing()
546
-
547
- def disable_vae_slicing(self):
548
- r"""
549
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
550
- computing decoding in one step.
551
- """
552
- self.vae.disable_slicing()
553
-
554
- def enable_vae_tiling(self):
555
- r"""
556
- Enable tiled VAE decoding.
557
-
558
- When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
559
- several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
560
- """
561
- self.vae.enable_tiling()
562
-
563
- def disable_vae_tiling(self):
564
- r"""
565
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
566
- computing decoding in one step.
567
- """
568
- self.vae.disable_tiling()
569
-
570
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
571
- def enable_sequential_cpu_offload(self, gpu_id=0):
572
- r"""
573
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
574
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
575
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
576
- Note that offloading happens on a submodule basis. Memory savings are higher than with
577
- `enable_model_cpu_offload`, but performance is lower.
578
- """
579
- if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
580
- from accelerate import cpu_offload
581
- else:
582
- raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
583
-
584
- device = torch.device(f"cuda:{gpu_id}")
585
-
586
- if self.device.type != "cpu":
587
- self.to("cpu", silence_dtype_warnings=True)
588
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
589
-
590
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
591
- cpu_offload(cpu_offloaded_model, device)
592
-
593
- if self.safety_checker is not None:
594
- cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
595
-
596
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
597
- def enable_model_cpu_offload(self, gpu_id=0):
598
- r"""
599
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
600
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
601
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
602
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
603
- """
604
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
605
- from accelerate import cpu_offload_with_hook
606
- else:
607
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
608
-
609
- device = torch.device(f"cuda:{gpu_id}")
610
-
611
- if self.device.type != "cpu":
612
- self.to("cpu", silence_dtype_warnings=True)
613
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
614
-
615
- hook = None
616
- for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
617
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
618
-
619
- if self.safety_checker is not None:
620
- _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
621
-
622
- # We'll offload the last model manually.
623
- self.final_offload_hook = hook
624
-
625
- @property
626
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
627
- def _execution_device(self):
628
- r"""
629
- Returns the device on which the pipeline's models will be executed. After calling
630
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
631
- hooks.
632
- """
633
- if not hasattr(self.unet, "_hf_hook"):
634
- return self.device
635
- for module in self.unet.modules():
636
- if (
637
- hasattr(module, "_hf_hook")
638
- and hasattr(module._hf_hook, "execution_device")
639
- and module._hf_hook.execution_device is not None
640
- ):
641
- return torch.device(module._hf_hook.execution_device)
642
- return self.device
643
-
644
- def encode_prompt(
645
  self,
646
  prompt,
647
- device,
648
  num_images_per_prompt,
649
  do_classifier_free_guidance,
650
- negative_prompt=None,
651
- max_embeddings_multiples=3,
652
- prompt_embeds: Optional[torch.FloatTensor] = None,
653
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
654
  ):
655
  r"""
656
  Encodes the prompt into text encoder hidden states.
@@ -658,8 +505,6 @@ class StableDiffusionLongPromptWeightingPipeline(
658
  Args:
659
  prompt (`str` or `list(int)`):
660
  prompt to be encoded
661
- device: (`torch.device`):
662
- torch device
663
  num_images_per_prompt (`int`):
664
  number of images that should be generated per prompt
665
  do_classifier_free_guidance (`bool`):
@@ -670,71 +515,43 @@ class StableDiffusionLongPromptWeightingPipeline(
670
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
671
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
672
  """
673
- if prompt is not None and isinstance(prompt, str):
674
- batch_size = 1
675
- elif prompt is not None and isinstance(prompt, list):
676
- batch_size = len(prompt)
677
- else:
678
- batch_size = prompt_embeds.shape[0]
679
-
680
- if negative_prompt_embeds is None:
681
- if negative_prompt is None:
682
- negative_prompt = [""] * batch_size
683
- elif isinstance(negative_prompt, str):
684
- negative_prompt = [negative_prompt] * batch_size
685
- if batch_size != len(negative_prompt):
686
- raise ValueError(
687
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
688
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
689
- " the batch size of `prompt`."
690
- )
691
- if prompt_embeds is None or negative_prompt_embeds is None:
692
- if isinstance(self, TextualInversionLoaderMixin):
693
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
694
- if do_classifier_free_guidance and negative_prompt_embeds is None:
695
- negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
696
-
697
- prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
698
- pipe=self,
699
- prompt=prompt,
700
- uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
701
- max_embeddings_multiples=max_embeddings_multiples,
702
  )
703
- if prompt_embeds is None:
704
- prompt_embeds = prompt_embeds1
705
- if negative_prompt_embeds is None:
706
- negative_prompt_embeds = negative_prompt_embeds1
707
 
708
- bs_embed, seq_len, _ = prompt_embeds.shape
709
- # duplicate text embeddings for each generation per prompt, using mps friendly method
710
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
711
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
 
712
 
 
713
  if do_classifier_free_guidance:
714
- bs_embed, seq_len, _ = negative_prompt_embeds.shape
715
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
716
- negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
717
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
718
 
719
- return prompt_embeds
720
 
721
- def check_inputs(
722
- self,
723
- prompt,
724
- height,
725
- width,
726
- strength,
727
- callback_steps,
728
- negative_prompt=None,
729
- prompt_embeds=None,
730
- negative_prompt_embeds=None,
731
- ):
732
- if height % 8 != 0 or width % 8 != 0:
733
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
734
 
735
  if strength < 0 or strength > 1:
736
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
737
 
 
 
 
738
  if (callback_steps is None) or (
739
  callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
740
  ):
@@ -743,60 +560,46 @@ class StableDiffusionLongPromptWeightingPipeline(
743
  f" {type(callback_steps)}."
744
  )
745
 
746
- if prompt is not None and prompt_embeds is not None:
747
- raise ValueError(
748
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
749
- " only forward one of the two."
750
- )
751
- elif prompt is None and prompt_embeds is None:
752
- raise ValueError(
753
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
754
- )
755
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
756
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
757
-
758
- if negative_prompt is not None and negative_prompt_embeds is not None:
759
- raise ValueError(
760
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
761
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
762
- )
763
-
764
- if prompt_embeds is not None and negative_prompt_embeds is not None:
765
- if prompt_embeds.shape != negative_prompt_embeds.shape:
766
- raise ValueError(
767
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
768
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
769
- f" {negative_prompt_embeds.shape}."
770
- )
771
-
772
- def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
773
  if is_text2img:
774
- return self.scheduler.timesteps.to(device), num_inference_steps
775
  else:
776
  # get the original timestep using init_timestep
777
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
778
-
779
- t_start = max(num_inference_steps - init_timestep, 0)
780
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
781
 
 
 
782
  return timesteps, num_inference_steps - t_start
783
 
784
- def run_safety_checker(self, image, device, dtype):
785
  if self.safety_checker is not None:
786
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
787
- image, has_nsfw_concept = self.safety_checker(
788
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
789
- )
 
 
 
 
 
 
 
 
790
  else:
791
  has_nsfw_concept = None
792
  return image, has_nsfw_concept
793
 
794
  def decode_latents(self, latents):
795
- latents = 1 / self.vae.config.scaling_factor * latents
796
- image = self.vae.decode(latents).sample
797
- image = (image / 2 + 0.5).clamp(0, 1)
798
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
799
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
 
 
 
800
  return image
801
 
802
  def prepare_extra_step_kwargs(self, generator, eta):
@@ -816,51 +619,36 @@ class StableDiffusionLongPromptWeightingPipeline(
816
  extra_step_kwargs["generator"] = generator
817
  return extra_step_kwargs
818
 
819
- def prepare_latents(
820
- self,
821
- image,
822
- timestep,
823
- num_images_per_prompt,
824
- batch_size,
825
- num_channels_latents,
826
- height,
827
- width,
828
- dtype,
829
- device,
830
- generator,
831
- latents=None,
832
- ):
833
  if image is None:
834
- batch_size = batch_size * num_images_per_prompt
835
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
836
- if isinstance(generator, list) and len(generator) != batch_size:
837
- raise ValueError(
838
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
839
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
840
- )
841
 
842
  if latents is None:
843
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
844
  else:
845
- latents = latents.to(device)
 
846
 
847
  # scale the initial noise by the standard deviation required by the scheduler
848
- latents = latents * self.scheduler.init_noise_sigma
849
  return latents, None, None
850
  else:
851
- image = image.to(device=self.device, dtype=dtype)
852
- init_latent_dist = self.vae.encode(image).latent_dist
853
- init_latents = init_latent_dist.sample(generator=generator)
854
- init_latents = self.vae.config.scaling_factor * init_latents
855
-
856
- # Expand init_latents for batch_size and num_images_per_prompt
857
- init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
858
  init_latents_orig = init_latents
 
859
 
860
  # add noise to latents using the timesteps
861
- noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
862
- init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
863
- latents = init_latents
 
864
  return latents, init_latents_orig, noise
865
 
866
  @torch.no_grad()
@@ -868,27 +656,24 @@ class StableDiffusionLongPromptWeightingPipeline(
868
  self,
869
  prompt: Union[str, List[str]],
870
  negative_prompt: Optional[Union[str, List[str]]] = None,
871
- image: Union[torch.FloatTensor, PIL.Image.Image] = None,
872
- mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
873
  height: int = 512,
874
  width: int = 512,
875
  num_inference_steps: int = 50,
876
  guidance_scale: float = 7.5,
877
  strength: float = 0.8,
878
  num_images_per_prompt: Optional[int] = 1,
879
- add_predicted_noise: Optional[bool] = False,
880
  eta: float = 0.0,
881
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
882
- latents: Optional[torch.FloatTensor] = None,
883
- prompt_embeds: Optional[torch.FloatTensor] = None,
884
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
885
  max_embeddings_multiples: Optional[int] = 3,
886
  output_type: Optional[str] = "pil",
887
  return_dict: bool = True,
888
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
889
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
890
  callback_steps: int = 1,
891
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
892
  ):
893
  r"""
894
  Function invoked when calling the pipeline for generation.
@@ -899,10 +684,10 @@ class StableDiffusionLongPromptWeightingPipeline(
899
  negative_prompt (`str` or `List[str]`, *optional*):
900
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
901
  if `guidance_scale` is less than `1`).
902
- image (`torch.FloatTensor` or `PIL.Image.Image`):
903
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
904
  process.
905
- mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
906
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
907
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
908
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
@@ -928,26 +713,16 @@ class StableDiffusionLongPromptWeightingPipeline(
928
  `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
929
  num_images_per_prompt (`int`, *optional*, defaults to 1):
930
  The number of images to generate per prompt.
931
- add_predicted_noise (`bool`, *optional*, defaults to True):
932
- Use predicted noise instead of random noise when constructing noisy versions of the original image in
933
- the reverse diffusion process
934
  eta (`float`, *optional*, defaults to 0.0):
935
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
936
  [`schedulers.DDIMScheduler`], will be ignored for others.
937
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
938
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
939
- to make generation deterministic.
940
- latents (`torch.FloatTensor`, *optional*):
941
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
942
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
943
  tensor will ge generated by sampling using the supplied random `generator`.
944
- prompt_embeds (`torch.FloatTensor`, *optional*):
945
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
946
- provided, text embeddings will be generated from `prompt` input argument.
947
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
948
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
949
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
950
- argument.
951
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
952
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
953
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -958,17 +733,13 @@ class StableDiffusionLongPromptWeightingPipeline(
958
  plain tuple.
959
  callback (`Callable`, *optional*):
960
  A function that will be called every `callback_steps` steps during inference. The function will be
961
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
962
  is_cancelled_callback (`Callable`, *optional*):
963
  A function that will be called every `callback_steps` steps during inference. If the function returns
964
  `True`, the inference will be cancelled.
965
  callback_steps (`int`, *optional*, defaults to 1):
966
  The frequency at which the `callback` function will be called. If not specified, the callback will be
967
  called at every step.
968
- cross_attention_kwargs (`dict`, *optional*):
969
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
970
- `self.processor` in
971
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
972
 
973
  Returns:
974
  `None` if cancelled by `is_cancelled_callback`,
@@ -983,66 +754,55 @@ class StableDiffusionLongPromptWeightingPipeline(
983
  width = width or self.unet.config.sample_size * self.vae_scale_factor
984
 
985
  # 1. Check inputs. Raise error if not correct
986
- self.check_inputs(
987
- prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
988
- )
989
 
990
  # 2. Define call parameters
991
- if prompt is not None and isinstance(prompt, str):
992
- batch_size = 1
993
- elif prompt is not None and isinstance(prompt, list):
994
- batch_size = len(prompt)
995
- else:
996
- batch_size = prompt_embeds.shape[0]
997
-
998
- device = self._execution_device
999
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1000
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1001
  # corresponds to doing no classifier free guidance.
1002
  do_classifier_free_guidance = guidance_scale > 1.0
1003
 
1004
  # 3. Encode input prompt
1005
- prompt_embeds = self.encode_prompt(
1006
  prompt,
1007
- device,
1008
  num_images_per_prompt,
1009
  do_classifier_free_guidance,
1010
  negative_prompt,
1011
  max_embeddings_multiples,
1012
- prompt_embeds=prompt_embeds,
1013
- negative_prompt_embeds=negative_prompt_embeds,
1014
  )
1015
- dtype = prompt_embeds.dtype
1016
 
1017
  # 4. Preprocess image and mask
1018
  if isinstance(image, PIL.Image.Image):
1019
- image = preprocess_image(image, batch_size)
1020
  if image is not None:
1021
- image = image.to(device=self.device, dtype=dtype)
1022
  if isinstance(mask_image, PIL.Image.Image):
1023
- mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
1024
  if mask_image is not None:
1025
- mask = mask_image.to(device=self.device, dtype=dtype)
1026
- mask = torch.cat([mask] * num_images_per_prompt)
1027
  else:
1028
  mask = None
1029
 
1030
  # 5. set timesteps
1031
- self.scheduler.set_timesteps(num_inference_steps, device=device)
1032
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
 
 
 
 
1033
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1034
 
1035
  # 6. Prepare latent variables
1036
  latents, init_latents_orig, noise = self.prepare_latents(
1037
  image,
1038
  latent_timestep,
1039
- num_images_per_prompt,
1040
- batch_size,
1041
- self.unet.config.in_channels,
1042
  height,
1043
  width,
1044
  dtype,
1045
- device,
1046
  generator,
1047
  latents,
1048
  )
@@ -1051,71 +811,57 @@ class StableDiffusionLongPromptWeightingPipeline(
1051
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1052
 
1053
  # 8. Denoising loop
1054
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1055
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1056
- for i, t in enumerate(timesteps):
1057
- # expand the latents if we are doing classifier free guidance
1058
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1059
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1060
-
1061
- # predict the noise residual
1062
- noise_pred = self.unet(
1063
- latent_model_input,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1064
  t,
1065
- encoder_hidden_states=prompt_embeds,
1066
- cross_attention_kwargs=cross_attention_kwargs,
1067
- ).sample
1068
-
1069
- # perform guidance
1070
- if do_classifier_free_guidance:
1071
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1072
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1073
-
1074
- # compute the previous noisy sample x_t -> x_t-1
1075
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1076
-
1077
- if mask is not None:
1078
- # masking
1079
- if add_predicted_noise:
1080
- init_latents_proper = self.scheduler.add_noise(
1081
- init_latents_orig, noise_pred_uncond, torch.tensor([t])
1082
- )
1083
- else:
1084
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
1085
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
1086
-
1087
- # call the callback, if provided
1088
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1089
- progress_bar.update()
1090
- if i % callback_steps == 0:
1091
- if callback is not None:
1092
- step_idx = i // getattr(self.scheduler, "order", 1)
1093
- callback(step_idx, t, latents)
1094
- if is_cancelled_callback is not None and is_cancelled_callback():
1095
- return None
1096
-
1097
- if output_type == "latent":
1098
- image = latents
1099
- has_nsfw_concept = None
1100
- elif output_type == "pil":
1101
- # 9. Post-processing
1102
- image = self.decode_latents(latents)
1103
 
1104
- # 10. Run safety checker
1105
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
 
 
 
 
 
1106
 
1107
- # 11. Convert to PIL
1108
- image = self.numpy_to_pil(image)
1109
- else:
1110
- # 9. Post-processing
1111
- image = self.decode_latents(latents)
1112
 
1113
- # 10. Run safety checker
1114
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1115
 
1116
- # Offload last model to CPU
1117
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1118
- self.final_offload_hook.offload()
1119
 
1120
  if not return_dict:
1121
  return image, has_nsfw_concept
@@ -1132,17 +878,14 @@ class StableDiffusionLongPromptWeightingPipeline(
1132
  guidance_scale: float = 7.5,
1133
  num_images_per_prompt: Optional[int] = 1,
1134
  eta: float = 0.0,
1135
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1136
- latents: Optional[torch.FloatTensor] = None,
1137
- prompt_embeds: Optional[torch.FloatTensor] = None,
1138
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1139
  max_embeddings_multiples: Optional[int] = 3,
1140
  output_type: Optional[str] = "pil",
1141
  return_dict: bool = True,
1142
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1143
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
1144
  callback_steps: int = 1,
1145
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1146
  ):
1147
  r"""
1148
  Function for text-to-image generation.
@@ -1170,20 +913,13 @@ class StableDiffusionLongPromptWeightingPipeline(
1170
  eta (`float`, *optional*, defaults to 0.0):
1171
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1172
  [`schedulers.DDIMScheduler`], will be ignored for others.
1173
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1174
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1175
- to make generation deterministic.
1176
- latents (`torch.FloatTensor`, *optional*):
1177
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1178
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1179
  tensor will ge generated by sampling using the supplied random `generator`.
1180
- prompt_embeds (`torch.FloatTensor`, *optional*):
1181
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1182
- provided, text embeddings will be generated from `prompt` input argument.
1183
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1184
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1185
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1186
- argument.
1187
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1188
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1189
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -1194,20 +930,11 @@ class StableDiffusionLongPromptWeightingPipeline(
1194
  plain tuple.
1195
  callback (`Callable`, *optional*):
1196
  A function that will be called every `callback_steps` steps during inference. The function will be
1197
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1198
- is_cancelled_callback (`Callable`, *optional*):
1199
- A function that will be called every `callback_steps` steps during inference. If the function returns
1200
- `True`, the inference will be cancelled.
1201
  callback_steps (`int`, *optional*, defaults to 1):
1202
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1203
  called at every step.
1204
- cross_attention_kwargs (`dict`, *optional*):
1205
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1206
- `self.processor` in
1207
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1208
-
1209
  Returns:
1210
- `None` if cancelled by `is_cancelled_callback`,
1211
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1212
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1213
  When returning a tuple, the first element is a list with the generated images, and the second element is a
@@ -1225,20 +952,17 @@ class StableDiffusionLongPromptWeightingPipeline(
1225
  eta=eta,
1226
  generator=generator,
1227
  latents=latents,
1228
- prompt_embeds=prompt_embeds,
1229
- negative_prompt_embeds=negative_prompt_embeds,
1230
  max_embeddings_multiples=max_embeddings_multiples,
1231
  output_type=output_type,
1232
  return_dict=return_dict,
1233
  callback=callback,
1234
- is_cancelled_callback=is_cancelled_callback,
1235
  callback_steps=callback_steps,
1236
- cross_attention_kwargs=cross_attention_kwargs,
1237
  )
1238
 
1239
  def img2img(
1240
  self,
1241
- image: Union[torch.FloatTensor, PIL.Image.Image],
1242
  prompt: Union[str, List[str]],
1243
  negative_prompt: Optional[Union[str, List[str]]] = None,
1244
  strength: float = 0.8,
@@ -1246,22 +970,19 @@ class StableDiffusionLongPromptWeightingPipeline(
1246
  guidance_scale: Optional[float] = 7.5,
1247
  num_images_per_prompt: Optional[int] = 1,
1248
  eta: Optional[float] = 0.0,
1249
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1250
- prompt_embeds: Optional[torch.FloatTensor] = None,
1251
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1252
  max_embeddings_multiples: Optional[int] = 3,
1253
  output_type: Optional[str] = "pil",
1254
  return_dict: bool = True,
1255
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1256
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
1257
  callback_steps: int = 1,
1258
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1259
  ):
1260
  r"""
1261
  Function for image-to-image generation.
1262
  Args:
1263
- image (`torch.FloatTensor` or `PIL.Image.Image`):
1264
- `Image`, or tensor representing an image batch, that will be used as the starting point for the
1265
  process.
1266
  prompt (`str` or `List[str]`):
1267
  The prompt or prompts to guide the image generation.
@@ -1288,16 +1009,9 @@ class StableDiffusionLongPromptWeightingPipeline(
1288
  eta (`float`, *optional*, defaults to 0.0):
1289
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1290
  [`schedulers.DDIMScheduler`], will be ignored for others.
1291
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1292
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1293
- to make generation deterministic.
1294
- prompt_embeds (`torch.FloatTensor`, *optional*):
1295
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1296
- provided, text embeddings will be generated from `prompt` input argument.
1297
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1298
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1299
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1300
- argument.
1301
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1302
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1303
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -1308,20 +1022,12 @@ class StableDiffusionLongPromptWeightingPipeline(
1308
  plain tuple.
1309
  callback (`Callable`, *optional*):
1310
  A function that will be called every `callback_steps` steps during inference. The function will be
1311
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1312
- is_cancelled_callback (`Callable`, *optional*):
1313
- A function that will be called every `callback_steps` steps during inference. If the function returns
1314
- `True`, the inference will be cancelled.
1315
  callback_steps (`int`, *optional*, defaults to 1):
1316
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1317
  called at every step.
1318
- cross_attention_kwargs (`dict`, *optional*):
1319
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1320
- `self.processor` in
1321
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1322
-
1323
  Returns:
1324
- `None` if cancelled by `is_cancelled_callback`,
1325
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1326
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1327
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
@@ -1337,47 +1043,40 @@ class StableDiffusionLongPromptWeightingPipeline(
1337
  num_images_per_prompt=num_images_per_prompt,
1338
  eta=eta,
1339
  generator=generator,
1340
- prompt_embeds=prompt_embeds,
1341
- negative_prompt_embeds=negative_prompt_embeds,
1342
  max_embeddings_multiples=max_embeddings_multiples,
1343
  output_type=output_type,
1344
  return_dict=return_dict,
1345
  callback=callback,
1346
- is_cancelled_callback=is_cancelled_callback,
1347
  callback_steps=callback_steps,
1348
- cross_attention_kwargs=cross_attention_kwargs,
1349
  )
1350
 
1351
  def inpaint(
1352
  self,
1353
- image: Union[torch.FloatTensor, PIL.Image.Image],
1354
- mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1355
  prompt: Union[str, List[str]],
1356
  negative_prompt: Optional[Union[str, List[str]]] = None,
1357
  strength: float = 0.8,
1358
  num_inference_steps: Optional[int] = 50,
1359
  guidance_scale: Optional[float] = 7.5,
1360
  num_images_per_prompt: Optional[int] = 1,
1361
- add_predicted_noise: Optional[bool] = False,
1362
  eta: Optional[float] = 0.0,
1363
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1364
- prompt_embeds: Optional[torch.FloatTensor] = None,
1365
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1366
  max_embeddings_multiples: Optional[int] = 3,
1367
  output_type: Optional[str] = "pil",
1368
  return_dict: bool = True,
1369
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1370
- is_cancelled_callback: Optional[Callable[[], bool]] = None,
1371
  callback_steps: int = 1,
1372
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1373
  ):
1374
  r"""
1375
  Function for inpaint.
1376
  Args:
1377
- image (`torch.FloatTensor` or `PIL.Image.Image`):
1378
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
1379
  process. This is the image whose masked region will be inpainted.
1380
- mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1381
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1382
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1383
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
@@ -1403,22 +1102,12 @@ class StableDiffusionLongPromptWeightingPipeline(
1403
  usually at the expense of lower image quality.
1404
  num_images_per_prompt (`int`, *optional*, defaults to 1):
1405
  The number of images to generate per prompt.
1406
- add_predicted_noise (`bool`, *optional*, defaults to True):
1407
- Use predicted noise instead of random noise when constructing noisy versions of the original image in
1408
- the reverse diffusion process
1409
  eta (`float`, *optional*, defaults to 0.0):
1410
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1411
  [`schedulers.DDIMScheduler`], will be ignored for others.
1412
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1413
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1414
- to make generation deterministic.
1415
- prompt_embeds (`torch.FloatTensor`, *optional*):
1416
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1417
- provided, text embeddings will be generated from `prompt` input argument.
1418
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1419
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1420
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1421
- argument.
1422
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1423
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1424
  output_type (`str`, *optional*, defaults to `"pil"`):
@@ -1429,20 +1118,12 @@ class StableDiffusionLongPromptWeightingPipeline(
1429
  plain tuple.
1430
  callback (`Callable`, *optional*):
1431
  A function that will be called every `callback_steps` steps during inference. The function will be
1432
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1433
- is_cancelled_callback (`Callable`, *optional*):
1434
- A function that will be called every `callback_steps` steps during inference. If the function returns
1435
- `True`, the inference will be cancelled.
1436
  callback_steps (`int`, *optional*, defaults to 1):
1437
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1438
  called at every step.
1439
- cross_attention_kwargs (`dict`, *optional*):
1440
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1441
- `self.processor` in
1442
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1443
-
1444
  Returns:
1445
- `None` if cancelled by `is_cancelled_callback`,
1446
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1447
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1448
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
@@ -1457,16 +1138,12 @@ class StableDiffusionLongPromptWeightingPipeline(
1457
  guidance_scale=guidance_scale,
1458
  strength=strength,
1459
  num_images_per_prompt=num_images_per_prompt,
1460
- add_predicted_noise=add_predicted_noise,
1461
  eta=eta,
1462
  generator=generator,
1463
- prompt_embeds=prompt_embeds,
1464
- negative_prompt_embeds=negative_prompt_embeds,
1465
  max_embeddings_multiples=max_embeddings_multiples,
1466
  output_type=output_type,
1467
  return_dict=return_dict,
1468
  callback=callback,
1469
- is_cancelled_callback=is_cancelled_callback,
1470
  callback_steps=callback_steps,
1471
- cross_attention_kwargs=cross_attention_kwargs,
1472
  )
 
1
+ # source https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion_onnx.py
2
  import inspect
3
  import re
4
+ from typing import Callable, List, Optional, Union
5
 
6
  import numpy as np
7
  import PIL.Image
8
  import torch
9
  from packaging import version
10
+ from transformers import CLIPImageProcessor, CLIPTokenizer
11
+
12
+ import diffusers
13
+ from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
14
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
15
+ from diffusers.utils import logging
16
+
17
+
18
+ try:
19
+ from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
20
+ except ImportError:
21
+ ORT_TO_NP_TYPE = {
22
+ "tensor(bool)": np.bool_,
23
+ "tensor(int8)": np.int8,
24
+ "tensor(uint8)": np.uint8,
25
+ "tensor(int16)": np.int16,
26
+ "tensor(uint16)": np.uint16,
27
+ "tensor(int32)": np.int32,
28
+ "tensor(uint32)": np.uint32,
29
+ "tensor(int64)": np.int64,
30
+ "tensor(uint64)": np.uint64,
31
+ "tensor(float16)": np.float16,
32
+ "tensor(float)": np.float32,
33
+ "tensor(double)": np.float64,
34
+ }
35
+
36
+ try:
37
+ from diffusers.utils import PIL_INTERPOLATION
38
+ except ImportError:
39
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
40
+ PIL_INTERPOLATION = {
41
+ "linear": PIL.Image.Resampling.BILINEAR,
42
+ "bilinear": PIL.Image.Resampling.BILINEAR,
43
+ "bicubic": PIL.Image.Resampling.BICUBIC,
44
+ "lanczos": PIL.Image.Resampling.LANCZOS,
45
+ "nearest": PIL.Image.Resampling.NEAREST,
46
+ }
47
+ else:
48
+ PIL_INTERPOLATION = {
49
+ "linear": PIL.Image.LINEAR,
50
+ "bilinear": PIL.Image.BILINEAR,
51
+ "bicubic": PIL.Image.BICUBIC,
52
+ "lanczos": PIL.Image.LANCZOS,
53
+ "nearest": PIL.Image.NEAREST,
54
+ }
55
  # ------------------------------------------------------------------------------
56
 
57
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
162
  return res
163
 
164
 
165
+ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
166
  r"""
167
  Tokenize a list of prompts and return its tokens with weights of each token.
168
 
 
177
  text_weight = []
178
  for word, weight in texts_and_weights:
179
  # tokenize and discard the starting and the ending token
180
+ token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
181
+ text_token += list(token)
182
  # copy the weight by length of token
183
  text_weight += [weight] * len(token)
184
  # stop if the text is too long (longer than truncation limit)
 
223
 
224
 
225
  def get_unweighted_text_embeddings(
226
+ pipe,
227
+ text_input: np.array,
228
  chunk_length: int,
229
  no_boseos_middle: Optional[bool] = True,
230
  ):
 
237
  text_embeddings = []
238
  for i in range(max_embeddings_multiples):
239
  # extract the i-th chunk
240
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy()
241
 
242
  # cover the head and the tail by the starting and the ending tokens
243
  text_input_chunk[:, 0] = text_input[0, 0]
244
  text_input_chunk[:, -1] = text_input[0, -1]
245
+
246
+ text_embedding = pipe.text_encoder(input_ids=text_input_chunk)[0]
247
 
248
  if no_boseos_middle:
249
  if i == 0:
 
257
  text_embedding = text_embedding[:, 1:-1]
258
 
259
  text_embeddings.append(text_embedding)
260
+ text_embeddings = np.concatenate(text_embeddings, axis=1)
261
  else:
262
+ text_embeddings = pipe.text_encoder(input_ids=text_input)[0]
263
  return text_embeddings
264
 
265
 
266
  def get_weighted_text_embeddings(
267
+ pipe,
268
  prompt: Union[str, List[str]],
269
  uncond_prompt: Optional[Union[str, List[str]]] = None,
270
+ max_embeddings_multiples: Optional[int] = 4,
271
  no_boseos_middle: Optional[bool] = False,
272
  skip_parsing: Optional[bool] = False,
273
  skip_weighting: Optional[bool] = False,
274
+ **kwargs,
275
  ):
276
  r"""
277
  Prompts can be assigned with local weights using brackets. For example,
 
281
  Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
282
 
283
  Args:
284
+ pipe (`OnnxStableDiffusionPipeline`):
285
  Pipe to provide access to the tokenizer and the text encoder.
286
  prompt (`str` or `List[str]`):
287
  The prompt or prompts to guide the image generation.
288
  uncond_prompt (`str` or `List[str]`):
289
  The unconditional prompt or prompts for guide the image generation. If unconditional prompt
290
  is provided, the embeddings of prompt and uncond_prompt are concatenated.
291
+ max_embeddings_multiples (`int`, *optional*, defaults to `1`):
292
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
293
  no_boseos_middle (`bool`, *optional*, defaults to `False`):
294
  If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
 
310
  uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
311
  else:
312
  prompt_tokens = [
313
+ token[1:-1]
314
+ for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids
315
  ]
316
  prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
317
  if uncond_prompt is not None:
 
319
  uncond_prompt = [uncond_prompt]
320
  uncond_tokens = [
321
  token[1:-1]
322
+ for token in pipe.tokenizer(
323
+ uncond_prompt,
324
+ max_length=max_length,
325
+ truncation=True,
326
+ return_tensors="np",
327
+ ).input_ids
328
  ]
329
  uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
330
 
 
354
  no_boseos_middle=no_boseos_middle,
355
  chunk_length=pipe.tokenizer.model_max_length,
356
  )
357
+ prompt_tokens = np.array(prompt_tokens, dtype=np.int32)
358
  if uncond_prompt is not None:
359
  uncond_tokens, uncond_weights = pad_tokens_and_weights(
360
  uncond_tokens,
 
366
  no_boseos_middle=no_boseos_middle,
367
  chunk_length=pipe.tokenizer.model_max_length,
368
  )
369
+ uncond_tokens = np.array(uncond_tokens, dtype=np.int32)
370
 
371
  # get the embeddings
372
  text_embeddings = get_unweighted_text_embeddings(
 
375
  pipe.tokenizer.model_max_length,
376
  no_boseos_middle=no_boseos_middle,
377
  )
378
+ prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
379
  if uncond_prompt is not None:
380
  uncond_embeddings = get_unweighted_text_embeddings(
381
  pipe,
 
383
  pipe.tokenizer.model_max_length,
384
  no_boseos_middle=no_boseos_middle,
385
  )
386
+ uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
387
 
388
  # assign weights to the prompts and normalize in the sense of mean
389
  # TODO: should we normalize by chunk or in a whole (current implementation)?
390
  if (not skip_parsing) and (not skip_weighting):
391
+ previous_mean = text_embeddings.mean(axis=(-2, -1))
392
+ text_embeddings *= prompt_weights[:, :, None]
393
+ text_embeddings *= (previous_mean / text_embeddings.mean(axis=(-2, -1)))[:, None, None]
 
394
  if uncond_prompt is not None:
395
+ previous_mean = uncond_embeddings.mean(axis=(-2, -1))
396
+ uncond_embeddings *= uncond_weights[:, :, None]
397
+ uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=(-2, -1)))[:, None, None]
 
398
 
399
+ # For classifier free guidance, we need to do two forward passes.
400
+ # Here we concatenate the unconditional and text embeddings into a single batch
401
+ # to avoid doing two forward passes
402
  if uncond_prompt is not None:
403
  return text_embeddings, uncond_embeddings
404
+
405
+ return text_embeddings
406
 
407
 
408
+ def preprocess_image(image):
409
  w, h = image.size
410
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
411
  image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
412
  image = np.array(image).astype(np.float32) / 255.0
413
+ image = image[None].transpose(0, 3, 1, 2)
 
414
  return 2.0 * image - 1.0
415
 
416
 
417
+ def preprocess_mask(mask, scale_factor=8):
418
+ mask = mask.convert("L")
419
+ w, h = mask.size
420
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
421
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
422
+ mask = np.array(mask).astype(np.float32) / 255.0
423
+ mask = np.tile(mask, (4, 1, 1))
424
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
425
+ mask = 1 - mask # repaint white, keep black
426
+ return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
 
429
+ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline):
 
 
430
  r"""
431
  Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
432
  weighting in prompt.
433
 
434
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
435
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  """
437
 
438
+ if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
439
+
440
+ def __init__(
441
+ self,
442
+ vae_encoder: OnnxRuntimeModel,
443
+ vae_decoder: OnnxRuntimeModel,
444
+ text_encoder: OnnxRuntimeModel,
445
+ tokenizer: CLIPTokenizer,
446
+ unet: OnnxRuntimeModel,
447
+ scheduler: SchedulerMixin,
448
+ safety_checker: OnnxRuntimeModel,
449
+ feature_extractor: CLIPImageProcessor,
450
+ requires_safety_checker: bool = True,
451
+ ):
452
+ super().__init__(
453
+ vae_encoder=vae_encoder,
454
+ vae_decoder=vae_decoder,
455
+ text_encoder=text_encoder,
456
+ tokenizer=tokenizer,
457
+ unet=unet,
458
+ scheduler=scheduler,
459
+ safety_checker=safety_checker,
460
+ feature_extractor=feature_extractor,
461
+ requires_safety_checker=requires_safety_checker,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  )
463
+ self.__init__additional__()
464
 
465
+ else:
 
 
 
 
466
 
467
+ def __init__(
468
+ self,
469
+ vae_encoder: OnnxRuntimeModel,
470
+ vae_decoder: OnnxRuntimeModel,
471
+ text_encoder: OnnxRuntimeModel,
472
+ tokenizer: CLIPTokenizer,
473
+ unet: OnnxRuntimeModel,
474
+ scheduler: SchedulerMixin,
475
+ safety_checker: OnnxRuntimeModel,
476
+ feature_extractor: CLIPImageProcessor,
477
+ ):
478
+ super().__init__(
479
+ vae_encoder=vae_encoder,
480
+ vae_decoder=vae_decoder,
481
+ text_encoder=text_encoder,
482
+ tokenizer=tokenizer,
483
+ unet=unet,
484
+ scheduler=scheduler,
485
+ safety_checker=safety_checker,
486
+ feature_extractor=feature_extractor,
487
  )
488
+ self.__init__additional__()
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ def __init__additional__(self):
491
+ self.unet.config.in_channels = 4
492
+ self.vae_scale_factor = 8
 
493
 
494
+ def _encode_prompt(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  self,
496
  prompt,
 
497
  num_images_per_prompt,
498
  do_classifier_free_guidance,
499
+ negative_prompt,
500
+ max_embeddings_multiples,
 
 
501
  ):
502
  r"""
503
  Encodes the prompt into text encoder hidden states.
 
505
  Args:
506
  prompt (`str` or `list(int)`):
507
  prompt to be encoded
 
 
508
  num_images_per_prompt (`int`):
509
  number of images that should be generated per prompt
510
  do_classifier_free_guidance (`bool`):
 
515
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
516
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
517
  """
518
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
519
+
520
+ if negative_prompt is None:
521
+ negative_prompt = [""] * batch_size
522
+ elif isinstance(negative_prompt, str):
523
+ negative_prompt = [negative_prompt] * batch_size
524
+ if batch_size != len(negative_prompt):
525
+ raise ValueError(
526
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
527
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
528
+ " the batch size of `prompt`."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  )
 
 
 
 
530
 
531
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
532
+ pipe=self,
533
+ prompt=prompt,
534
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
535
+ max_embeddings_multiples=max_embeddings_multiples,
536
+ )
537
 
538
+ text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
539
  if do_classifier_free_guidance:
540
+ uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)
541
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
 
 
542
 
543
+ return text_embeddings
544
 
545
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
546
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
547
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
 
 
 
 
 
 
 
 
 
548
 
549
  if strength < 0 or strength > 1:
550
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
551
 
552
+ if height % 8 != 0 or width % 8 != 0:
553
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
554
+
555
  if (callback_steps is None) or (
556
  callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
557
  ):
 
560
  f" {type(callback_steps)}."
561
  )
562
 
563
+ def get_timesteps(self, num_inference_steps, strength, is_text2img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  if is_text2img:
565
+ return self.scheduler.timesteps, num_inference_steps
566
  else:
567
  # get the original timestep using init_timestep
568
+ offset = self.scheduler.config.get("steps_offset", 0)
569
+ init_timestep = int(num_inference_steps * strength) + offset
570
+ init_timestep = min(init_timestep, num_inference_steps)
 
571
 
572
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
573
+ timesteps = self.scheduler.timesteps[t_start:]
574
  return timesteps, num_inference_steps - t_start
575
 
576
+ def run_safety_checker(self, image):
577
  if self.safety_checker is not None:
578
+ safety_checker_input = self.feature_extractor(
579
+ self.numpy_to_pil(image), return_tensors="np"
580
+ ).pixel_values.astype(image.dtype)
581
+ # There will throw an error if use safety_checker directly and batchsize>1
582
+ images, has_nsfw_concept = [], []
583
+ for i in range(image.shape[0]):
584
+ image_i, has_nsfw_concept_i = self.safety_checker(
585
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
586
+ )
587
+ images.append(image_i)
588
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
589
+ image = np.concatenate(images)
590
  else:
591
  has_nsfw_concept = None
592
  return image, has_nsfw_concept
593
 
594
  def decode_latents(self, latents):
595
+ latents = 1 / 0.18215 * latents
596
+ # image = self.vae_decoder(latent_sample=latents)[0]
597
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
598
+ image = np.concatenate(
599
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
600
+ )
601
+ image = np.clip(image / 2 + 0.5, 0, 1)
602
+ image = image.transpose((0, 2, 3, 1))
603
  return image
604
 
605
  def prepare_extra_step_kwargs(self, generator, eta):
 
619
  extra_step_kwargs["generator"] = generator
620
  return extra_step_kwargs
621
 
622
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  if image is None:
624
+ shape = (
625
+ batch_size,
626
+ self.unet.config.in_channels,
627
+ height // self.vae_scale_factor,
628
+ width // self.vae_scale_factor,
629
+ )
 
630
 
631
  if latents is None:
632
+ latents = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
633
  else:
634
+ if latents.shape != shape:
635
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
636
 
637
  # scale the initial noise by the standard deviation required by the scheduler
638
+ latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy()
639
  return latents, None, None
640
  else:
641
+ init_latents = self.vae_encoder(sample=image)[0]
642
+ init_latents = 0.18215 * init_latents
643
+ init_latents = np.concatenate([init_latents] * batch_size, axis=0)
 
 
 
 
644
  init_latents_orig = init_latents
645
+ shape = init_latents.shape
646
 
647
  # add noise to latents using the timesteps
648
+ noise = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
649
+ latents = self.scheduler.add_noise(
650
+ torch.from_numpy(init_latents), torch.from_numpy(noise), timestep
651
+ ).numpy()
652
  return latents, init_latents_orig, noise
653
 
654
  @torch.no_grad()
 
656
  self,
657
  prompt: Union[str, List[str]],
658
  negative_prompt: Optional[Union[str, List[str]]] = None,
659
+ image: Union[np.ndarray, PIL.Image.Image] = None,
660
+ mask_image: Union[np.ndarray, PIL.Image.Image] = None,
661
  height: int = 512,
662
  width: int = 512,
663
  num_inference_steps: int = 50,
664
  guidance_scale: float = 7.5,
665
  strength: float = 0.8,
666
  num_images_per_prompt: Optional[int] = 1,
 
667
  eta: float = 0.0,
668
+ generator: Optional[torch.Generator] = None,
669
+ latents: Optional[np.ndarray] = None,
 
 
670
  max_embeddings_multiples: Optional[int] = 3,
671
  output_type: Optional[str] = "pil",
672
  return_dict: bool = True,
673
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
674
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
675
  callback_steps: int = 1,
676
+ **kwargs,
677
  ):
678
  r"""
679
  Function invoked when calling the pipeline for generation.
 
684
  negative_prompt (`str` or `List[str]`, *optional*):
685
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
686
  if `guidance_scale` is less than `1`).
687
+ image (`np.ndarray` or `PIL.Image.Image`):
688
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
689
  process.
690
+ mask_image (`np.ndarray` or `PIL.Image.Image`):
691
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
692
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
693
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
 
713
  `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
714
  num_images_per_prompt (`int`, *optional*, defaults to 1):
715
  The number of images to generate per prompt.
 
 
 
716
  eta (`float`, *optional*, defaults to 0.0):
717
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
718
  [`schedulers.DDIMScheduler`], will be ignored for others.
719
+ generator (`torch.Generator`, *optional*):
720
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
721
+ deterministic.
722
+ latents (`np.ndarray`, *optional*):
723
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
724
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
725
  tensor will ge generated by sampling using the supplied random `generator`.
 
 
 
 
 
 
 
726
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
727
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
728
  output_type (`str`, *optional*, defaults to `"pil"`):
 
733
  plain tuple.
734
  callback (`Callable`, *optional*):
735
  A function that will be called every `callback_steps` steps during inference. The function will be
736
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
737
  is_cancelled_callback (`Callable`, *optional*):
738
  A function that will be called every `callback_steps` steps during inference. If the function returns
739
  `True`, the inference will be cancelled.
740
  callback_steps (`int`, *optional*, defaults to 1):
741
  The frequency at which the `callback` function will be called. If not specified, the callback will be
742
  called at every step.
 
 
 
 
743
 
744
  Returns:
745
  `None` if cancelled by `is_cancelled_callback`,
 
754
  width = width or self.unet.config.sample_size * self.vae_scale_factor
755
 
756
  # 1. Check inputs. Raise error if not correct
757
+ self.check_inputs(prompt, height, width, strength, callback_steps)
 
 
758
 
759
  # 2. Define call parameters
760
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
 
 
 
 
 
 
 
761
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
762
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
763
  # corresponds to doing no classifier free guidance.
764
  do_classifier_free_guidance = guidance_scale > 1.0
765
 
766
  # 3. Encode input prompt
767
+ text_embeddings = self._encode_prompt(
768
  prompt,
 
769
  num_images_per_prompt,
770
  do_classifier_free_guidance,
771
  negative_prompt,
772
  max_embeddings_multiples,
 
 
773
  )
774
+ dtype = text_embeddings.dtype
775
 
776
  # 4. Preprocess image and mask
777
  if isinstance(image, PIL.Image.Image):
778
+ image = preprocess_image(image)
779
  if image is not None:
780
+ image = image.astype(dtype)
781
  if isinstance(mask_image, PIL.Image.Image):
782
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
783
  if mask_image is not None:
784
+ mask = mask_image.astype(dtype)
785
+ mask = np.concatenate([mask] * batch_size * num_images_per_prompt)
786
  else:
787
  mask = None
788
 
789
  # 5. set timesteps
790
+ self.scheduler.set_timesteps(num_inference_steps)
791
+ timestep_dtype = next(
792
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
793
+ )
794
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
795
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None)
796
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
797
 
798
  # 6. Prepare latent variables
799
  latents, init_latents_orig, noise = self.prepare_latents(
800
  image,
801
  latent_timestep,
802
+ batch_size * num_images_per_prompt,
 
 
803
  height,
804
  width,
805
  dtype,
 
806
  generator,
807
  latents,
808
  )
 
811
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
812
 
813
  # 8. Denoising loop
814
+ for i, t in enumerate(self.progress_bar(timesteps)):
815
+ # expand the latents if we are doing classifier free guidance
816
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
817
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
818
+ latent_model_input = latent_model_input.numpy()
819
+
820
+ # predict the noise residual
821
+ noise_pred = self.unet(
822
+ sample=latent_model_input,
823
+ timestep=np.array([t], dtype=timestep_dtype),
824
+ encoder_hidden_states=text_embeddings,
825
+ )
826
+ noise_pred = noise_pred[0]
827
+
828
+ # perform guidance
829
+ if do_classifier_free_guidance:
830
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
831
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
832
+
833
+ # compute the previous noisy sample x_t -> x_t-1
834
+ scheduler_output = self.scheduler.step(
835
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
836
+ )
837
+ latents = scheduler_output.prev_sample.numpy()
838
+
839
+ if mask is not None:
840
+ # masking
841
+ init_latents_proper = self.scheduler.add_noise(
842
+ torch.from_numpy(init_latents_orig),
843
+ torch.from_numpy(noise),
844
  t,
845
+ ).numpy()
846
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
847
 
848
+ # call the callback, if provided
849
+ if i % callback_steps == 0:
850
+ if callback is not None:
851
+ step_idx = i // getattr(self.scheduler, "order", 1)
852
+ callback(step_idx, t, latents)
853
+ if is_cancelled_callback is not None and is_cancelled_callback():
854
+ return None
855
 
856
+ # 9. Post-processing
857
+ image = self.decode_latents(latents)
 
 
 
858
 
859
+ # 10. Run safety checker
860
+ image, has_nsfw_concept = self.run_safety_checker(image)
861
 
862
+ # 11. Convert to PIL
863
+ if output_type == "pil":
864
+ image = self.numpy_to_pil(image)
865
 
866
  if not return_dict:
867
  return image, has_nsfw_concept
 
878
  guidance_scale: float = 7.5,
879
  num_images_per_prompt: Optional[int] = 1,
880
  eta: float = 0.0,
881
+ generator: Optional[torch.Generator] = None,
882
+ latents: Optional[np.ndarray] = None,
 
 
883
  max_embeddings_multiples: Optional[int] = 3,
884
  output_type: Optional[str] = "pil",
885
  return_dict: bool = True,
886
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
 
887
  callback_steps: int = 1,
888
+ **kwargs,
889
  ):
890
  r"""
891
  Function for text-to-image generation.
 
913
  eta (`float`, *optional*, defaults to 0.0):
914
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
915
  [`schedulers.DDIMScheduler`], will be ignored for others.
916
+ generator (`torch.Generator`, *optional*):
917
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
918
+ deterministic.
919
+ latents (`np.ndarray`, *optional*):
920
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
921
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
922
  tensor will ge generated by sampling using the supplied random `generator`.
 
 
 
 
 
 
 
923
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
924
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
925
  output_type (`str`, *optional*, defaults to `"pil"`):
 
930
  plain tuple.
931
  callback (`Callable`, *optional*):
932
  A function that will be called every `callback_steps` steps during inference. The function will be
933
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
 
 
 
934
  callback_steps (`int`, *optional*, defaults to 1):
935
  The frequency at which the `callback` function will be called. If not specified, the callback will be
936
  called at every step.
 
 
 
 
 
937
  Returns:
 
938
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
939
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
940
  When returning a tuple, the first element is a list with the generated images, and the second element is a
 
952
  eta=eta,
953
  generator=generator,
954
  latents=latents,
 
 
955
  max_embeddings_multiples=max_embeddings_multiples,
956
  output_type=output_type,
957
  return_dict=return_dict,
958
  callback=callback,
 
959
  callback_steps=callback_steps,
960
+ **kwargs,
961
  )
962
 
963
  def img2img(
964
  self,
965
+ image: Union[np.ndarray, PIL.Image.Image],
966
  prompt: Union[str, List[str]],
967
  negative_prompt: Optional[Union[str, List[str]]] = None,
968
  strength: float = 0.8,
 
970
  guidance_scale: Optional[float] = 7.5,
971
  num_images_per_prompt: Optional[int] = 1,
972
  eta: Optional[float] = 0.0,
973
+ generator: Optional[torch.Generator] = None,
 
 
974
  max_embeddings_multiples: Optional[int] = 3,
975
  output_type: Optional[str] = "pil",
976
  return_dict: bool = True,
977
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
 
978
  callback_steps: int = 1,
979
+ **kwargs,
980
  ):
981
  r"""
982
  Function for image-to-image generation.
983
  Args:
984
+ image (`np.ndarray` or `PIL.Image.Image`):
985
+ `Image`, or ndarray representing an image batch, that will be used as the starting point for the
986
  process.
987
  prompt (`str` or `List[str]`):
988
  The prompt or prompts to guide the image generation.
 
1009
  eta (`float`, *optional*, defaults to 0.0):
1010
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1011
  [`schedulers.DDIMScheduler`], will be ignored for others.
1012
+ generator (`torch.Generator`, *optional*):
1013
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1014
+ deterministic.
 
 
 
 
 
 
 
1015
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1016
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1017
  output_type (`str`, *optional*, defaults to `"pil"`):
 
1022
  plain tuple.
1023
  callback (`Callable`, *optional*):
1024
  A function that will be called every `callback_steps` steps during inference. The function will be
1025
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
 
 
 
1026
  callback_steps (`int`, *optional*, defaults to 1):
1027
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1028
  called at every step.
 
 
 
 
 
1029
  Returns:
1030
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1031
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1032
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1033
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
 
1043
  num_images_per_prompt=num_images_per_prompt,
1044
  eta=eta,
1045
  generator=generator,
 
 
1046
  max_embeddings_multiples=max_embeddings_multiples,
1047
  output_type=output_type,
1048
  return_dict=return_dict,
1049
  callback=callback,
 
1050
  callback_steps=callback_steps,
1051
+ **kwargs,
1052
  )
1053
 
1054
  def inpaint(
1055
  self,
1056
+ image: Union[np.ndarray, PIL.Image.Image],
1057
+ mask_image: Union[np.ndarray, PIL.Image.Image],
1058
  prompt: Union[str, List[str]],
1059
  negative_prompt: Optional[Union[str, List[str]]] = None,
1060
  strength: float = 0.8,
1061
  num_inference_steps: Optional[int] = 50,
1062
  guidance_scale: Optional[float] = 7.5,
1063
  num_images_per_prompt: Optional[int] = 1,
 
1064
  eta: Optional[float] = 0.0,
1065
+ generator: Optional[torch.Generator] = None,
 
 
1066
  max_embeddings_multiples: Optional[int] = 3,
1067
  output_type: Optional[str] = "pil",
1068
  return_dict: bool = True,
1069
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
 
1070
  callback_steps: int = 1,
1071
+ **kwargs,
1072
  ):
1073
  r"""
1074
  Function for inpaint.
1075
  Args:
1076
+ image (`np.ndarray` or `PIL.Image.Image`):
1077
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
1078
  process. This is the image whose masked region will be inpainted.
1079
+ mask_image (`np.ndarray` or `PIL.Image.Image`):
1080
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1081
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1082
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
 
1102
  usually at the expense of lower image quality.
1103
  num_images_per_prompt (`int`, *optional*, defaults to 1):
1104
  The number of images to generate per prompt.
 
 
 
1105
  eta (`float`, *optional*, defaults to 0.0):
1106
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1107
  [`schedulers.DDIMScheduler`], will be ignored for others.
1108
+ generator (`torch.Generator`, *optional*):
1109
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1110
+ deterministic.
 
 
 
 
 
 
 
1111
  max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1112
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
1113
  output_type (`str`, *optional*, defaults to `"pil"`):
 
1118
  plain tuple.
1119
  callback (`Callable`, *optional*):
1120
  A function that will be called every `callback_steps` steps during inference. The function will be
1121
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
 
 
 
1122
  callback_steps (`int`, *optional*, defaults to 1):
1123
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1124
  called at every step.
 
 
 
 
 
1125
  Returns:
1126
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1127
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1128
  When returning a tuple, the first element is a list with the generated images, and the second element is a
1129
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
 
1138
  guidance_scale=guidance_scale,
1139
  strength=strength,
1140
  num_images_per_prompt=num_images_per_prompt,
 
1141
  eta=eta,
1142
  generator=generator,
 
 
1143
  max_embeddings_multiples=max_embeddings_multiples,
1144
  output_type=output_type,
1145
  return_dict=return_dict,
1146
  callback=callback,
 
1147
  callback_steps=callback_steps,
1148
+ **kwargs,
1149
  )
pipeline_lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # source https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
+ import inspect
3
+ import re
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
+
6
+ import numpy as np
7
+ import PIL.Image
8
+ import torch
9
+ from packaging import version
10
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
11
+
12
+ from diffusers import DiffusionPipeline
13
+ from diffusers.configuration_utils import FrozenDict
14
+ from diffusers.image_processor import VaeImageProcessor
15
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
16
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
+ from diffusers.schedulers import KarrasDiffusionSchedulers
19
+ from diffusers.utils import (
20
+ PIL_INTERPOLATION,
21
+ deprecate,
22
+ is_accelerate_available,
23
+ is_accelerate_version,
24
+ logging,
25
+ )
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+
28
+
29
+ # ------------------------------------------------------------------------------
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+ re_attention = re.compile(
34
+ r"""
35
+ \\\(|
36
+ \\\)|
37
+ \\\[|
38
+ \\]|
39
+ \\\\|
40
+ \\|
41
+ \(|
42
+ \[|
43
+ :([+-]?[.\d]+)\)|
44
+ \)|
45
+ ]|
46
+ [^\\()\[\]:]+|
47
+ :
48
+ """,
49
+ re.X,
50
+ )
51
+
52
+
53
+ def parse_prompt_attention(text):
54
+ """
55
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
56
+ Accepted tokens are:
57
+ (abc) - increases attention to abc by a multiplier of 1.1
58
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
59
+ [abc] - decreases attention to abc by a multiplier of 1.1
60
+ \\( - literal character '('
61
+ \\[ - literal character '['
62
+ \\) - literal character ')'
63
+ \\] - literal character ']'
64
+ \\ - literal character '\'
65
+ anything else - just text
66
+ >>> parse_prompt_attention('normal text')
67
+ [['normal text', 1.0]]
68
+ >>> parse_prompt_attention('an (important) word')
69
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
70
+ >>> parse_prompt_attention('(unbalanced')
71
+ [['unbalanced', 1.1]]
72
+ >>> parse_prompt_attention('\\(literal\\]')
73
+ [['(literal]', 1.0]]
74
+ >>> parse_prompt_attention('(unnecessary)(parens)')
75
+ [['unnecessaryparens', 1.1]]
76
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
77
+ [['a ', 1.0],
78
+ ['house', 1.5730000000000004],
79
+ [' ', 1.1],
80
+ ['on', 1.0],
81
+ [' a ', 1.1],
82
+ ['hill', 0.55],
83
+ [', sun, ', 1.1],
84
+ ['sky', 1.4641000000000006],
85
+ ['.', 1.1]]
86
+ """
87
+
88
+ res = []
89
+ round_brackets = []
90
+ square_brackets = []
91
+
92
+ round_bracket_multiplier = 1.1
93
+ square_bracket_multiplier = 1 / 1.1
94
+
95
+ def multiply_range(start_position, multiplier):
96
+ for p in range(start_position, len(res)):
97
+ res[p][1] *= multiplier
98
+
99
+ for m in re_attention.finditer(text):
100
+ text = m.group(0)
101
+ weight = m.group(1)
102
+
103
+ if text.startswith("\\"):
104
+ res.append([text[1:], 1.0])
105
+ elif text == "(":
106
+ round_brackets.append(len(res))
107
+ elif text == "[":
108
+ square_brackets.append(len(res))
109
+ elif weight is not None and len(round_brackets) > 0:
110
+ multiply_range(round_brackets.pop(), float(weight))
111
+ elif text == ")" and len(round_brackets) > 0:
112
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
113
+ elif text == "]" and len(square_brackets) > 0:
114
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
115
+ else:
116
+ res.append([text, 1.0])
117
+
118
+ for pos in round_brackets:
119
+ multiply_range(pos, round_bracket_multiplier)
120
+
121
+ for pos in square_brackets:
122
+ multiply_range(pos, square_bracket_multiplier)
123
+
124
+ if len(res) == 0:
125
+ res = [["", 1.0]]
126
+
127
+ # merge runs of identical weights
128
+ i = 0
129
+ while i + 1 < len(res):
130
+ if res[i][1] == res[i + 1][1]:
131
+ res[i][0] += res[i + 1][0]
132
+ res.pop(i + 1)
133
+ else:
134
+ i += 1
135
+
136
+ return res
137
+
138
+
139
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
140
+ r"""
141
+ Tokenize a list of prompts and return its tokens with weights of each token.
142
+
143
+ No padding, starting or ending token is included.
144
+ """
145
+ tokens = []
146
+ weights = []
147
+ truncated = False
148
+ for text in prompt:
149
+ texts_and_weights = parse_prompt_attention(text)
150
+ text_token = []
151
+ text_weight = []
152
+ for word, weight in texts_and_weights:
153
+ # tokenize and discard the starting and the ending token
154
+ token = pipe.tokenizer(word).input_ids[1:-1]
155
+ text_token += token
156
+ # copy the weight by length of token
157
+ text_weight += [weight] * len(token)
158
+ # stop if the text is too long (longer than truncation limit)
159
+ if len(text_token) > max_length:
160
+ truncated = True
161
+ break
162
+ # truncate
163
+ if len(text_token) > max_length:
164
+ truncated = True
165
+ text_token = text_token[:max_length]
166
+ text_weight = text_weight[:max_length]
167
+ tokens.append(text_token)
168
+ weights.append(text_weight)
169
+ if truncated:
170
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
171
+ return tokens, weights
172
+
173
+
174
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
175
+ r"""
176
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
177
+ """
178
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
179
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
180
+ for i in range(len(tokens)):
181
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
182
+ if no_boseos_middle:
183
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
184
+ else:
185
+ w = []
186
+ if len(weights[i]) == 0:
187
+ w = [1.0] * weights_length
188
+ else:
189
+ for j in range(max_embeddings_multiples):
190
+ w.append(1.0) # weight for starting token in this chunk
191
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
192
+ w.append(1.0) # weight for ending token in this chunk
193
+ w += [1.0] * (weights_length - len(w))
194
+ weights[i] = w[:]
195
+
196
+ return tokens, weights
197
+
198
+
199
+ def get_unweighted_text_embeddings(
200
+ pipe: DiffusionPipeline,
201
+ text_input: torch.Tensor,
202
+ chunk_length: int,
203
+ no_boseos_middle: Optional[bool] = True,
204
+ ):
205
+ """
206
+ When the length of tokens is a multiple of the capacity of the text encoder,
207
+ it should be split into chunks and sent to the text encoder individually.
208
+ """
209
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
210
+ if max_embeddings_multiples > 1:
211
+ text_embeddings = []
212
+ for i in range(max_embeddings_multiples):
213
+ # extract the i-th chunk
214
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
215
+
216
+ # cover the head and the tail by the starting and the ending tokens
217
+ text_input_chunk[:, 0] = text_input[0, 0]
218
+ text_input_chunk[:, -1] = text_input[0, -1]
219
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
220
+
221
+ if no_boseos_middle:
222
+ if i == 0:
223
+ # discard the ending token
224
+ text_embedding = text_embedding[:, :-1]
225
+ elif i == max_embeddings_multiples - 1:
226
+ # discard the starting token
227
+ text_embedding = text_embedding[:, 1:]
228
+ else:
229
+ # discard both starting and ending tokens
230
+ text_embedding = text_embedding[:, 1:-1]
231
+
232
+ text_embeddings.append(text_embedding)
233
+ text_embeddings = torch.concat(text_embeddings, axis=1)
234
+ else:
235
+ text_embeddings = pipe.text_encoder(text_input)[0]
236
+ return text_embeddings
237
+
238
+
239
+ def get_weighted_text_embeddings(
240
+ pipe: DiffusionPipeline,
241
+ prompt: Union[str, List[str]],
242
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
243
+ max_embeddings_multiples: Optional[int] = 3,
244
+ no_boseos_middle: Optional[bool] = False,
245
+ skip_parsing: Optional[bool] = False,
246
+ skip_weighting: Optional[bool] = False,
247
+ ):
248
+ r"""
249
+ Prompts can be assigned with local weights using brackets. For example,
250
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
251
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
252
+
253
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
254
+
255
+ Args:
256
+ pipe (`DiffusionPipeline`):
257
+ Pipe to provide access to the tokenizer and the text encoder.
258
+ prompt (`str` or `List[str]`):
259
+ The prompt or prompts to guide the image generation.
260
+ uncond_prompt (`str` or `List[str]`):
261
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
262
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
263
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
264
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
265
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
266
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
267
+ ending token in each of the chunk in the middle.
268
+ skip_parsing (`bool`, *optional*, defaults to `False`):
269
+ Skip the parsing of brackets.
270
+ skip_weighting (`bool`, *optional*, defaults to `False`):
271
+ Skip the weighting. When the parsing is skipped, it is forced True.
272
+ """
273
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
274
+ if isinstance(prompt, str):
275
+ prompt = [prompt]
276
+
277
+ if not skip_parsing:
278
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
279
+ if uncond_prompt is not None:
280
+ if isinstance(uncond_prompt, str):
281
+ uncond_prompt = [uncond_prompt]
282
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
283
+ else:
284
+ prompt_tokens = [
285
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
286
+ ]
287
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
288
+ if uncond_prompt is not None:
289
+ if isinstance(uncond_prompt, str):
290
+ uncond_prompt = [uncond_prompt]
291
+ uncond_tokens = [
292
+ token[1:-1]
293
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
294
+ ]
295
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
296
+
297
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
298
+ max_length = max([len(token) for token in prompt_tokens])
299
+ if uncond_prompt is not None:
300
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
301
+
302
+ max_embeddings_multiples = min(
303
+ max_embeddings_multiples,
304
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
305
+ )
306
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
307
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
308
+
309
+ # pad the length of tokens and weights
310
+ bos = pipe.tokenizer.bos_token_id
311
+ eos = pipe.tokenizer.eos_token_id
312
+ pad = getattr(pipe.tokenizer, "pad_token_id", eos)
313
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
314
+ prompt_tokens,
315
+ prompt_weights,
316
+ max_length,
317
+ bos,
318
+ eos,
319
+ pad,
320
+ no_boseos_middle=no_boseos_middle,
321
+ chunk_length=pipe.tokenizer.model_max_length,
322
+ )
323
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
324
+ if uncond_prompt is not None:
325
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
326
+ uncond_tokens,
327
+ uncond_weights,
328
+ max_length,
329
+ bos,
330
+ eos,
331
+ pad,
332
+ no_boseos_middle=no_boseos_middle,
333
+ chunk_length=pipe.tokenizer.model_max_length,
334
+ )
335
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
336
+
337
+ # get the embeddings
338
+ text_embeddings = get_unweighted_text_embeddings(
339
+ pipe,
340
+ prompt_tokens,
341
+ pipe.tokenizer.model_max_length,
342
+ no_boseos_middle=no_boseos_middle,
343
+ )
344
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
345
+ if uncond_prompt is not None:
346
+ uncond_embeddings = get_unweighted_text_embeddings(
347
+ pipe,
348
+ uncond_tokens,
349
+ pipe.tokenizer.model_max_length,
350
+ no_boseos_middle=no_boseos_middle,
351
+ )
352
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
353
+
354
+ # assign weights to the prompts and normalize in the sense of mean
355
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
356
+ if (not skip_parsing) and (not skip_weighting):
357
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
358
+ text_embeddings *= prompt_weights.unsqueeze(-1)
359
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
360
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
361
+ if uncond_prompt is not None:
362
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
363
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
364
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
365
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
366
+
367
+ if uncond_prompt is not None:
368
+ return text_embeddings, uncond_embeddings
369
+ return text_embeddings, None
370
+
371
+
372
+ def preprocess_image(image, batch_size):
373
+ w, h = image.size
374
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
375
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
376
+ image = np.array(image).astype(np.float32) / 255.0
377
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
378
+ image = torch.from_numpy(image)
379
+ return 2.0 * image - 1.0
380
+
381
+
382
+ def preprocess_mask(mask, batch_size, scale_factor=8):
383
+ if not isinstance(mask, torch.FloatTensor):
384
+ mask = mask.convert("L")
385
+ w, h = mask.size
386
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
387
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
388
+ mask = np.array(mask).astype(np.float32) / 255.0
389
+ mask = np.tile(mask, (4, 1, 1))
390
+ mask = np.vstack([mask[None]] * batch_size)
391
+ mask = 1 - mask # repaint white, keep black
392
+ mask = torch.from_numpy(mask)
393
+ return mask
394
+
395
+ else:
396
+ valid_mask_channel_sizes = [1, 3]
397
+ # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
398
+ if mask.shape[3] in valid_mask_channel_sizes:
399
+ mask = mask.permute(0, 3, 1, 2)
400
+ elif mask.shape[1] not in valid_mask_channel_sizes:
401
+ raise ValueError(
402
+ f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
403
+ f" but received mask of shape {tuple(mask.shape)}"
404
+ )
405
+ # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
406
+ mask = mask.mean(dim=1, keepdim=True)
407
+ h, w = mask.shape[-2:]
408
+ h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
409
+ mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
410
+ return mask
411
+
412
+
413
+ class StableDiffusionLongPromptWeightingPipeline(
414
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
415
+ ):
416
+ r"""
417
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
418
+ weighting in prompt.
419
+
420
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
421
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
422
+
423
+ Args:
424
+ vae ([`AutoencoderKL`]):
425
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
426
+ text_encoder ([`CLIPTextModel`]):
427
+ Frozen text-encoder. Stable Diffusion uses the text portion of
428
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
429
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
430
+ tokenizer (`CLIPTokenizer`):
431
+ Tokenizer of class
432
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
433
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
434
+ scheduler ([`SchedulerMixin`]):
435
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
436
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
437
+ safety_checker ([`StableDiffusionSafetyChecker`]):
438
+ Classification module that estimates whether generated images could be considered offensive or harmful.
439
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
440
+ feature_extractor ([`CLIPImageProcessor`]):
441
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
442
+ """
443
+
444
+ _optional_components = ["safety_checker", "feature_extractor"]
445
+
446
+ def __init__(
447
+ self,
448
+ vae: AutoencoderKL,
449
+ text_encoder: CLIPTextModel,
450
+ tokenizer: CLIPTokenizer,
451
+ unet: UNet2DConditionModel,
452
+ scheduler: KarrasDiffusionSchedulers,
453
+ safety_checker: StableDiffusionSafetyChecker,
454
+ feature_extractor: CLIPImageProcessor,
455
+ requires_safety_checker: bool = True,
456
+ ):
457
+ super().__init__()
458
+
459
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
460
+ deprecation_message = (
461
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
462
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
463
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
464
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
465
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
466
+ " file"
467
+ )
468
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
469
+ new_config = dict(scheduler.config)
470
+ new_config["steps_offset"] = 1
471
+ scheduler._internal_dict = FrozenDict(new_config)
472
+
473
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
474
+ deprecation_message = (
475
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
476
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
477
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
478
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
479
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
480
+ )
481
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
482
+ new_config = dict(scheduler.config)
483
+ new_config["clip_sample"] = False
484
+ scheduler._internal_dict = FrozenDict(new_config)
485
+
486
+ if safety_checker is None and requires_safety_checker:
487
+ logger.warning(
488
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
489
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
490
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
491
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
492
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
493
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
494
+ )
495
+
496
+ if safety_checker is not None and feature_extractor is None:
497
+ raise ValueError(
498
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
499
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
500
+ )
501
+
502
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
503
+ version.parse(unet.config._diffusers_version).base_version
504
+ ) < version.parse("0.9.0.dev0")
505
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
506
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
507
+ deprecation_message = (
508
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
509
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
510
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
511
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
512
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
513
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
514
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
515
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
516
+ " the `unet/config.json` file"
517
+ )
518
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
519
+ new_config = dict(unet.config)
520
+ new_config["sample_size"] = 64
521
+ unet._internal_dict = FrozenDict(new_config)
522
+ self.register_modules(
523
+ vae=vae,
524
+ text_encoder=text_encoder,
525
+ tokenizer=tokenizer,
526
+ unet=unet,
527
+ scheduler=scheduler,
528
+ safety_checker=safety_checker,
529
+ feature_extractor=feature_extractor,
530
+ )
531
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
532
+
533
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
534
+ self.register_to_config(
535
+ requires_safety_checker=requires_safety_checker,
536
+ )
537
+
538
+ def enable_vae_slicing(self):
539
+ r"""
540
+ Enable sliced VAE decoding.
541
+
542
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
543
+ steps. This is useful to save some memory and allow larger batch sizes.
544
+ """
545
+ self.vae.enable_slicing()
546
+
547
+ def disable_vae_slicing(self):
548
+ r"""
549
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
550
+ computing decoding in one step.
551
+ """
552
+ self.vae.disable_slicing()
553
+
554
+ def enable_vae_tiling(self):
555
+ r"""
556
+ Enable tiled VAE decoding.
557
+
558
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
559
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
560
+ """
561
+ self.vae.enable_tiling()
562
+
563
+ def disable_vae_tiling(self):
564
+ r"""
565
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
566
+ computing decoding in one step.
567
+ """
568
+ self.vae.disable_tiling()
569
+
570
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
571
+ def enable_sequential_cpu_offload(self, gpu_id=0):
572
+ r"""
573
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
574
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
575
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
576
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
577
+ `enable_model_cpu_offload`, but performance is lower.
578
+ """
579
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
580
+ from accelerate import cpu_offload
581
+ else:
582
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
583
+
584
+ device = torch.device(f"cuda:{gpu_id}")
585
+
586
+ if self.device.type != "cpu":
587
+ self.to("cpu", silence_dtype_warnings=True)
588
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
589
+
590
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
591
+ cpu_offload(cpu_offloaded_model, device)
592
+
593
+ if self.safety_checker is not None:
594
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
595
+
596
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
597
+ def enable_model_cpu_offload(self, gpu_id=0):
598
+ r"""
599
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
600
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
601
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
602
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
603
+ """
604
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
605
+ from accelerate import cpu_offload_with_hook
606
+ else:
607
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
608
+
609
+ device = torch.device(f"cuda:{gpu_id}")
610
+
611
+ if self.device.type != "cpu":
612
+ self.to("cpu", silence_dtype_warnings=True)
613
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
614
+
615
+ hook = None
616
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
617
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
618
+
619
+ if self.safety_checker is not None:
620
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
621
+
622
+ # We'll offload the last model manually.
623
+ self.final_offload_hook = hook
624
+
625
+ @property
626
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
627
+ def _execution_device(self):
628
+ r"""
629
+ Returns the device on which the pipeline's models will be executed. After calling
630
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
631
+ hooks.
632
+ """
633
+ if not hasattr(self.unet, "_hf_hook"):
634
+ return self.device
635
+ for module in self.unet.modules():
636
+ if (
637
+ hasattr(module, "_hf_hook")
638
+ and hasattr(module._hf_hook, "execution_device")
639
+ and module._hf_hook.execution_device is not None
640
+ ):
641
+ return torch.device(module._hf_hook.execution_device)
642
+ return self.device
643
+
644
+ def encode_prompt(
645
+ self,
646
+ prompt,
647
+ device,
648
+ num_images_per_prompt,
649
+ do_classifier_free_guidance,
650
+ negative_prompt=None,
651
+ max_embeddings_multiples=3,
652
+ prompt_embeds: Optional[torch.FloatTensor] = None,
653
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
654
+ ):
655
+ r"""
656
+ Encodes the prompt into text encoder hidden states.
657
+
658
+ Args:
659
+ prompt (`str` or `list(int)`):
660
+ prompt to be encoded
661
+ device: (`torch.device`):
662
+ torch device
663
+ num_images_per_prompt (`int`):
664
+ number of images that should be generated per prompt
665
+ do_classifier_free_guidance (`bool`):
666
+ whether to use classifier free guidance or not
667
+ negative_prompt (`str` or `List[str]`):
668
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
669
+ if `guidance_scale` is less than `1`).
670
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
671
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
672
+ """
673
+ if prompt is not None and isinstance(prompt, str):
674
+ batch_size = 1
675
+ elif prompt is not None and isinstance(prompt, list):
676
+ batch_size = len(prompt)
677
+ else:
678
+ batch_size = prompt_embeds.shape[0]
679
+
680
+ if negative_prompt_embeds is None:
681
+ if negative_prompt is None:
682
+ negative_prompt = [""] * batch_size
683
+ elif isinstance(negative_prompt, str):
684
+ negative_prompt = [negative_prompt] * batch_size
685
+ if batch_size != len(negative_prompt):
686
+ raise ValueError(
687
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
688
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
689
+ " the batch size of `prompt`."
690
+ )
691
+ if prompt_embeds is None or negative_prompt_embeds is None:
692
+ if isinstance(self, TextualInversionLoaderMixin):
693
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
694
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
695
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
696
+
697
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
698
+ pipe=self,
699
+ prompt=prompt,
700
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
701
+ max_embeddings_multiples=max_embeddings_multiples,
702
+ )
703
+ if prompt_embeds is None:
704
+ prompt_embeds = prompt_embeds1
705
+ if negative_prompt_embeds is None:
706
+ negative_prompt_embeds = negative_prompt_embeds1
707
+
708
+ bs_embed, seq_len, _ = prompt_embeds.shape
709
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
710
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
711
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
712
+
713
+ if do_classifier_free_guidance:
714
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
715
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
716
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
717
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
718
+
719
+ return prompt_embeds
720
+
721
+ def check_inputs(
722
+ self,
723
+ prompt,
724
+ height,
725
+ width,
726
+ strength,
727
+ callback_steps,
728
+ negative_prompt=None,
729
+ prompt_embeds=None,
730
+ negative_prompt_embeds=None,
731
+ ):
732
+ if height % 8 != 0 or width % 8 != 0:
733
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
734
+
735
+ if strength < 0 or strength > 1:
736
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
737
+
738
+ if (callback_steps is None) or (
739
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
740
+ ):
741
+ raise ValueError(
742
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
743
+ f" {type(callback_steps)}."
744
+ )
745
+
746
+ if prompt is not None and prompt_embeds is not None:
747
+ raise ValueError(
748
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
749
+ " only forward one of the two."
750
+ )
751
+ elif prompt is None and prompt_embeds is None:
752
+ raise ValueError(
753
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
754
+ )
755
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
756
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
757
+
758
+ if negative_prompt is not None and negative_prompt_embeds is not None:
759
+ raise ValueError(
760
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
761
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
762
+ )
763
+
764
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
765
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
766
+ raise ValueError(
767
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
768
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
769
+ f" {negative_prompt_embeds.shape}."
770
+ )
771
+
772
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
773
+ if is_text2img:
774
+ return self.scheduler.timesteps.to(device), num_inference_steps
775
+ else:
776
+ # get the original timestep using init_timestep
777
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
778
+
779
+ t_start = max(num_inference_steps - init_timestep, 0)
780
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
781
+
782
+ return timesteps, num_inference_steps - t_start
783
+
784
+ def run_safety_checker(self, image, device, dtype):
785
+ if self.safety_checker is not None:
786
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
787
+ image, has_nsfw_concept = self.safety_checker(
788
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
789
+ )
790
+ else:
791
+ has_nsfw_concept = None
792
+ return image, has_nsfw_concept
793
+
794
+ def decode_latents(self, latents):
795
+ latents = 1 / self.vae.config.scaling_factor * latents
796
+ image = self.vae.decode(latents).sample
797
+ image = (image / 2 + 0.5).clamp(0, 1)
798
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
799
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
800
+ return image
801
+
802
+ def prepare_extra_step_kwargs(self, generator, eta):
803
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
804
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
805
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
806
+ # and should be between [0, 1]
807
+
808
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
809
+ extra_step_kwargs = {}
810
+ if accepts_eta:
811
+ extra_step_kwargs["eta"] = eta
812
+
813
+ # check if the scheduler accepts generator
814
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
815
+ if accepts_generator:
816
+ extra_step_kwargs["generator"] = generator
817
+ return extra_step_kwargs
818
+
819
+ def prepare_latents(
820
+ self,
821
+ image,
822
+ timestep,
823
+ num_images_per_prompt,
824
+ batch_size,
825
+ num_channels_latents,
826
+ height,
827
+ width,
828
+ dtype,
829
+ device,
830
+ generator,
831
+ latents=None,
832
+ ):
833
+ if image is None:
834
+ batch_size = batch_size * num_images_per_prompt
835
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
836
+ if isinstance(generator, list) and len(generator) != batch_size:
837
+ raise ValueError(
838
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
839
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
840
+ )
841
+
842
+ if latents is None:
843
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
844
+ else:
845
+ latents = latents.to(device)
846
+
847
+ # scale the initial noise by the standard deviation required by the scheduler
848
+ latents = latents * self.scheduler.init_noise_sigma
849
+ return latents, None, None
850
+ else:
851
+ image = image.to(device=self.device, dtype=dtype)
852
+ init_latent_dist = self.vae.encode(image).latent_dist
853
+ init_latents = init_latent_dist.sample(generator=generator)
854
+ init_latents = self.vae.config.scaling_factor * init_latents
855
+
856
+ # Expand init_latents for batch_size and num_images_per_prompt
857
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
858
+ init_latents_orig = init_latents
859
+
860
+ # add noise to latents using the timesteps
861
+ noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
862
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
863
+ latents = init_latents
864
+ return latents, init_latents_orig, noise
865
+
866
+ @torch.no_grad()
867
+ def __call__(
868
+ self,
869
+ prompt: Union[str, List[str]],
870
+ negative_prompt: Optional[Union[str, List[str]]] = None,
871
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
872
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
873
+ height: int = 512,
874
+ width: int = 512,
875
+ num_inference_steps: int = 50,
876
+ guidance_scale: float = 7.5,
877
+ strength: float = 0.8,
878
+ num_images_per_prompt: Optional[int] = 1,
879
+ add_predicted_noise: Optional[bool] = False,
880
+ eta: float = 0.0,
881
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
882
+ latents: Optional[torch.FloatTensor] = None,
883
+ prompt_embeds: Optional[torch.FloatTensor] = None,
884
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
885
+ max_embeddings_multiples: Optional[int] = 3,
886
+ output_type: Optional[str] = "pil",
887
+ return_dict: bool = True,
888
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
889
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
890
+ callback_steps: int = 1,
891
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
892
+ ):
893
+ r"""
894
+ Function invoked when calling the pipeline for generation.
895
+
896
+ Args:
897
+ prompt (`str` or `List[str]`):
898
+ The prompt or prompts to guide the image generation.
899
+ negative_prompt (`str` or `List[str]`, *optional*):
900
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
901
+ if `guidance_scale` is less than `1`).
902
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
903
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
904
+ process.
905
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
906
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
907
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
908
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
909
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
910
+ height (`int`, *optional*, defaults to 512):
911
+ The height in pixels of the generated image.
912
+ width (`int`, *optional*, defaults to 512):
913
+ The width in pixels of the generated image.
914
+ num_inference_steps (`int`, *optional*, defaults to 50):
915
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
916
+ expense of slower inference.
917
+ guidance_scale (`float`, *optional*, defaults to 7.5):
918
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
919
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
920
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
921
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
922
+ usually at the expense of lower image quality.
923
+ strength (`float`, *optional*, defaults to 0.8):
924
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
925
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
926
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
927
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
928
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
929
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
930
+ The number of images to generate per prompt.
931
+ add_predicted_noise (`bool`, *optional*, defaults to True):
932
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
933
+ the reverse diffusion process
934
+ eta (`float`, *optional*, defaults to 0.0):
935
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
936
+ [`schedulers.DDIMScheduler`], will be ignored for others.
937
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
938
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
939
+ to make generation deterministic.
940
+ latents (`torch.FloatTensor`, *optional*):
941
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
942
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
943
+ tensor will ge generated by sampling using the supplied random `generator`.
944
+ prompt_embeds (`torch.FloatTensor`, *optional*):
945
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
946
+ provided, text embeddings will be generated from `prompt` input argument.
947
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
948
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
949
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
950
+ argument.
951
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
952
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
953
+ output_type (`str`, *optional*, defaults to `"pil"`):
954
+ The output format of the generate image. Choose between
955
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
956
+ return_dict (`bool`, *optional*, defaults to `True`):
957
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
958
+ plain tuple.
959
+ callback (`Callable`, *optional*):
960
+ A function that will be called every `callback_steps` steps during inference. The function will be
961
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
962
+ is_cancelled_callback (`Callable`, *optional*):
963
+ A function that will be called every `callback_steps` steps during inference. If the function returns
964
+ `True`, the inference will be cancelled.
965
+ callback_steps (`int`, *optional*, defaults to 1):
966
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
967
+ called at every step.
968
+ cross_attention_kwargs (`dict`, *optional*):
969
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
970
+ `self.processor` in
971
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
972
+
973
+ Returns:
974
+ `None` if cancelled by `is_cancelled_callback`,
975
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
976
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
977
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
978
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
979
+ (nsfw) content, according to the `safety_checker`.
980
+ """
981
+ # 0. Default height and width to unet
982
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
983
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
984
+
985
+ # 1. Check inputs. Raise error if not correct
986
+ self.check_inputs(
987
+ prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
988
+ )
989
+
990
+ # 2. Define call parameters
991
+ if prompt is not None and isinstance(prompt, str):
992
+ batch_size = 1
993
+ elif prompt is not None and isinstance(prompt, list):
994
+ batch_size = len(prompt)
995
+ else:
996
+ batch_size = prompt_embeds.shape[0]
997
+
998
+ device = self._execution_device
999
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1000
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1001
+ # corresponds to doing no classifier free guidance.
1002
+ do_classifier_free_guidance = guidance_scale > 1.0
1003
+
1004
+ # 3. Encode input prompt
1005
+ prompt_embeds = self.encode_prompt(
1006
+ prompt,
1007
+ device,
1008
+ num_images_per_prompt,
1009
+ do_classifier_free_guidance,
1010
+ negative_prompt,
1011
+ max_embeddings_multiples,
1012
+ prompt_embeds=prompt_embeds,
1013
+ negative_prompt_embeds=negative_prompt_embeds,
1014
+ )
1015
+ dtype = prompt_embeds.dtype
1016
+
1017
+ # 4. Preprocess image and mask
1018
+ if isinstance(image, PIL.Image.Image):
1019
+ image = preprocess_image(image, batch_size)
1020
+ if image is not None:
1021
+ image = image.to(device=self.device, dtype=dtype)
1022
+ if isinstance(mask_image, PIL.Image.Image):
1023
+ mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
1024
+ if mask_image is not None:
1025
+ mask = mask_image.to(device=self.device, dtype=dtype)
1026
+ mask = torch.cat([mask] * num_images_per_prompt)
1027
+ else:
1028
+ mask = None
1029
+
1030
+ # 5. set timesteps
1031
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1032
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
1033
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1034
+
1035
+ # 6. Prepare latent variables
1036
+ latents, init_latents_orig, noise = self.prepare_latents(
1037
+ image,
1038
+ latent_timestep,
1039
+ num_images_per_prompt,
1040
+ batch_size,
1041
+ self.unet.config.in_channels,
1042
+ height,
1043
+ width,
1044
+ dtype,
1045
+ device,
1046
+ generator,
1047
+ latents,
1048
+ )
1049
+
1050
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1051
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1052
+
1053
+ # 8. Denoising loop
1054
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1055
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1056
+ for i, t in enumerate(timesteps):
1057
+ # expand the latents if we are doing classifier free guidance
1058
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1059
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1060
+
1061
+ # predict the noise residual
1062
+ noise_pred = self.unet(
1063
+ latent_model_input,
1064
+ t,
1065
+ encoder_hidden_states=prompt_embeds,
1066
+ cross_attention_kwargs=cross_attention_kwargs,
1067
+ ).sample
1068
+
1069
+ # perform guidance
1070
+ if do_classifier_free_guidance:
1071
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1072
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1073
+
1074
+ # compute the previous noisy sample x_t -> x_t-1
1075
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1076
+
1077
+ if mask is not None:
1078
+ # masking
1079
+ if add_predicted_noise:
1080
+ init_latents_proper = self.scheduler.add_noise(
1081
+ init_latents_orig, noise_pred_uncond, torch.tensor([t])
1082
+ )
1083
+ else:
1084
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
1085
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
1086
+
1087
+ # call the callback, if provided
1088
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1089
+ progress_bar.update()
1090
+ if i % callback_steps == 0:
1091
+ if callback is not None:
1092
+ step_idx = i // getattr(self.scheduler, "order", 1)
1093
+ callback(step_idx, t, latents)
1094
+ if is_cancelled_callback is not None and is_cancelled_callback():
1095
+ return None
1096
+
1097
+ if output_type == "latent":
1098
+ image = latents
1099
+ has_nsfw_concept = None
1100
+ elif output_type == "pil":
1101
+ # 9. Post-processing
1102
+ image = self.decode_latents(latents)
1103
+
1104
+ # 10. Run safety checker
1105
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1106
+
1107
+ # 11. Convert to PIL
1108
+ image = self.numpy_to_pil(image)
1109
+ else:
1110
+ # 9. Post-processing
1111
+ image = self.decode_latents(latents)
1112
+
1113
+ # 10. Run safety checker
1114
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1115
+
1116
+ # Offload last model to CPU
1117
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1118
+ self.final_offload_hook.offload()
1119
+
1120
+ if not return_dict:
1121
+ return image, has_nsfw_concept
1122
+
1123
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1124
+
1125
+ def text2img(
1126
+ self,
1127
+ prompt: Union[str, List[str]],
1128
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1129
+ height: int = 512,
1130
+ width: int = 512,
1131
+ num_inference_steps: int = 50,
1132
+ guidance_scale: float = 7.5,
1133
+ num_images_per_prompt: Optional[int] = 1,
1134
+ eta: float = 0.0,
1135
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1136
+ latents: Optional[torch.FloatTensor] = None,
1137
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1138
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1139
+ max_embeddings_multiples: Optional[int] = 3,
1140
+ output_type: Optional[str] = "pil",
1141
+ return_dict: bool = True,
1142
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1143
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1144
+ callback_steps: int = 1,
1145
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1146
+ ):
1147
+ r"""
1148
+ Function for text-to-image generation.
1149
+ Args:
1150
+ prompt (`str` or `List[str]`):
1151
+ The prompt or prompts to guide the image generation.
1152
+ negative_prompt (`str` or `List[str]`, *optional*):
1153
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1154
+ if `guidance_scale` is less than `1`).
1155
+ height (`int`, *optional*, defaults to 512):
1156
+ The height in pixels of the generated image.
1157
+ width (`int`, *optional*, defaults to 512):
1158
+ The width in pixels of the generated image.
1159
+ num_inference_steps (`int`, *optional*, defaults to 50):
1160
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1161
+ expense of slower inference.
1162
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1163
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1164
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1165
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1166
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1167
+ usually at the expense of lower image quality.
1168
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1169
+ The number of images to generate per prompt.
1170
+ eta (`float`, *optional*, defaults to 0.0):
1171
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1172
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1173
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1174
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1175
+ to make generation deterministic.
1176
+ latents (`torch.FloatTensor`, *optional*):
1177
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1178
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1179
+ tensor will ge generated by sampling using the supplied random `generator`.
1180
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1181
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1182
+ provided, text embeddings will be generated from `prompt` input argument.
1183
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1184
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1185
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1186
+ argument.
1187
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1188
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1189
+ output_type (`str`, *optional*, defaults to `"pil"`):
1190
+ The output format of the generate image. Choose between
1191
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1192
+ return_dict (`bool`, *optional*, defaults to `True`):
1193
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1194
+ plain tuple.
1195
+ callback (`Callable`, *optional*):
1196
+ A function that will be called every `callback_steps` steps during inference. The function will be
1197
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1198
+ is_cancelled_callback (`Callable`, *optional*):
1199
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1200
+ `True`, the inference will be cancelled.
1201
+ callback_steps (`int`, *optional*, defaults to 1):
1202
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1203
+ called at every step.
1204
+ cross_attention_kwargs (`dict`, *optional*):
1205
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1206
+ `self.processor` in
1207
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1208
+
1209
+ Returns:
1210
+ `None` if cancelled by `is_cancelled_callback`,
1211
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1212
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1213
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1214
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1215
+ (nsfw) content, according to the `safety_checker`.
1216
+ """
1217
+ return self.__call__(
1218
+ prompt=prompt,
1219
+ negative_prompt=negative_prompt,
1220
+ height=height,
1221
+ width=width,
1222
+ num_inference_steps=num_inference_steps,
1223
+ guidance_scale=guidance_scale,
1224
+ num_images_per_prompt=num_images_per_prompt,
1225
+ eta=eta,
1226
+ generator=generator,
1227
+ latents=latents,
1228
+ prompt_embeds=prompt_embeds,
1229
+ negative_prompt_embeds=negative_prompt_embeds,
1230
+ max_embeddings_multiples=max_embeddings_multiples,
1231
+ output_type=output_type,
1232
+ return_dict=return_dict,
1233
+ callback=callback,
1234
+ is_cancelled_callback=is_cancelled_callback,
1235
+ callback_steps=callback_steps,
1236
+ cross_attention_kwargs=cross_attention_kwargs,
1237
+ )
1238
+
1239
+ def img2img(
1240
+ self,
1241
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1242
+ prompt: Union[str, List[str]],
1243
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1244
+ strength: float = 0.8,
1245
+ num_inference_steps: Optional[int] = 50,
1246
+ guidance_scale: Optional[float] = 7.5,
1247
+ num_images_per_prompt: Optional[int] = 1,
1248
+ eta: Optional[float] = 0.0,
1249
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1250
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1251
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1252
+ max_embeddings_multiples: Optional[int] = 3,
1253
+ output_type: Optional[str] = "pil",
1254
+ return_dict: bool = True,
1255
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1256
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1257
+ callback_steps: int = 1,
1258
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1259
+ ):
1260
+ r"""
1261
+ Function for image-to-image generation.
1262
+ Args:
1263
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1264
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1265
+ process.
1266
+ prompt (`str` or `List[str]`):
1267
+ The prompt or prompts to guide the image generation.
1268
+ negative_prompt (`str` or `List[str]`, *optional*):
1269
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1270
+ if `guidance_scale` is less than `1`).
1271
+ strength (`float`, *optional*, defaults to 0.8):
1272
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1273
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1274
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1275
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1276
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1277
+ num_inference_steps (`int`, *optional*, defaults to 50):
1278
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1279
+ expense of slower inference. This parameter will be modulated by `strength`.
1280
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1281
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1282
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1283
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1284
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1285
+ usually at the expense of lower image quality.
1286
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1287
+ The number of images to generate per prompt.
1288
+ eta (`float`, *optional*, defaults to 0.0):
1289
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1290
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1291
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1292
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1293
+ to make generation deterministic.
1294
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1295
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1296
+ provided, text embeddings will be generated from `prompt` input argument.
1297
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1298
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1299
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1300
+ argument.
1301
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1302
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1303
+ output_type (`str`, *optional*, defaults to `"pil"`):
1304
+ The output format of the generate image. Choose between
1305
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1306
+ return_dict (`bool`, *optional*, defaults to `True`):
1307
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1308
+ plain tuple.
1309
+ callback (`Callable`, *optional*):
1310
+ A function that will be called every `callback_steps` steps during inference. The function will be
1311
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1312
+ is_cancelled_callback (`Callable`, *optional*):
1313
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1314
+ `True`, the inference will be cancelled.
1315
+ callback_steps (`int`, *optional*, defaults to 1):
1316
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1317
+ called at every step.
1318
+ cross_attention_kwargs (`dict`, *optional*):
1319
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1320
+ `self.processor` in
1321
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1322
+
1323
+ Returns:
1324
+ `None` if cancelled by `is_cancelled_callback`,
1325
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1326
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1327
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1328
+ (nsfw) content, according to the `safety_checker`.
1329
+ """
1330
+ return self.__call__(
1331
+ prompt=prompt,
1332
+ negative_prompt=negative_prompt,
1333
+ image=image,
1334
+ num_inference_steps=num_inference_steps,
1335
+ guidance_scale=guidance_scale,
1336
+ strength=strength,
1337
+ num_images_per_prompt=num_images_per_prompt,
1338
+ eta=eta,
1339
+ generator=generator,
1340
+ prompt_embeds=prompt_embeds,
1341
+ negative_prompt_embeds=negative_prompt_embeds,
1342
+ max_embeddings_multiples=max_embeddings_multiples,
1343
+ output_type=output_type,
1344
+ return_dict=return_dict,
1345
+ callback=callback,
1346
+ is_cancelled_callback=is_cancelled_callback,
1347
+ callback_steps=callback_steps,
1348
+ cross_attention_kwargs=cross_attention_kwargs,
1349
+ )
1350
+
1351
+ def inpaint(
1352
+ self,
1353
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1354
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1355
+ prompt: Union[str, List[str]],
1356
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1357
+ strength: float = 0.8,
1358
+ num_inference_steps: Optional[int] = 50,
1359
+ guidance_scale: Optional[float] = 7.5,
1360
+ num_images_per_prompt: Optional[int] = 1,
1361
+ add_predicted_noise: Optional[bool] = False,
1362
+ eta: Optional[float] = 0.0,
1363
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1364
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1365
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1366
+ max_embeddings_multiples: Optional[int] = 3,
1367
+ output_type: Optional[str] = "pil",
1368
+ return_dict: bool = True,
1369
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1370
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1371
+ callback_steps: int = 1,
1372
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1373
+ ):
1374
+ r"""
1375
+ Function for inpaint.
1376
+ Args:
1377
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1378
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1379
+ process. This is the image whose masked region will be inpainted.
1380
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1381
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1382
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1383
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1384
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1385
+ prompt (`str` or `List[str]`):
1386
+ The prompt or prompts to guide the image generation.
1387
+ negative_prompt (`str` or `List[str]`, *optional*):
1388
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1389
+ if `guidance_scale` is less than `1`).
1390
+ strength (`float`, *optional*, defaults to 0.8):
1391
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1392
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1393
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1394
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1395
+ num_inference_steps (`int`, *optional*, defaults to 50):
1396
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1397
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1398
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1399
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1400
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1401
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1402
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1403
+ usually at the expense of lower image quality.
1404
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1405
+ The number of images to generate per prompt.
1406
+ add_predicted_noise (`bool`, *optional*, defaults to True):
1407
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
1408
+ the reverse diffusion process
1409
+ eta (`float`, *optional*, defaults to 0.0):
1410
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1411
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1412
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1413
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1414
+ to make generation deterministic.
1415
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1416
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1417
+ provided, text embeddings will be generated from `prompt` input argument.
1418
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1419
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1420
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1421
+ argument.
1422
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1423
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1424
+ output_type (`str`, *optional*, defaults to `"pil"`):
1425
+ The output format of the generate image. Choose between
1426
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1427
+ return_dict (`bool`, *optional*, defaults to `True`):
1428
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1429
+ plain tuple.
1430
+ callback (`Callable`, *optional*):
1431
+ A function that will be called every `callback_steps` steps during inference. The function will be
1432
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1433
+ is_cancelled_callback (`Callable`, *optional*):
1434
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1435
+ `True`, the inference will be cancelled.
1436
+ callback_steps (`int`, *optional*, defaults to 1):
1437
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1438
+ called at every step.
1439
+ cross_attention_kwargs (`dict`, *optional*):
1440
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1441
+ `self.processor` in
1442
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1443
+
1444
+ Returns:
1445
+ `None` if cancelled by `is_cancelled_callback`,
1446
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1447
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1448
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1449
+ (nsfw) content, according to the `safety_checker`.
1450
+ """
1451
+ return self.__call__(
1452
+ prompt=prompt,
1453
+ negative_prompt=negative_prompt,
1454
+ image=image,
1455
+ mask_image=mask_image,
1456
+ num_inference_steps=num_inference_steps,
1457
+ guidance_scale=guidance_scale,
1458
+ strength=strength,
1459
+ num_images_per_prompt=num_images_per_prompt,
1460
+ add_predicted_noise=add_predicted_noise,
1461
+ eta=eta,
1462
+ generator=generator,
1463
+ prompt_embeds=prompt_embeds,
1464
+ negative_prompt_embeds=negative_prompt_embeds,
1465
+ max_embeddings_multiples=max_embeddings_multiples,
1466
+ output_type=output_type,
1467
+ return_dict=return_dict,
1468
+ callback=callback,
1469
+ is_cancelled_callback=is_cancelled_callback,
1470
+ callback_steps=callback_steps,
1471
+ cross_attention_kwargs=cross_attention_kwargs,
1472
+ )