import torch import torch.nn.functional as F from torch import nn from .backbone import build_backbone import pdb import numpy as np from typing import Optional class TokenOCR(nn.Module): def __init__(self, backbone): """ Initializes the model. Parameters: backbone: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py num_classes: number of object classes """ super().__init__() self.language_embedding = nn.Embedding(92553, 2048, padding_idx=2) for p in self.parameters(): p.requires_grad = False self.backbone = backbone init_tau=np.log(10) init_b=-2.71 # self.t_prime = nn.Parameter(torch.ones([]) * init_tau) # self.b = nn.Parameter(torch.ones([]) * init_b) self.kb = True self.upsample = nn.Sequential( nn.ConvTranspose2d( in_channels=2048, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False ), nn.SyncBatchNorm(512), nn.ConvTranspose2d( in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False ), nn.SyncBatchNorm(512), ) self.ocr_mlp = nn.Sequential( nn.Linear(512, 2048), nn.GELU(), nn.Linear(2048, 2048) ) def forward(self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor = None, image_flags: Optional[torch.LongTensor] = None, mask_values: Optional[torch.LongTensor] = None, masks_flags: Optional[torch.LongTensor] = None, mask_nums: Optional[torch.LongTensor] = None, ): image_flags = image_flags.squeeze(-1) try: input_embeds = self.language_embedding(input_ids).clone() except: print('error'*1000) import pdb; pdb.set_trace() # import pdb; pdb.set_trace() vit_embeds, vit_embeds_shape = self.extract_feature_custom(pixel_values) #(vit_batch_size, 16*16, 2048) nb, nl, nd = vit_embeds.shape h, w = vit_embeds_shape vit_embeds = vit_embeds.reshape(nb, h, w, nd) vit_embeds = vit_embeds.split(list(image_flags)) #[(vit_batch_size / B, h, w, C)]*B vit_batch_size = pixel_values.shape[0] B, N, C = input_embeds.shape try: assert sum(image_flags) == mask_values.shape[0] except: print((mask_values.shape, image_flags, mask_nums)) mask_values = torch.nn.functional.interpolate(mask_values.float(), size=(h, w), mode='bilinear', align_corners=False) #(128, 128) masks = mask_values.split(list(image_flags)) #[(vit_batch_size / B, N, 448, 448)]*B masks_flags = masks_flags.chunk(B) token_features = [] input_embedings = [] masked_input_ids = [] masked_zero_bools = [] for i, vit_embed in enumerate(vit_embeds): current_token = masks_flags[i].sum() mask = masks[i] limit_num = mask.shape[1] mask = mask.permute(1,0,2,3).reshape(limit_num, -1) > 0 max_cluster_index = mask.sum(-1) zero_bool = max_cluster_index != 0 # import pdb; pdb.set_trace() mask[~zero_bool] = 1 #for addressing bflost16 bug new_max_cluster_index = mask.sum(-1) mask = mask / new_max_cluster_index.unsqueeze(-1) token_feature = torch.matmul(mask.to(vit_embed), vit_embed.reshape(-1, vit_embed.shape[-1])) token_features.extend(token_feature) input_embedings.extend(input_embeds[i, :]) masked_input_ids.extend(input_ids[i, zero_bool]) masked_zero_bools.append(zero_bool) masked_zero_bools = torch.cat(masked_zero_bools) token_features = torch.stack(token_features) input_embedings= torch.stack(input_embedings) loss2 = F.mse_loss(token_features, input_embedings, reduction='none')[masked_zero_bools].sum(1).sqrt().mean() token_features = token_features / token_features.norm(dim=1, keepdim=True) input_embedings = input_embedings / input_embedings.norm(dim=1, keepdim=True) # cosine similarity as logits similarity = F.cosine_similarity(token_features, input_embedings, dim=1) loss1 = (1 - similarity[masked_zero_bools]).mean() # loss_d = loss1 + loss2 # if rank == 0: # print(f'loss1:{loss_d}') ###siglip # masked_input_ids = torch.stack(masked_input_ids) # label_matrix = (masked_input_ids.unsqueeze(0) == masked_input_ids.unsqueeze(1)).int() # label_matrix = 2 * label_matrix - 1 # if self.kb: # logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() + self.b.to(input_embedings.device) # else: # logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() - 8.9375 # loss_s = -torch.sum(F.logsigmoid(label_matrix * logits)) / logits.shape[0] # if rank == 0: # print(f'loss2:{loss_s}') return loss1, loss2 def forward_tokenocr(self, pixel_values): vit_embeds = self.backbone(pixel_values) vit_embeds = vit_embeds['0'] h, w = vit_embeds.shape[2], vit_embeds.shape[3] vit_embeds = self.upsample(vit_embeds) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-2] * vit_embeds.shape[-1]) vit_embeds = self.ocr_mlp(vit_embeds.permute(0, 2, 1)) return vit_embeds, (h*4, w*4) class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def build(args): backbone = build_backbone(args) model = TokenOCR(backbone) return model