Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from loguru import logger | |
# from prodigyopt import Prodigy | |
from torch.utils.checkpoint import checkpoint | |
from transformers import pipeline | |
# from sbp.nn.model_paths import MODEL_PATHS | |
# # from sbp.nn.torch.models.qformer import ModifiedQFormer | |
class ImageEncoder(nn.Module): | |
def __init__(self, output_dim, base_model='eva02_base_patch14_224.mim_in22k', layer_num=6, seq_len=3, device='cpu', use_pe=False, use_pyramid=False, use_global_feature=False, use_qformer_dim=False): | |
super().__init__() | |
self.output_dim = output_dim | |
import timm | |
# paths = { | |
# 'eva02_large_patch14_448.mim_in22k_ft_in1k': MODEL_PATHS.EVA02_LARGE_448_MIM_IN22K, | |
# 'eva02_base_patch14_224.mim_in22k': MODEL_PATHS.EVA02_BASE_224_MIM_IN22K, | |
# } | |
if base_model == 'eva02_base_patch14_224.mim_in22k': | |
self.img_seq = 257 | |
elif base_model == 'eva02_large_patch14_448.mim_in22k_ft_in1k': | |
self.img_seq = 1025 | |
elif base_model == 'siglip2': | |
self.img_seq = 1024 | |
else: | |
raise ValueError(f" unknown {base_model}, supported: {list(paths.keys())}") | |
# self.base_model = timm.create_model(base_model, pretrained=True, pretrained_cfg_overlay={'file': paths[base_model], 'custom_load': False}) | |
self.base_model = timm.create_model(base_model, pretrained=False) | |
del self.base_model.norm, self.base_model.fc_norm, self.base_model.head, self.base_model.head_drop | |
del self.base_model.blocks[layer_num:] | |
dim_mult = 3 if use_pyramid else 1 | |
image_output_dim = self.base_model.num_features * dim_mult | |
self.seq_len = seq_len | |
self.device = device | |
self.use_pe = use_pe | |
self.use_pyramid = use_pyramid | |
self.use_global_feature = use_global_feature | |
self.use_qformer = use_qformer_dim > 0 | |
if self.use_pe: | |
self.pe = torch.zeros([1, self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16) | |
for i in range(self.seq_len): | |
self.pe[:, i * self.img_seq: (i + 1) * self.img_seq, i::self.seq_len] = 0.05 | |
if self.use_qformer: | |
logger.info("image projection use qformer ...") | |
self.qformer = ModifiedQFormer( | |
input_dim=image_output_dim, | |
hidden_dim=use_qformer_dim, | |
num_heads=12, | |
num_layers=6, | |
output_dim=output_dim, | |
num_queries=512, | |
use_self_attention=False | |
).cuda() | |
else: | |
self.project = nn.Linear(image_output_dim, output_dim) | |
self.final_norm = nn.LayerNorm(output_dim) | |
def apply_feature_pyramid(self, original_tokens, original_grid_size=32, downsample = [1, 4, 32]): | |
B, seq_len, D = original_tokens[0].shape | |
H = W = original_grid_size | |
token_lst = [] | |
for i, tokens in enumerate(original_tokens): | |
downsample_size = downsample[i] | |
if downsample_size == 0: | |
pass | |
elif downsample_size == 1: | |
token_lst.append(tokens) | |
else: | |
head, tokens = torch.split(tokens, [1, 1024], dim=1) | |
tokens_2d = tokens.view(B, H, W, D).permute(0, 3, 1, 2) # Reshape tokens to 2D grid (B, D, H, W) | |
pooled = F.avg_pool2d(tokens_2d, kernel_size=downsample_size, stride=downsample_size) # (B, D, 32//ds, 32//ds) | |
up = F.interpolate(pooled, size=(H, W), mode='nearest') # (B, D, 32, 32) | |
up = up.permute(0, 2, 3, 1).reshape(B, seq_len - 1, D) | |
up = torch.cat([head, up], dim=1) | |
token_lst.append(up / downsample_size ** 0.5) | |
combined_tokens = torch.cat(token_lst, dim=2) | |
return combined_tokens | |
def apply_global_feature(self, original_tokens, original_grid_size=32, pool_size=4): | |
B, seq_len, D = original_tokens.shape | |
H = W = original_grid_size | |
tokens_2d = original_tokens.view(B, H, W, D).permute(0, 3, 1, 2) | |
pooled = F.avg_pool2d(tokens_2d, kernel_size=pool_size, stride=pool_size) # (B, D, 8, 8) | |
pooled = pooled.permute((0, 2, 3, 1)).reshape((B, seq_len // pool_size // pool_size, D)) | |
return pooled | |
def forward(self, image_list): | |
splits = [len(lst) for lst in image_list] | |
if sum(splits) == 0: | |
return torch.zeros([len(splits), self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16) | |
x = torch.concat(image_list, dim=0).to(device=self.device, dtype=torch.bfloat16) | |
x = self.base_model.patch_embed(x) | |
x, rot_pos_embed = self.base_model._pos_embed(x) | |
intermediates = [] | |
for i, blk in enumerate(self.base_model.blocks): | |
x = blk(x, rope=rot_pos_embed) | |
if i in [11]: | |
intermediates.append(x) | |
intermediates.append(x) | |
if self.use_pyramid: | |
x = self.apply_feature_pyramid(intermediates + [x]) | |
elif self.use_global_feature: | |
x = self.apply_global_feature(x) | |
if self.use_qformer: | |
x = self.qformer(x) | |
else: | |
x = self.project(x) | |
x = self.final_norm(x) | |
b, seq_len, c= x.shape | |
split_patches = torch.split(x, splits, dim=0) | |
split_patches = [nn.functional.pad(sample, (0, 0, 0, 0, 0, self.seq_len - len(sample))) for sample in split_patches] | |
x = torch.stack(split_patches, dim=0) | |
x = x.reshape((len(splits), self.seq_len * seq_len, c)) | |
if self.use_pe: | |
x = x + self.pe | |
return x | |
class ImageEncoderWithSiglip(nn.Module): | |
def __init__(self, output_dim, base_model="siglip2-so400m-patch16-512", layer_num=6, seq_len=3, device='cpu', use_pe=False): | |
super().__init__() | |
self.output_dim = output_dim | |
ckpt = { | |
'siglip-so400m-patch14-384': MODEL_PATHS.SIGLIP_SO400M_384, | |
'siglip2-so400m-patch16-512': MODEL_PATHS.SIGLIP2_SO400M_512 | |
}[base_model] | |
image_classifier = pipeline(model=ckpt, task="zero-shot-image-classification", device='cpu') | |
logger.info(f"using {layer_num} / {len(image_classifier.model.vision_model.encoder.layers)} layers of {base_model} ... ") | |
del image_classifier.model.vision_model.encoder.layers[layer_num:] | |
num_features = image_classifier.model.vision_model.post_layernorm.normalized_shape[0] | |
self.base_model = image_classifier.model.vision_model | |
self.project = nn.Linear(num_features, output_dim) | |
self.final_norm = nn.LayerNorm(output_dim) | |
self.seq_len = seq_len | |
self.device = device | |
self.use_pe = use_pe | |
def forward(self, image_list): | |
splits = [len(lst) for lst in image_list] | |
if sum(splits) == 0: | |
return torch.zeros([len(splits), self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16) | |
x = torch.concat(image_list, dim=0).to(device=self.device, dtype=torch.bfloat16) | |
x = self.base_model(x).last_hidden_state | |
x = self.project(x) | |
x = self.final_norm(x) | |
b, seq_len, c= x.shape | |
split_patches = torch.split(x, splits, dim=0) | |
split_patches = [nn.functional.pad(sample, (0, 0, 0, 0, 0, self.seq_len - len(sample))) for sample in split_patches] | |
x = torch.stack(split_patches, dim=0) | |
x = x.reshape((len(splits), self.seq_len * seq_len, c)) | |
if self.use_pe: | |
x = x + self.pe | |
return x | |