import torch import torch.nn as nn from transformers import CLIPTextModel from diffusers import ( StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, UNet2DConditionModel, DDIMScheduler, AutoencoderKL, ) from insightface.app import FaceAnalysis from adaface.arc2face_models import CLIPTextModelWrapper from adaface.util import get_arc2face_id_prompt_embs import re, os import sys sys.modules['ldm'] = sys.modules['adaface'] class AdaFaceWrapper(nn.Module): def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device, subject_string='z', num_vectors=16, num_inference_steps=50, negative_prompt=None, use_840k_vae=False, use_ds_text_encoder=False, is_training=False): ''' pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are removed from the pipeline to release RAM. ''' super().__init__() self.pipeline_name = pipeline_name self.base_model_path = base_model_path self.adaface_ckpt_path = adaface_ckpt_path self.use_840k_vae = use_840k_vae self.use_ds_text_encoder = use_ds_text_encoder self.subject_string = subject_string self.num_vectors = num_vectors self.num_inference_steps = num_inference_steps self.device = device self.is_training = is_training self.initialize_pipeline() self.extend_tokenizer_and_text_encoder() if negative_prompt is None: self.negative_prompt = \ "flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \ "mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \ "mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \ "nude, naked, nsfw, topless, bare breasts" else: self.negative_prompt = negative_prompt def load_subj_basis_generator(self, adaface_ckpt_path): ckpt = torch.load(adaface_ckpt_path, map_location='cpu') string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"] if self.subject_string not in string_to_subj_basis_generator_dict: print(f"Subject '{self.subject_string}' not found in the embedding manager.") breakpoint() self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string] # In the original ckpt, num_out_layers is 16 for layerwise embeddings. # But we don't do layerwise embeddings here, so we set it to 1. self.subj_basis_generator.num_out_layers = 1 print(f"Loaded subject basis generator for '{self.subject_string}'.") print(repr(self.subj_basis_generator)) self.subj_basis_generator.to(self.device) if self.is_training: self.subj_basis_generator.train() else: self.subj_basis_generator.eval() def initialize_pipeline(self): self.load_subj_basis_generator(self.adaface_ckpt_path) # arc2face_text_encoder maps the face analysis embedding to 16 face embeddings # in the UNet image space. arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained( 'models/arc2face', subfolder="encoder", torch_dtype=torch.float16 ) self.arc2face_text_encoder = arc2face_text_encoder.to(self.device) if self.use_840k_vae: # The 840000-step vae model is slightly better in face details than the original vae model. # https://huggingface.co/stabilityai/sd-vae-ft-mse-original vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16) else: vae = None if self.use_ds_text_encoder: # The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder. # https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder text_encoder = CLIPTextModel.from_pretrained("models/ds_text_encoder", torch_dtype=torch.float16) else: text_encoder = None remove_unet = False if self.pipeline_name == "img2img": PipelineClass = StableDiffusionImg2ImgPipeline elif self.pipeline_name == "text2img": PipelineClass = StableDiffusionPipeline # pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images. elif self.pipeline_name is None: PipelineClass = StableDiffusionPipeline remove_unet = True else: raise ValueError(f"Unknown pipeline name: {self.pipeline_name}") if os.path.isfile(self.base_model_path): pipeline = PipelineClass.from_single_file( self.base_model_path, torch_dtype=torch.float16 ) else: pipeline = PipelineClass.from_pretrained( self.base_model_path, torch_dtype=torch.float16, safety_checker=None ) print(f"Loaded pipeline from {self.base_model_path}.") if self.use_840k_vae: pipeline.vae = vae print("Replaced the VAE with the 840k-step VAE.") if self.use_ds_text_encoder: pipeline.text_encoder = text_encoder print("Replaced the text encoder with the DreamShaper text encoder.") if remove_unet: # Remove unet and vae to release RAM. Only keep tokenizer and text_encoder. pipeline.unet = None pipeline.vae = None print("Removed UNet and VAE from the pipeline.") noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) pipeline.scheduler = noise_scheduler self.pipeline = pipeline.to(self.device) # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2. # Note there's a second "model" in the path. self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) self.face_app.prepare(ctx_id=0, det_size=(512, 512)) # Patch the missing tokenizer in the subj_basis_generator. if not hasattr(self.subj_basis_generator, 'clip_tokenizer'): self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer print("Patched the missing tokenizer in the subj_basis_generator.") def extend_tokenizer_and_text_encoder(self): if self.num_vectors < 1: raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}") tokenizer = self.pipeline.tokenizer # Add z0, z1, z2, ..., z15. self.placeholder_tokens = [] for i in range(0, self.num_vectors): self.placeholder_tokens.append(f"{self.subject_string}_{i}") self.placeholder_tokens_str = " ".join(self.placeholder_tokens) # Add the new tokens to the tokenizer. num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens) if num_added_tokens != self.num_vectors: raise ValueError( f"The tokenizer already contains the token {self.subject_string}. Please pass a different" " `subject_string` that is not already in the tokenizer.") print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.") # placeholder_token_ids: [49408, ..., 49423]. self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens) # print(self.placeholder_token_ids) # Resize the token embeddings as we are adding new special tokens to the tokenizer old_weight = self.pipeline.text_encoder.get_input_embeddings().weight self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer)) new_weight = self.pipeline.text_encoder.get_input_embeddings().weight print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.") # Extend pipeline.text_encoder with the adaface subject emeddings. # subj_embs: [16, 768]. def update_text_encoder_subj_embs(self, subj_embs): # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data with torch.no_grad(): for i, token_id in enumerate(self.placeholder_token_ids): token_embeds[token_id] = subj_embs[i] print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.") def update_prompt(self, prompt): # If the placeholder tokens are already in the prompt, then return the prompt as is. if self.placeholder_tokens_str in prompt: return prompt # If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt. if re.search(r'\b' + self.subject_string + r'\b', prompt) is None: print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.") comp_prompt = self.placeholder_tokens_str + " " + prompt else: # Replace the subject string 'z' with the placeholder tokens. comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt) return comp_prompt # image_paths: a list of image paths. image_folder: the parent folder name. def generate_adaface_embeddings(self, image_paths, image_folder=None, pre_face_embs=None, gen_rand_face=False, out_id_embs_scale=1., noise_level=0, update_text_encoder=True): # faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512). # If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times. # Otherwise, faceid_embeds is a batch of random embeddings, each instance is different. # The same applies to id_prompt_emb. # faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space. # Here id_batch_size = 1, so # faceid_embeds: [1, 512]. NOT used later. # id_prompt_emb: [1, 16, 768]. # NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings. # arc2face prompt template: "photo of a id person" # ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings. face_image_count, faceid_embeds, id_prompt_emb \ = get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder, extract_faceid_embeds=not gen_rand_face, pre_face_embs=pre_face_embs, # image_folder is passed only for logging purpose. # image_paths contains the paths of the images. image_folder=image_folder, image_paths=image_paths, images_np=None, id_batch_size=1, device=self.device, # input_max_length == 22: only keep the first 22 tokens, # including 3 template tokens and 16 ID tokens, and BOS and EOS tokens. # The results are indistinguishable from input_max_length=77. input_max_length=22, noise_level=noise_level, return_core_id_embs=True, gen_neg_prompt=False, verbose=True) if face_image_count == 0: return None # adaface_subj_embs: [1, 1, 16, 768]. # adaface_prompt_embs: [1, 77, 768] (not used). adaface_subj_embs, adaface_prompt_embs = \ self.subj_basis_generator(id_prompt_emb, None, None, out_id_embs_scale=out_id_embs_scale, is_face=True, is_training=False, adaface_prompt_embs_inf_type='full_half_pad') # adaface_subj_embs: [16, 768] adaface_subj_embs = adaface_subj_embs.squeeze() if update_text_encoder: self.update_text_encoder_subj_embs(adaface_subj_embs) return adaface_subj_embs def encode_prompt(self, prompt, negative_prompt=None, device="cuda", verbose=False): if negative_prompt is None: negative_prompt = self.negative_prompt prompt = self.update_prompt(prompt) if verbose: print(f"Prompt: {prompt}") # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device). # So we manually move it to GPU here. self.pipeline.text_encoder.to(device) # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768] prompt_embeds_, negative_prompt_embeds_ = \ self.pipeline.encode_prompt(prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt) return prompt_embeds_, negative_prompt_embeds_ # ref_img_strength is used only in the img2img pipeline. def forward(self, noise, prompt, negative_prompt=None, guidance_scale=4.0, out_image_count=4, ref_img_strength=0.8, generator=None, verbose=False): if negative_prompt is None: negative_prompt = self.negative_prompt # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768] prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, negative_prompt, device=self.device, verbose=verbose) # Repeat the prompt embeddings for all images in the batch. prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1) negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1) noise = noise.to(self.device).to(torch.float16) # noise: [BS, 4, 64, 64] # When the pipeline is text2img, strength is ignored. images = self.pipeline(image=noise, prompt_embeds=prompt_embeds_, negative_prompt_embeds=negative_prompt_embeds_, num_inference_steps=self.num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=1, strength=ref_img_strength, generator=generator).images # images: [BS, 3, 512, 512] return images