myMagicClothing / garment_adapter /garment_ipadapter_faceid.py
yzy0713's picture
Add files
6a05036
import os
import pdb
from typing import List
import numpy as np
import torch
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from garment_seg.process import load_seg_model, generate_mask
from utils.utils import is_torch2_available, prepare_image, prepare_mask
import copy
from utils.resampler import PerceiverAttention, FeedForward
from insightface.utils import face_align
from insightface.app import FaceAnalysis
import cv2
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
if is_torch2_available() and (not USE_DAFAULT_ATTN):
from .attention_processor import AttnProcessor2_0 as AttnProcessor
from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
from .attention_processor import REFAttnProcessor2_0 as REFAttnProcessor
else:
from .attention_processor import AttnProcessor, IPAttnProcessor, REFAttnProcessor
class FacePerceiverResampler(torch.nn.Module):
def __init__(
self,
*,
dim=768,
depth=4,
dim_head=64,
heads=16,
embedding_dim=1280,
output_dim=768,
ff_mult=4,
):
super().__init__()
self.proj_in = torch.nn.Linear(embedding_dim, dim)
self.proj_out = torch.nn.Linear(dim, output_dim)
self.norm_out = torch.nn.LayerNorm(output_dim)
self.layers = torch.nn.ModuleList([])
for _ in range(depth):
self.layers.append(
torch.nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, latents, x):
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x
class ProjPlusModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
self.perceiver_resampler = FacePerceiverResampler(
dim=cross_attention_dim,
depth=4,
dim_head=64,
heads=cross_attention_dim // 64,
embedding_dim=clip_embeddings_dim,
output_dim=cross_attention_dim,
ff_mult=4,
)
def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
out = self.perceiver_resampler(x, clip_embeds)
if shortcut:
out = x + scale * out
return out
class IPAdapterFaceID:
def __init__(self, sd_pipe, ref_path, ip_ckpt, device, enable_cloth_guidance, num_tokens=4, n_cond=1, torch_dtype=torch.float16, set_seg_model=True):
self.enable_cloth_guidance = enable_cloth_guidance
self.device = device
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
self.n_cond = n_cond
self.torch_dtype = torch_dtype
self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()
# image proj model
self.image_proj_model = self.init_proj()
self.load_ip_adapter()
self.set_insightface()
ref_unet = copy.deepcopy(sd_pipe.unet)
state_dict = {}
with safe_open(ref_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
ref_unet.load_state_dict(state_dict, strict=False)
self.ref_unet = ref_unet.to(self.device)
self.set_ref_adapter()
if set_seg_model:
self.set_seg_model()
self.attn_store = {}
def set_insightface(self):
self.app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app.prepare(ctx_id=0, det_size=(640, 640))
def set_seg_model(self, ):
checkpoint_path = 'checkpoints/cloth_segm.pth'
self.seg_net = load_seg_model(checkpoint_path, device=self.device)
def init_proj(self):
image_proj_model = MLPProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
id_embeddings_dim=512,
num_tokens=self.num_tokens,
).to(self.device, dtype=self.torch_dtype)
return image_proj_model
def set_ref_adapter(self):
attn_procs = {}
for name in self.ref_unet.attn_processors.keys():
if "attn1" in name:
attn_procs[name] = REFAttnProcessor(name=name, type="read")
else:
attn_procs[name] = AttnProcessor()
self.ref_unet.set_attn_processor(attn_procs)
def set_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = REFAttnProcessor(name=name, type="write")
else:
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens * self.n_cond,
).to(self.device, dtype=self.torch_dtype)
unet.set_attn_processor(attn_procs)
def load_ip_adapter(self):
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
@torch.inference_mode()
def get_image_embeds(self, faceid_embeds):
multi_face = False
if faceid_embeds.dim() == 3:
multi_face = True
b, n, c = faceid_embeds.shape
faceid_embeds = faceid_embeds.reshape(b * n, c)
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
image_prompt_embeds = self.image_proj_model(faceid_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
if multi_face:
c = image_prompt_embeds.size(-1)
image_prompt_embeds = image_prompt_embeds.reshape(b, -1, c)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.reshape(b, -1, c)
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
def generate(
self,
cloth_image,
face_image,
cloth_mask=None,
prompt=None,
a_prompt="best quality, high quality",
negative_prompt=None,
num_samples=4,
seed=None,
guidance_scale=3.,
cloth_guidance_scale=3.,
num_inference_steps=30,
height=512,
width=384,
scale=1.0,
**kwargs,
):
faces = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
try:
faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
except:
return None
if cloth_mask is None:
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device)
cloth = prepare_image(cloth_image, height, width)
cloth_mask = prepare_mask(cloth_mask_image, height, width)
cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16)
self.set_scale(scale)
num_prompts = faceid_embeds.size(0)
if prompt is None:
prompt = "a photography of a model"
prompt = prompt + ", " + a_prompt
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=False)[0]
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor
self.ref_unet(torch.cat([cloth_embeds] * num_samples), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store})
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance},
**kwargs,
).images
return images, cloth_mask_image
class IPAdapterFaceIDPlus:
def __init__(self, sd_pipe, ref_path, image_encoder_path, ip_ckpt, device, enable_cloth_guidance, num_tokens=4, torch_dtype=torch.float16, set_seg_model=True):
self.enable_cloth_guidance = enable_cloth_guidance
self.device = device
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
self.torch_dtype = torch_dtype
self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=self.torch_dtype
)
self.clip_image_processor = CLIPImageProcessor()
# image proj model
self.image_proj_model = self.init_proj()
self.load_ip_adapter()
self.set_insightface()
ref_unet = copy.deepcopy(sd_pipe.unet)
state_dict = {}
with safe_open(ref_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
ref_unet.load_state_dict(state_dict, strict=False)
self.ref_unet = ref_unet.to(self.device)
self.set_ref_adapter()
if set_seg_model:
self.set_seg_model()
self.attn_store = {}
def set_insightface(self):
self.app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app.prepare(ctx_id=0, det_size=(640, 640))
def set_seg_model(self, ):
checkpoint_path = 'checkpoints/cloth_segm.pth'
self.seg_net = load_seg_model(checkpoint_path, device=self.device)
def init_proj(self):
image_proj_model = ProjPlusModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
id_embeddings_dim=512,
clip_embeddings_dim=self.image_encoder.config.hidden_size,
num_tokens=self.num_tokens,
).to(self.device, dtype=self.torch_dtype)
return image_proj_model
def set_ref_adapter(self):
attn_procs = {}
for name in self.ref_unet.attn_processors.keys():
if "attn1" in name:
attn_procs[name] = REFAttnProcessor(name=name, type="read")
else:
attn_procs[name] = AttnProcessor()
self.ref_unet.set_attn_processor(attn_procs)
def set_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = REFAttnProcessor(name=name, type="write")
else:
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens,
).to(self.device, dtype=self.torch_dtype)
unet.set_attn_processor(attn_procs)
def load_ip_adapter(self):
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
@torch.inference_mode()
def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
def generate(
self,
cloth_image,
face_image,
cloth_mask=None,
prompt=None,
a_prompt="best quality, high quality",
negative_prompt=None,
num_samples=4,
seed=None,
guidance_scale=2.5,
cloth_guidance_scale=2.5,
num_inference_steps=20,
height=512,
width=384,
scale=1.0,
s_scale=1.,
shortcut=False,
**kwargs,
):
face_image = cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)
faces = self.app.get(face_image)
try:
faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
face_image = face_align.norm_crop(face_image, landmark=faces[0].kps, image_size=224)
except:
return None
if cloth_mask is None:
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device)
cloth = prepare_image(cloth_image, height, width)
cloth_mask = prepare_mask(cloth_mask_image, height, width)
cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16)
self.set_scale(scale)
num_prompts = faceid_embeds.size(0)
if prompt is None:
prompt = "a photography of a model"
prompt = prompt + ", " + a_prompt
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=False)[0]
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor
self.ref_unet(torch.cat([cloth_embeds] * num_samples), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store})
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
if self.enable_cloth_guidance:
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
cloth_guidance_scale=cloth_guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance},
**kwargs,
).images
else:
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance},
**kwargs,
).images
return images, cloth_mask_image
class IPAdapterFaceIDXL(IPAdapterFaceID):
"""SDXL"""
def generate(
self,
faceid_embeds=None,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
num_prompts = faceid_embeds.size(0)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
"""SDXL"""
def generate(
self,
face_image=None,
faceid_embeds=None,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
guidance_scale=7.5,
num_inference_steps=30,
s_scale=1.0,
shortcut=True,
**kwargs,
):
self.set_scale(scale)
num_prompts = faceid_embeds.size(0)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale,
**kwargs,
).images
return images