Flux-Consistancy-v2 / projection.py
zs38's picture
init
c82ab6a
raw
history blame contribute delete
7.69 kB
import torch
from torch import nn
from torch.nn import functional as F
from loguru import logger
# from prodigyopt import Prodigy
from torch.utils.checkpoint import checkpoint
from transformers import pipeline
# from sbp.nn.model_paths import MODEL_PATHS
# # from sbp.nn.torch.models.qformer import ModifiedQFormer
class ImageEncoder(nn.Module):
def __init__(self, output_dim, base_model='eva02_base_patch14_224.mim_in22k', layer_num=6, seq_len=3, device='cpu', use_pe=False, use_pyramid=False, use_global_feature=False, use_qformer_dim=False):
super().__init__()
self.output_dim = output_dim
import timm
# paths = {
# 'eva02_large_patch14_448.mim_in22k_ft_in1k': MODEL_PATHS.EVA02_LARGE_448_MIM_IN22K,
# 'eva02_base_patch14_224.mim_in22k': MODEL_PATHS.EVA02_BASE_224_MIM_IN22K,
# }
if base_model == 'eva02_base_patch14_224.mim_in22k':
self.img_seq = 257
elif base_model == 'eva02_large_patch14_448.mim_in22k_ft_in1k':
self.img_seq = 1025
elif base_model == 'siglip2':
self.img_seq = 1024
else:
raise ValueError(f" unknown {base_model}, supported: {list(paths.keys())}")
# self.base_model = timm.create_model(base_model, pretrained=True, pretrained_cfg_overlay={'file': paths[base_model], 'custom_load': False})
self.base_model = timm.create_model(base_model, pretrained=False)
del self.base_model.norm, self.base_model.fc_norm, self.base_model.head, self.base_model.head_drop
del self.base_model.blocks[layer_num:]
dim_mult = 3 if use_pyramid else 1
image_output_dim = self.base_model.num_features * dim_mult
self.seq_len = seq_len
self.device = device
self.use_pe = use_pe
self.use_pyramid = use_pyramid
self.use_global_feature = use_global_feature
self.use_qformer = use_qformer_dim > 0
if self.use_pe:
self.pe = torch.zeros([1, self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16)
for i in range(self.seq_len):
self.pe[:, i * self.img_seq: (i + 1) * self.img_seq, i::self.seq_len] = 0.05
if self.use_qformer:
logger.info("image projection use qformer ...")
self.qformer = ModifiedQFormer(
input_dim=image_output_dim,
hidden_dim=use_qformer_dim,
num_heads=12,
num_layers=6,
output_dim=output_dim,
num_queries=512,
use_self_attention=False
).cuda()
else:
self.project = nn.Linear(image_output_dim, output_dim)
self.final_norm = nn.LayerNorm(output_dim)
def apply_feature_pyramid(self, original_tokens, original_grid_size=32, downsample = [1, 4, 32]):
B, seq_len, D = original_tokens[0].shape
H = W = original_grid_size
token_lst = []
for i, tokens in enumerate(original_tokens):
downsample_size = downsample[i]
if downsample_size == 0:
pass
elif downsample_size == 1:
token_lst.append(tokens)
else:
head, tokens = torch.split(tokens, [1, 1024], dim=1)
tokens_2d = tokens.view(B, H, W, D).permute(0, 3, 1, 2) # Reshape tokens to 2D grid (B, D, H, W)
pooled = F.avg_pool2d(tokens_2d, kernel_size=downsample_size, stride=downsample_size) # (B, D, 32//ds, 32//ds)
up = F.interpolate(pooled, size=(H, W), mode='nearest') # (B, D, 32, 32)
up = up.permute(0, 2, 3, 1).reshape(B, seq_len - 1, D)
up = torch.cat([head, up], dim=1)
token_lst.append(up / downsample_size ** 0.5)
combined_tokens = torch.cat(token_lst, dim=2)
return combined_tokens
def apply_global_feature(self, original_tokens, original_grid_size=32, pool_size=4):
B, seq_len, D = original_tokens.shape
H = W = original_grid_size
tokens_2d = original_tokens.view(B, H, W, D).permute(0, 3, 1, 2)
pooled = F.avg_pool2d(tokens_2d, kernel_size=pool_size, stride=pool_size) # (B, D, 8, 8)
pooled = pooled.permute((0, 2, 3, 1)).reshape((B, seq_len // pool_size // pool_size, D))
return pooled
def forward(self, image_list):
splits = [len(lst) for lst in image_list]
if sum(splits) == 0:
return torch.zeros([len(splits), self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16)
x = torch.concat(image_list, dim=0).to(device=self.device, dtype=torch.bfloat16)
x = self.base_model.patch_embed(x)
x, rot_pos_embed = self.base_model._pos_embed(x)
intermediates = []
for i, blk in enumerate(self.base_model.blocks):
x = blk(x, rope=rot_pos_embed)
if i in [11]:
intermediates.append(x)
intermediates.append(x)
if self.use_pyramid:
x = self.apply_feature_pyramid(intermediates + [x])
elif self.use_global_feature:
x = self.apply_global_feature(x)
if self.use_qformer:
x = self.qformer(x)
else:
x = self.project(x)
x = self.final_norm(x)
b, seq_len, c= x.shape
split_patches = torch.split(x, splits, dim=0)
split_patches = [nn.functional.pad(sample, (0, 0, 0, 0, 0, self.seq_len - len(sample))) for sample in split_patches]
x = torch.stack(split_patches, dim=0)
x = x.reshape((len(splits), self.seq_len * seq_len, c))
if self.use_pe:
x = x + self.pe
return x
class ImageEncoderWithSiglip(nn.Module):
def __init__(self, output_dim, base_model="siglip2-so400m-patch16-512", layer_num=6, seq_len=3, device='cpu', use_pe=False):
super().__init__()
self.output_dim = output_dim
ckpt = {
'siglip-so400m-patch14-384': MODEL_PATHS.SIGLIP_SO400M_384,
'siglip2-so400m-patch16-512': MODEL_PATHS.SIGLIP2_SO400M_512
}[base_model]
image_classifier = pipeline(model=ckpt, task="zero-shot-image-classification", device='cpu')
logger.info(f"using {layer_num} / {len(image_classifier.model.vision_model.encoder.layers)} layers of {base_model} ... ")
del image_classifier.model.vision_model.encoder.layers[layer_num:]
num_features = image_classifier.model.vision_model.post_layernorm.normalized_shape[0]
self.base_model = image_classifier.model.vision_model
self.project = nn.Linear(num_features, output_dim)
self.final_norm = nn.LayerNorm(output_dim)
self.seq_len = seq_len
self.device = device
self.use_pe = use_pe
def forward(self, image_list):
splits = [len(lst) for lst in image_list]
if sum(splits) == 0:
return torch.zeros([len(splits), self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16)
x = torch.concat(image_list, dim=0).to(device=self.device, dtype=torch.bfloat16)
x = self.base_model(x).last_hidden_state
x = self.project(x)
x = self.final_norm(x)
b, seq_len, c= x.shape
split_patches = torch.split(x, splits, dim=0)
split_patches = [nn.functional.pad(sample, (0, 0, 0, 0, 0, self.seq_len - len(sample))) for sample in split_patches]
x = torch.stack(split_patches, dim=0)
x = x.reshape((len(splits), self.seq_len * seq_len, c))
if self.use_pe:
x = x + self.pe
return x