import torch from torch import nn from torch.nn import functional as F from loguru import logger # from prodigyopt import Prodigy from torch.utils.checkpoint import checkpoint from transformers import pipeline # from sbp.nn.model_paths import MODEL_PATHS # # from sbp.nn.torch.models.qformer import ModifiedQFormer class ImageEncoder(nn.Module): def __init__(self, output_dim, base_model='eva02_base_patch14_224.mim_in22k', layer_num=6, seq_len=3, device='cpu', use_pe=False, use_pyramid=False, use_global_feature=False, use_qformer_dim=False): super().__init__() self.output_dim = output_dim import timm # paths = { # 'eva02_large_patch14_448.mim_in22k_ft_in1k': MODEL_PATHS.EVA02_LARGE_448_MIM_IN22K, # 'eva02_base_patch14_224.mim_in22k': MODEL_PATHS.EVA02_BASE_224_MIM_IN22K, # } if base_model == 'eva02_base_patch14_224.mim_in22k': self.img_seq = 257 elif base_model == 'eva02_large_patch14_448.mim_in22k_ft_in1k': self.img_seq = 1025 elif base_model == 'siglip2': self.img_seq = 1024 else: raise ValueError(f" unknown {base_model}, supported: {list(paths.keys())}") # self.base_model = timm.create_model(base_model, pretrained=True, pretrained_cfg_overlay={'file': paths[base_model], 'custom_load': False}) self.base_model = timm.create_model(base_model, pretrained=False) del self.base_model.norm, self.base_model.fc_norm, self.base_model.head, self.base_model.head_drop del self.base_model.blocks[layer_num:] dim_mult = 3 if use_pyramid else 1 image_output_dim = self.base_model.num_features * dim_mult self.seq_len = seq_len self.device = device self.use_pe = use_pe self.use_pyramid = use_pyramid self.use_global_feature = use_global_feature self.use_qformer = use_qformer_dim > 0 if self.use_pe: self.pe = torch.zeros([1, self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16) for i in range(self.seq_len): self.pe[:, i * self.img_seq: (i + 1) * self.img_seq, i::self.seq_len] = 0.05 if self.use_qformer: logger.info("image projection use qformer ...") self.qformer = ModifiedQFormer( input_dim=image_output_dim, hidden_dim=use_qformer_dim, num_heads=12, num_layers=6, output_dim=output_dim, num_queries=512, use_self_attention=False ).cuda() else: self.project = nn.Linear(image_output_dim, output_dim) self.final_norm = nn.LayerNorm(output_dim) def apply_feature_pyramid(self, original_tokens, original_grid_size=32, downsample = [1, 4, 32]): B, seq_len, D = original_tokens[0].shape H = W = original_grid_size token_lst = [] for i, tokens in enumerate(original_tokens): downsample_size = downsample[i] if downsample_size == 0: pass elif downsample_size == 1: token_lst.append(tokens) else: head, tokens = torch.split(tokens, [1, 1024], dim=1) tokens_2d = tokens.view(B, H, W, D).permute(0, 3, 1, 2) # Reshape tokens to 2D grid (B, D, H, W) pooled = F.avg_pool2d(tokens_2d, kernel_size=downsample_size, stride=downsample_size) # (B, D, 32//ds, 32//ds) up = F.interpolate(pooled, size=(H, W), mode='nearest') # (B, D, 32, 32) up = up.permute(0, 2, 3, 1).reshape(B, seq_len - 1, D) up = torch.cat([head, up], dim=1) token_lst.append(up / downsample_size ** 0.5) combined_tokens = torch.cat(token_lst, dim=2) return combined_tokens def apply_global_feature(self, original_tokens, original_grid_size=32, pool_size=4): B, seq_len, D = original_tokens.shape H = W = original_grid_size tokens_2d = original_tokens.view(B, H, W, D).permute(0, 3, 1, 2) pooled = F.avg_pool2d(tokens_2d, kernel_size=pool_size, stride=pool_size) # (B, D, 8, 8) pooled = pooled.permute((0, 2, 3, 1)).reshape((B, seq_len // pool_size // pool_size, D)) return pooled def forward(self, image_list): splits = [len(lst) for lst in image_list] if sum(splits) == 0: return torch.zeros([len(splits), self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16) x = torch.concat(image_list, dim=0).to(device=self.device, dtype=torch.bfloat16) x = self.base_model.patch_embed(x) x, rot_pos_embed = self.base_model._pos_embed(x) intermediates = [] for i, blk in enumerate(self.base_model.blocks): x = blk(x, rope=rot_pos_embed) if i in [11]: intermediates.append(x) intermediates.append(x) if self.use_pyramid: x = self.apply_feature_pyramid(intermediates + [x]) elif self.use_global_feature: x = self.apply_global_feature(x) if self.use_qformer: x = self.qformer(x) else: x = self.project(x) x = self.final_norm(x) b, seq_len, c= x.shape split_patches = torch.split(x, splits, dim=0) split_patches = [nn.functional.pad(sample, (0, 0, 0, 0, 0, self.seq_len - len(sample))) for sample in split_patches] x = torch.stack(split_patches, dim=0) x = x.reshape((len(splits), self.seq_len * seq_len, c)) if self.use_pe: x = x + self.pe return x class ImageEncoderWithSiglip(nn.Module): def __init__(self, output_dim, base_model="siglip2-so400m-patch16-512", layer_num=6, seq_len=3, device='cpu', use_pe=False): super().__init__() self.output_dim = output_dim ckpt = { 'siglip-so400m-patch14-384': MODEL_PATHS.SIGLIP_SO400M_384, 'siglip2-so400m-patch16-512': MODEL_PATHS.SIGLIP2_SO400M_512 }[base_model] image_classifier = pipeline(model=ckpt, task="zero-shot-image-classification", device='cpu') logger.info(f"using {layer_num} / {len(image_classifier.model.vision_model.encoder.layers)} layers of {base_model} ... ") del image_classifier.model.vision_model.encoder.layers[layer_num:] num_features = image_classifier.model.vision_model.post_layernorm.normalized_shape[0] self.base_model = image_classifier.model.vision_model self.project = nn.Linear(num_features, output_dim) self.final_norm = nn.LayerNorm(output_dim) self.seq_len = seq_len self.device = device self.use_pe = use_pe def forward(self, image_list): splits = [len(lst) for lst in image_list] if sum(splits) == 0: return torch.zeros([len(splits), self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16) x = torch.concat(image_list, dim=0).to(device=self.device, dtype=torch.bfloat16) x = self.base_model(x).last_hidden_state x = self.project(x) x = self.final_norm(x) b, seq_len, c= x.shape split_patches = torch.split(x, splits, dim=0) split_patches = [nn.functional.pad(sample, (0, 0, 0, 0, 0, self.seq_len - len(sample))) for sample in split_patches] x = torch.stack(split_patches, dim=0) x = x.reshape((len(splits), self.seq_len * seq_len, c)) if self.use_pe: x = x + self.pe return x