|
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from einops.layers.torch import Rearrange
|
|
from transformers import CLIPVisionModel, CLIPTokenizer
|
|
|
|
import numpy as np
|
|
from torch import einsum
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Tuple
|
|
from transformers.utils import ModelOutput
|
|
from adaface.util import arc2face_inverse_face_prompt_embs, gen_gradient_scaler
|
|
from adaface.arc2face_models import CLIPTextModelWrapper
|
|
import sys
|
|
sys.modules['ldm'] = sys.modules['adaface']
|
|
|
|
def reshape_tensor(x, num_heads):
|
|
bs, length, width = x.shape
|
|
|
|
x = x.view(bs, length, num_heads, -1)
|
|
|
|
x = x.transpose(1, 2)
|
|
|
|
x = x.reshape(bs, num_heads, length, -1)
|
|
return x
|
|
|
|
|
|
def FeedForward(dim, mult=4, p_dropout=0.1):
|
|
inner_dim = int(dim * mult)
|
|
return nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Linear(dim, inner_dim, bias=False),
|
|
nn.GELU(),
|
|
nn.Linear(inner_dim, dim, bias=False),
|
|
nn.Dropout(p_dropout),
|
|
)
|
|
|
|
|
|
|
|
class IP_MLPProjModel(nn.Module):
|
|
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
|
super().__init__()
|
|
|
|
self.cross_attention_dim = cross_attention_dim
|
|
self.num_tokens = num_tokens
|
|
|
|
self.proj = nn.Sequential(
|
|
nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
|
nn.GELU(),
|
|
nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
|
)
|
|
self.norm = nn.LayerNorm(cross_attention_dim)
|
|
|
|
def forward(self, id_embeds):
|
|
x = self.proj(id_embeds)
|
|
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
|
|
class LearnedSoftAggregate(nn.Module):
|
|
def __init__(self, num_feat, group_dim, keepdim=False):
|
|
super(LearnedSoftAggregate, self).__init__()
|
|
self.group_dim = group_dim
|
|
|
|
|
|
self.num_feat = num_feat
|
|
self.feat2score = nn.Linear(num_feat, 1, bias=False)
|
|
self.keepdim = keepdim
|
|
|
|
def forward(self, x, score_basis=None):
|
|
|
|
if x.shape[self.group_dim] == 1:
|
|
if self.keepdim:
|
|
return x
|
|
else:
|
|
return x.squeeze(self.group_dim)
|
|
|
|
|
|
if score_basis is None:
|
|
score_basis = x
|
|
|
|
if self.num_feat == 1:
|
|
mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1)
|
|
else:
|
|
mode_scores = self.feat2score(score_basis)
|
|
attn_probs = mode_scores.softmax(dim=self.group_dim)
|
|
x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim)
|
|
return x_aggr
|
|
|
|
def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes,
|
|
num_output_vecs, elementwise_affine=True, p_dropout=0.1):
|
|
return nn.Sequential(
|
|
|
|
|
|
nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False),
|
|
|
|
Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim),
|
|
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
|
|
|
LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
|
|
else Rearrange('b () q d -> b q d'),
|
|
nn.Dropout(p_dropout),
|
|
|
|
Rearrange('b q d -> b d q'),
|
|
|
|
nn.Linear(lora_rank, num_output_vecs, bias=False),
|
|
|
|
Rearrange('b d q -> b q d'),
|
|
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
|
nn.Dropout(p_dropout),
|
|
)
|
|
|
|
def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1):
|
|
return nn.Sequential(
|
|
|
|
nn.Linear(input_dim, expansion_ratio * output_dim, bias=False),
|
|
|
|
Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim),
|
|
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
|
nn.Dropout(p_dropout),
|
|
)
|
|
|
|
|
|
def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1):
|
|
if output_dim == -1:
|
|
output_dim = input_dim
|
|
|
|
return nn.Sequential(
|
|
nn.Linear(input_dim, output_dim * num_modes, bias=False),
|
|
|
|
Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim),
|
|
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
|
|
|
LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \
|
|
else Rearrange('b n () d -> b n d'),
|
|
nn.Dropout(p_dropout),
|
|
)
|
|
|
|
|
|
def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1):
|
|
return nn.Sequential(
|
|
|
|
Rearrange('b q d -> b d q'),
|
|
|
|
nn.Linear(lora_rank, hira_rank * num_modes, bias=False),
|
|
|
|
Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank),
|
|
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
|
|
|
LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
|
|
else Rearrange('b () q d -> b q d'),
|
|
nn.Dropout(p_dropout),
|
|
)
|
|
|
|
class PerceiverAttention(nn.Module):
|
|
def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True):
|
|
super().__init__()
|
|
self.scale = dim_head**-0.5
|
|
self.dim_head = dim_head
|
|
self.num_heads = num_heads
|
|
inner_dim = dim_head * num_heads
|
|
|
|
self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
|
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
|
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
|
|
def forward(self, x, latent_queries):
|
|
"""
|
|
Args:
|
|
x (torch.Tensor): image features
|
|
shape (b, n1, D)
|
|
latent (torch.Tensor): latent features
|
|
shape (b, n2, D)
|
|
"""
|
|
x = self.norm1(x)
|
|
latent_queries = self.norm2(latent_queries)
|
|
|
|
b, l, _ = latent_queries.shape
|
|
|
|
q = self.to_q(latent_queries)
|
|
kv_input = torch.cat((x, latent_queries), dim=-2)
|
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
|
|
|
q = reshape_tensor(q, self.num_heads)
|
|
k = reshape_tensor(k, self.num_heads)
|
|
v = reshape_tensor(v, self.num_heads)
|
|
|
|
|
|
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
|
weight = (q * scale) @ (k * scale).transpose(-2, -1)
|
|
attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
out = attn @ v
|
|
|
|
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
|
|
|
return self.to_out(out)
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, input_dim, num_heads=6, p_dropout=0.05,
|
|
identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True,
|
|
q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64,
|
|
identity_to_out=False, out_has_skip=False):
|
|
super().__init__()
|
|
dim_head = input_dim // num_heads
|
|
inner_dim = dim_head * num_heads
|
|
|
|
self.num_heads = num_heads
|
|
self.q_aware_to_v = q_aware_to_v
|
|
self.v_has_skip = v_has_skip
|
|
self.to_q = nn.Sequential(
|
|
nn.Linear(input_dim, inner_dim, bias=False),
|
|
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
|
) if not identity_to_q else nn.Identity()
|
|
self.to_k = nn.Sequential(
|
|
nn.Linear(input_dim, inner_dim, bias=False),
|
|
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
|
) if not identity_to_k else nn.Identity()
|
|
|
|
self.v_repeat = v_repeat
|
|
self.num_q_group = num_q_group = num_q // v_repeat
|
|
|
|
|
|
|
|
if q_aware_to_v:
|
|
|
|
all_q_mid = num_q_group * q_aware_to_v_lora_rank
|
|
self.to_v = nn.Sequential(
|
|
|
|
|
|
|
|
nn.Linear(input_dim, all_q_mid, bias=False),
|
|
nn.LayerNorm(all_q_mid, elementwise_affine=True),
|
|
|
|
Rearrange('b n q -> b q n', q=all_q_mid),
|
|
|
|
|
|
|
|
nn.Conv1d(
|
|
in_channels=all_q_mid,
|
|
out_channels=num_q_group * input_dim,
|
|
kernel_size=1,
|
|
groups=num_q_group,
|
|
bias=False,
|
|
),
|
|
|
|
Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim),
|
|
nn.LayerNorm(input_dim, elementwise_affine=True),
|
|
)
|
|
else:
|
|
self.to_v = nn.Sequential(
|
|
nn.Linear(input_dim, inner_dim, bias=False),
|
|
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
|
) if not identity_to_v else nn.Identity()
|
|
|
|
if identity_to_out:
|
|
assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False."
|
|
|
|
if identity_to_out:
|
|
self.to_out = nn.Identity()
|
|
else:
|
|
self.to_out = nn.Sequential(
|
|
nn.Linear(input_dim, input_dim, bias=False),
|
|
nn.Dropout(p_dropout),
|
|
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
|
)
|
|
|
|
self.out_has_skip = out_has_skip
|
|
self.attn_drop = nn.Dropout(p_dropout)
|
|
|
|
def forward(self, x, context=None, attn_mat=None, return_attn=False):
|
|
h = self.num_heads
|
|
|
|
if context is None:
|
|
context = x
|
|
|
|
if attn_mat is None:
|
|
|
|
q = self.to_q(x)
|
|
|
|
k = self.to_k(context)
|
|
|
|
q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k))
|
|
|
|
if self.q_aware_to_v:
|
|
|
|
|
|
v = self.to_v(context)
|
|
if self.v_has_skip:
|
|
v = v + context.unsqueeze(1)
|
|
else:
|
|
|
|
v = self.to_v(context)
|
|
if self.v_has_skip:
|
|
v = v + context
|
|
|
|
|
|
|
|
if self.q_aware_to_v:
|
|
|
|
|
|
v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h)
|
|
|
|
|
|
|
|
v = v.repeat(1, self.v_repeat, 1, 1)
|
|
else:
|
|
v = rearrange(v, 'b n (h d) -> (b h) n d', h=h)
|
|
|
|
if attn_mat is None:
|
|
scale = q.size(-1) ** -0.25
|
|
sim = einsum('b i d, b j d -> b i j', q * scale, k * scale)
|
|
|
|
|
|
|
|
|
|
attn = sim.softmax(dim=-1)
|
|
attn = self.attn_drop(attn)
|
|
|
|
else:
|
|
attn = attn_mat
|
|
|
|
if self.q_aware_to_v:
|
|
|
|
|
|
out = einsum('b i j, b i j d -> b i d', attn, v)
|
|
else:
|
|
|
|
out = einsum('b i j, b j d -> b i d', attn, v)
|
|
|
|
|
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
|
|
|
if self.out_has_skip:
|
|
out = self.to_out(out) + out
|
|
else:
|
|
out = self.to_out(out)
|
|
|
|
if return_attn:
|
|
return out, attn
|
|
else:
|
|
return out
|
|
|
|
class SubjBasisGenerator(nn.Module):
|
|
def __init__(
|
|
self,
|
|
|
|
|
|
num_heads=6,
|
|
num_id_vecs={ 'subj': 77, 'bg': 257 },
|
|
num_out_embs_per_layer=4,
|
|
num_out_layers=16,
|
|
image_embedding_dim=768,
|
|
|
|
|
|
dino_embedding_dim=384,
|
|
output_dim=768,
|
|
placeholder_is_bg: bool = False,
|
|
prompt2token_proj_grad_scale: float = 0.4,
|
|
zs_extra_words_scale: float = 0.5,
|
|
learnable_hidden_state_weights_scheme: str = 'per-layer',
|
|
bg_prompt_translator_has_to_out_proj: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.placeholder_is_bg = placeholder_is_bg
|
|
self.num_out_layers = num_out_layers
|
|
self.num_out_embs_per_layer = num_out_embs_per_layer
|
|
|
|
self.num_out_embs = num_out_layers * num_out_embs_per_layer
|
|
self.output_dim = output_dim
|
|
|
|
|
|
self.num_id_vecs = num_id_vecs['bg'] if placeholder_is_bg else num_id_vecs['subj']
|
|
self.pos_embs = nn.Parameter(torch.randn(1, self.num_id_vecs, output_dim))
|
|
self.pos_embs_ln = nn.LayerNorm(output_dim)
|
|
self.zs_extra_words_scale = zs_extra_words_scale
|
|
self.output_scale = output_dim ** -0.5
|
|
self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
if not self.placeholder_is_bg:
|
|
|
|
|
|
self.obj_proj_in = ExpandEmbs(dino_embedding_dim, output_dim, expansion_ratio=self.num_id_vecs)
|
|
|
|
|
|
|
|
self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
|
|
self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale
|
|
self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale)
|
|
print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.")
|
|
|
|
|
|
if prompt2token_proj_grad_scale == 0:
|
|
self.freeze_prompt2token_proj()
|
|
|
|
self.prompt2token_proj_attention_multiplier = -1
|
|
self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu')
|
|
self.pad_embeddings = None
|
|
self.bg_proj_in = None
|
|
else:
|
|
|
|
self.obj_proj_in = None
|
|
self.prompt2token_proj = None
|
|
print("Bg prompt2token_proj is set to None.")
|
|
|
|
self.bg_proj_in = nn.Sequential(
|
|
nn.Linear(image_embedding_dim, output_dim, bias=False),
|
|
nn.LayerNorm(output_dim),
|
|
)
|
|
|
|
self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim))
|
|
self.latent_queries_ln = nn.LayerNorm(output_dim)
|
|
|
|
self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj
|
|
identity_to_v = False
|
|
v_has_skip = not identity_to_v
|
|
identity_to_out = not bg_prompt_translator_has_to_out_proj
|
|
out_has_skip = not identity_to_out
|
|
|
|
|
|
self.prompt_translator = \
|
|
CrossAttention(input_dim=output_dim, num_heads=num_heads, p_dropout=0.05,
|
|
identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v,
|
|
q_aware_to_v=False, v_has_skip=v_has_skip,
|
|
num_q=0,
|
|
identity_to_out=identity_to_out,
|
|
out_has_skip=out_has_skip)
|
|
'''
|
|
prompt_translator: CLIPEncoder
|
|
# https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566
|
|
# CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being
|
|
(0): CLIPEncoderLayer(
|
|
(self_attn): CLIPAttention(
|
|
(k_proj): Linear(in_features=768, out_features=768, bias=True)
|
|
(v_proj): Linear(in_features=768, out_features=768, bias=True)
|
|
(q_proj): Linear(in_features=768, out_features=768, bias=True)
|
|
(out_proj): Linear(in_features=768, out_features=768, bias=True)
|
|
)
|
|
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
|
(mlp): CLIPMLP(
|
|
(activation_fn): QuickGELUActivation()
|
|
(fc1): Linear(in_features=768, out_features=3072, bias=True)
|
|
(fc2): Linear(in_features=3072, out_features=768, bias=True)
|
|
)
|
|
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
|
)
|
|
'''
|
|
|
|
print(repr(self))
|
|
|
|
|
|
|
|
|
|
def forward(self, arc2face_id_embs, clip_features=None, raw_id_embs=None, out_id_embs_scale=1.0,
|
|
is_face=True, is_training=False, adaface_prompt_embs_inf_type='full_half_pad'):
|
|
|
|
if not self.placeholder_is_bg:
|
|
BS = arc2face_id_embs.shape[0]
|
|
else:
|
|
|
|
BS = clip_features.shape[0]
|
|
|
|
adaface_prompt_embs = None
|
|
if not hasattr(self, 'clip_tokenizer'):
|
|
self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
|
if not self.placeholder_is_bg:
|
|
if is_face:
|
|
assert arc2face_id_embs is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
|
|
|
|
|
|
|
|
|
|
if is_training:
|
|
return_emb_types = ['full_pad', 'core']
|
|
else:
|
|
|
|
return_emb_types = [adaface_prompt_embs_inf_type, 'core']
|
|
|
|
if self.pad_embeddings is None:
|
|
self.generate_pad_embeddings()
|
|
else:
|
|
self.pad_embeddings = self.pad_embeddings.to(arc2face_id_embs.device)
|
|
|
|
with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
adaface_prompt_embs, core_id_embs = \
|
|
arc2face_inverse_face_prompt_embs(self.clip_tokenizer,
|
|
self.prompt2token_proj,
|
|
arc2face_id_embs,
|
|
list_extra_words=None,
|
|
return_emb_types=return_emb_types,
|
|
pad_embeddings=self.pad_embeddings,
|
|
hidden_state_layer_weights=hidden_state_layer_weights,
|
|
input_max_length=77, zs_extra_words_scale=self.zs_extra_words_scale)
|
|
|
|
adaface_prompt_embs = self.prompt2token_proj_grad_scaler(adaface_prompt_embs)
|
|
core_id_embs = self.prompt2token_proj_grad_scaler(core_id_embs)
|
|
elif raw_id_embs is not None:
|
|
|
|
|
|
|
|
id_embs = self.obj_proj_in(raw_id_embs)
|
|
else:
|
|
breakpoint()
|
|
else:
|
|
|
|
|
|
id_embs = self.bg_proj_in(clip_features)
|
|
|
|
if self.placeholder_is_bg:
|
|
id_embs = id_embs + self.pos_embs_ln(self.pos_embs)
|
|
latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
with torch.set_grad_enabled(self.training):
|
|
id_embs_out = self.prompt_translator(latent_queries, id_embs)
|
|
|
|
id_embs_out = id_embs_out.reshape(BS, self.num_out_layers, -1, self.output_dim)
|
|
adaface_subj_embs = id_embs_out * self.output_scale
|
|
else:
|
|
|
|
adaface_subj_embs = core_id_embs.unsqueeze(1).repeat(1, self.num_out_layers, 1, 1)
|
|
|
|
|
|
if out_id_embs_scale != 1:
|
|
|
|
pad_embeddings = self.pad_embeddings[4:4+self.num_out_embs_per_layer].unsqueeze(0).unsqueeze(0)
|
|
adaface_subj_embs = adaface_subj_embs * out_id_embs_scale \
|
|
+ pad_embeddings * (1 - out_id_embs_scale)
|
|
|
|
return adaface_subj_embs, adaface_prompt_embs
|
|
|
|
def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device):
|
|
if learnable_hidden_state_weights_scheme == 'none':
|
|
self.hidden_state_layer_weights = None
|
|
|
|
self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1)
|
|
print("hidden_state_layer_weights is set to None.")
|
|
|
|
elif learnable_hidden_state_weights_scheme == 'per-layer':
|
|
|
|
|
|
|
|
self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device),
|
|
requires_grad=True)
|
|
self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5)
|
|
print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.")
|
|
else:
|
|
breakpoint()
|
|
|
|
def generate_pad_embeddings(self):
|
|
|
|
|
|
|
|
clip_embeddings = self.prompt2token_proj.text_model.embeddings
|
|
|
|
|
|
|
|
|
|
|
|
pad_tokens = torch.tensor([self.clip_tokenizer.pad_token_id]).to(clip_embeddings.token_embedding.weight.device).repeat(77)
|
|
|
|
pad_embeddings = clip_embeddings(pad_tokens)[0]
|
|
|
|
|
|
self.pad_embeddings = pad_embeddings.detach()
|
|
|
|
def extend_prompt2token_proj_attention(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
|
|
if multiplier > 1:
|
|
num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(begin_layer_idx, end_layer_idx, multiplier, noise_std)
|
|
self.prompt2token_proj_attention_multiplier = multiplier
|
|
print(f"{num_extended_layers} layers in prompt2token_proj_attention are x{multiplier}")
|
|
|
|
def freeze_prompt2token_proj(self):
|
|
|
|
|
|
if self.prompt2token_proj is not None:
|
|
frozen_param_names = []
|
|
for param_name, param in self.prompt2token_proj.named_parameters():
|
|
if param.requires_grad:
|
|
param.requires_grad = False
|
|
frozen_param_names.append(param_name)
|
|
|
|
print(f"{len(frozen_param_names)} params in Subj prompt2token_proj is frozen.")
|
|
|
|
|
|
def __repr__(self):
|
|
type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
|
|
|
|
if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
|
|
self.bg_prompt_translator_has_to_out_proj = False
|
|
if not hasattr(self, 'num_out_embs'):
|
|
self.num_out_embs = -1
|
|
return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \
|
|
f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}"
|
|
|
|
@dataclass
|
|
class BaseModelOutputWithPooling2(ModelOutput):
|
|
"""
|
|
Base class for model's outputs that also contains a pooling of the last hidden states.
|
|
|
|
Args:
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Sequence of hidden-states at the output of the last layer of the model.
|
|
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
|
Last layer hidden-state of the first token of the sequence (classification token) after further processing
|
|
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
|
|
the classification token after processing through a linear layer and a tanh activation function. The linear
|
|
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
"""
|
|
|
|
last_hidden_state: torch.FloatTensor = None
|
|
pooler_output: torch.FloatTensor = None
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
attn_mask: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def CLIPVisionTransformer_forward(self, pixel_values = None, attn_mask=None,
|
|
output_attentions = None,
|
|
output_hidden_states = None, return_dict = None):
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if pixel_values is None:
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = self.embeddings(pixel_values)
|
|
hidden_states = self.pre_layrnorm(hidden_states)
|
|
|
|
if attn_mask is not None:
|
|
|
|
feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int)
|
|
|
|
attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest')
|
|
|
|
attn_mask = attn_mask.flatten(2)
|
|
|
|
|
|
attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1)
|
|
attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1)
|
|
else:
|
|
attn_mask_pairs = None
|
|
|
|
|
|
encoder_outputs = self.encoder(
|
|
inputs_embeds=hidden_states,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attention_mask=attn_mask_pairs,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
|
|
last_hidden_state = encoder_outputs[0]
|
|
pooled_output = last_hidden_state[:, 0, :]
|
|
pooled_output = self.post_layernorm(pooled_output)
|
|
|
|
|
|
if not return_dict:
|
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithPooling2(
|
|
last_hidden_state=last_hidden_state,
|
|
pooler_output=pooled_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
|
|
|
|
attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None
|
|
)
|
|
|
|
|
|
class CLIPVisionModelWithMask(CLIPVisionModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.vision_model.forward = CLIPVisionTransformer_forward.__get__(self.vision_model)
|
|
|
|
def forward(self, pixel_values = None, attn_mask = None, output_attentions = None,
|
|
output_hidden_states = None, return_dict = None):
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
return self.vision_model(
|
|
pixel_values=pixel_values,
|
|
attn_mask=attn_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
|