|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from PIL import Image
|
|
import cv2
|
|
|
|
|
|
def add_noise_to_tensor(ts, noise_std, noise_std_is_relative=True, keep_norm=False,
|
|
std_dim=-1, norm_dim=-1):
|
|
if noise_std_is_relative:
|
|
ts_std_mean = ts.std(dim=std_dim).mean().detach()
|
|
noise_std *= ts_std_mean
|
|
|
|
noise = torch.randn_like(ts) * noise_std
|
|
if keep_norm:
|
|
orig_norm = ts.norm(dim=norm_dim, keepdim=True)
|
|
ts = ts + noise
|
|
new_norm = ts.norm(dim=norm_dim, keepdim=True).detach()
|
|
ts = ts * orig_norm / (new_norm + 1e-8)
|
|
else:
|
|
ts = ts + noise
|
|
|
|
return ts
|
|
|
|
|
|
|
|
class ScaleGrad(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input_, alpha_, debug=False):
|
|
ctx.save_for_backward(alpha_, debug)
|
|
output = input_
|
|
if debug:
|
|
print(f"input: {input_.abs().mean().item()}")
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
|
|
alpha_, debug = ctx.saved_tensors
|
|
if ctx.needs_input_grad[0]:
|
|
grad_output2 = grad_output * alpha_
|
|
if debug:
|
|
print(f"grad_output2: {grad_output2.abs().mean().item()}")
|
|
else:
|
|
grad_output2 = None
|
|
return grad_output2, None, None
|
|
|
|
class GradientScaler(nn.Module):
|
|
def __init__(self, alpha=1., debug=False, *args, **kwargs):
|
|
"""
|
|
A gradient scaling layer.
|
|
This layer has no parameters, and simply scales the gradient in the backward pass.
|
|
"""
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self._alpha = torch.tensor(alpha, requires_grad=False)
|
|
self._debug = torch.tensor(debug, requires_grad=False)
|
|
|
|
def forward(self, input_):
|
|
_debug = self._debug if hasattr(self, '_debug') else False
|
|
return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
|
|
|
|
def gen_gradient_scaler(alpha, debug=False):
|
|
if alpha == 1:
|
|
return nn.Identity()
|
|
if alpha > 0:
|
|
return GradientScaler(alpha, debug=debug)
|
|
else:
|
|
assert alpha == 0
|
|
|
|
return torch.detach
|
|
|
|
|
|
|
|
def arc2face_forward_face_embs(tokenizer, arc2face_text_encoder, face_embs,
|
|
input_max_length=77, return_full_and_core_embs=True):
|
|
|
|
'''
|
|
arc2face_text_encoder: arc2face_models.py CLIPTextModelWrapper instance.
|
|
face_embs: (N, 512) normalized ArcFace embeddings.
|
|
return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
|
|
If False, return only the core embeddings.
|
|
|
|
'''
|
|
|
|
|
|
arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0]
|
|
|
|
|
|
input_ids = tokenizer(
|
|
"photo of a id person",
|
|
truncation=True,
|
|
padding="max_length",
|
|
max_length=input_max_length,
|
|
return_tensors="pt",
|
|
).input_ids.to(face_embs.device)
|
|
|
|
input_ids = input_ids.repeat(len(face_embs), 1)
|
|
face_embs_dtype = face_embs.dtype
|
|
face_embs = face_embs.to(arc2face_text_encoder.dtype)
|
|
|
|
face_embs_padded = F.pad(face_embs, (0, arc2face_text_encoder.config.hidden_size - face_embs.shape[-1]), "constant", 0)
|
|
|
|
|
|
token_embs = arc2face_text_encoder(input_ids=input_ids, return_token_embs=True)
|
|
token_embs[input_ids==arcface_token_id] = face_embs_padded
|
|
|
|
prompt_embeds = arc2face_text_encoder(
|
|
input_ids=input_ids,
|
|
input_token_embs=token_embs,
|
|
return_token_embs=False
|
|
)[0]
|
|
|
|
|
|
prompt_embeds = prompt_embeds.to(face_embs_dtype)
|
|
|
|
if return_full_and_core_embs:
|
|
|
|
|
|
|
|
return prompt_embeds, prompt_embeds[:, 4:20]
|
|
else:
|
|
|
|
return prompt_embeds[:, 4:20]
|
|
|
|
def get_b_core_e_embeddings(prompt_embeds, length=22):
|
|
b_core_e_embs = torch.cat([ prompt_embeds[:, :length], prompt_embeds[:, [-1]] ], dim=1)
|
|
return b_core_e_embs
|
|
|
|
|
|
def arc2face_inverse_face_prompt_embs(clip_tokenizer, inverse_text_encoder, face_prompt_embs, list_extra_words,
|
|
return_emb_types, pad_embeddings, hidden_state_layer_weights=None,
|
|
input_max_length=77, zs_extra_words_scale=0.5):
|
|
|
|
'''
|
|
inverse_text_encoder: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
|
|
inverse_text_encoder is NOT the original arc2face text encoder, but retrained to do inverse mapping.
|
|
face_prompt_embs: (BS, 16, 768). Only the core embeddings, no paddings.
|
|
list_extra_words: [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt.
|
|
return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
|
|
If False, return only the core embeddings.
|
|
'''
|
|
|
|
if list_extra_words is not None:
|
|
if len(list_extra_words) != len(face_prompt_embs):
|
|
if len(face_prompt_embs) > 1:
|
|
print("Warn: list_extra_words has different length as face_prompt_embs.")
|
|
if len(list_extra_words) == 1:
|
|
list_extra_words = list_extra_words * len(face_prompt_embs)
|
|
else:
|
|
breakpoint()
|
|
else:
|
|
|
|
|
|
list_extra_words = list_extra_words[:1]
|
|
|
|
for extra_words in list_extra_words:
|
|
assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words."
|
|
|
|
prompt_templates = [ "photo of a " + ", " * 16 + list_extra_words[i] for i in range(len(list_extra_words)) ]
|
|
else:
|
|
|
|
|
|
prompt_templates = [ "photo of a " + ", " * 16 for _ in range(len(face_prompt_embs)) ]
|
|
|
|
|
|
|
|
input_ids = clip_tokenizer(
|
|
prompt_templates,
|
|
truncation=True,
|
|
padding="max_length",
|
|
max_length=input_max_length,
|
|
return_tensors="pt",
|
|
).input_ids.to(face_prompt_embs.device)
|
|
|
|
face_prompt_embs_dtype = face_prompt_embs.dtype
|
|
face_prompt_embs = face_prompt_embs.to(inverse_text_encoder.dtype)
|
|
|
|
|
|
token_embs = inverse_text_encoder(input_ids=input_ids, return_token_embs=True)
|
|
|
|
|
|
token_embs[:, 4:20] = face_prompt_embs
|
|
|
|
|
|
prompt_embeds = inverse_text_encoder(
|
|
input_ids=input_ids,
|
|
input_token_embs=token_embs,
|
|
hidden_state_layer_weights=hidden_state_layer_weights,
|
|
return_token_embs=False
|
|
)[0]
|
|
|
|
|
|
prompt_embeds = prompt_embeds.to(face_prompt_embs_dtype)
|
|
|
|
|
|
|
|
|
|
core_prompt_embs = prompt_embeds[:, 4:20]
|
|
if list_extra_words is not None:
|
|
|
|
extra_words_embs = prompt_embeds[:, 20:22] * zs_extra_words_scale
|
|
core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1)
|
|
|
|
return_prompts = []
|
|
for emb_type in return_emb_types:
|
|
if emb_type == 'full':
|
|
return_prompts.append(prompt_embeds)
|
|
elif emb_type == 'full_half_pad':
|
|
prompt_embeds2 = prompt_embeds.clone()
|
|
PADS = prompt_embeds2.shape[1] - 23
|
|
if PADS >= 2:
|
|
|
|
prompt_embeds2[:, 22:22+PADS//2] = pad_embeddings[22:22+PADS//2]
|
|
return_prompts.append(prompt_embeds2)
|
|
elif emb_type == 'full_pad':
|
|
prompt_embeds2 = prompt_embeds.clone()
|
|
|
|
prompt_embeds2[:, 22:-1] = pad_embeddings[22:-1]
|
|
return_prompts.append(prompt_embeds2)
|
|
elif emb_type == 'core':
|
|
return_prompts.append(core_prompt_embs)
|
|
elif emb_type == 'full_zeroed_extra':
|
|
prompt_embeds2 = prompt_embeds.clone()
|
|
|
|
|
|
prompt_embeds2[:, 22:24] = pad_embeddings[22:24]
|
|
prompt_embeds2[:, 24:-1] = 0
|
|
return_prompts.append(prompt_embeds2)
|
|
elif emb_type == 'b_core_e':
|
|
|
|
b_core_e_embs = get_b_core_e_embeddings(prompt_embeds, length=22)
|
|
return_prompts.append(b_core_e_embs)
|
|
else:
|
|
breakpoint()
|
|
|
|
return return_prompts
|
|
|
|
|
|
|
|
def get_arc2face_id_prompt_embs(face_app, clip_tokenizer, arc2face_text_encoder,
|
|
extract_faceid_embeds, pre_face_embs,
|
|
image_folder, image_paths, images_np,
|
|
id_batch_size, device,
|
|
input_max_length=77, noise_level=0.0,
|
|
return_core_id_embs=False,
|
|
gen_neg_prompt=False, verbose=False):
|
|
face_image_count = 0
|
|
|
|
if extract_faceid_embeds:
|
|
faceid_embeds = []
|
|
if image_paths is not None:
|
|
images_np = []
|
|
for image_path in image_paths:
|
|
image_np = np.array(Image.open(image_path))
|
|
images_np.append(image_np)
|
|
|
|
for i, image_np in enumerate(images_np):
|
|
image_obj = Image.fromarray(image_np).resize((512, 512), Image.NEAREST)
|
|
|
|
if image_obj.mode == 'RGBA':
|
|
image_obj = image_obj.convert('RGB')
|
|
|
|
|
|
image_np = cv2.cvtColor(np.array(image_obj), cv2.COLOR_RGB2BGR)
|
|
image_np = np.array(image_obj)
|
|
|
|
face_infos = face_app.get(image_np)
|
|
if verbose and image_paths is not None:
|
|
print(image_paths[i], len(face_infos))
|
|
|
|
if len(face_infos) == 0:
|
|
continue
|
|
|
|
face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]
|
|
|
|
faceid_embeds.append(torch.from_numpy(face_info.normed_embedding).unsqueeze(0))
|
|
face_image_count += 1
|
|
|
|
if verbose:
|
|
if image_folder is not None:
|
|
print(f"Extracted ID embeddings from {face_image_count} images in {image_folder}")
|
|
else:
|
|
print(f"Extracted ID embeddings from {face_image_count} images")
|
|
|
|
if len(faceid_embeds) == 0:
|
|
print("No face detected. Use a random face instead.")
|
|
faceid_embeds = torch.randn(id_batch_size, 512).to(device=device, dtype=torch.float16)
|
|
else:
|
|
|
|
faceid_embeds = torch.cat(faceid_embeds, dim=0)
|
|
|
|
|
|
faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True).to(device=device, dtype=torch.float16)
|
|
else:
|
|
|
|
if pre_face_embs is None:
|
|
faceid_embeds = torch.randn(id_batch_size, 512)
|
|
else:
|
|
faceid_embeds = pre_face_embs
|
|
if pre_face_embs.shape[0] == 1:
|
|
faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
|
|
|
|
faceid_embeds = faceid_embeds.to(device=device, dtype=torch.float16)
|
|
|
|
if noise_level > 0:
|
|
|
|
faceid_embeds = add_noise_to_tensor(faceid_embeds, noise_level, noise_std_is_relative=True, keep_norm=True)
|
|
|
|
faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1)
|
|
|
|
|
|
with torch.no_grad():
|
|
arc2face_pos_prompt_emb, arc2face_pos_core_prompt_emb = \
|
|
arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
|
|
faceid_embeds, input_max_length=input_max_length,
|
|
return_full_and_core_embs=True)
|
|
if return_core_id_embs:
|
|
arc2face_pos_prompt_emb = arc2face_pos_core_prompt_emb
|
|
|
|
|
|
if extract_faceid_embeds:
|
|
faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
|
|
arc2face_pos_prompt_emb = arc2face_pos_prompt_emb.repeat(id_batch_size, 1, 1)
|
|
|
|
if gen_neg_prompt:
|
|
with torch.no_grad():
|
|
arc2face_neg_prompt_emb, arc2face_neg_core_prompt_emb = \
|
|
arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
|
|
torch.zeros_like(faceid_embeds),
|
|
input_max_length=input_max_length,
|
|
return_full_and_core_embs=True)
|
|
if return_core_id_embs:
|
|
arc2face_neg_prompt_emb = arc2face_neg_core_prompt_emb
|
|
|
|
|
|
|
|
return face_image_count, faceid_embeds, arc2face_pos_prompt_emb, arc2face_neg_prompt_emb
|
|
else:
|
|
return face_image_count, faceid_embeds, arc2face_pos_prompt_emb
|
|
|