# from .internvideo2_stage2 import InternVideo2_Stage2 as IV2S2 from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig from .config import InternVideo2Config as config import warnings import torch from torch import nn import torchvision.transforms as transforms from torchvision.transforms import InterpolationMode from transformers.utils import logging warnings.filterwarnings("ignore") from .internvideo2_clip_vision import InternVideo2 from .mobile_clip import TextTransformer, ClipTokenizer logger = logging.get_logger(__name__) class InternVideo2_CLIP_small(PreTrainedModel): config_class = config def __init__(self, config, tokenizer=None, is_pretrain=True): super().__init__(config) self.config = config self.tokenizer = tokenizer self.is_pretrain = is_pretrain print(config) if tokenizer is None: self.tokenizer = ClipTokenizer(self.config.model.text_encoder) # self.model = IV2S2(self.config).to('cpu').to(torch.float16) self.vision_encoder = self.build_vision_encoder() self.vision_align = nn.Sequential( nn.LayerNorm(self.config.model.vision_encoder.clip_embed_dim), nn.Linear( self.config.model.vision_encoder.clip_embed_dim, self.config.model.vision_encoder.align_dim ), ) self.text_encoder = self.build_text_encoder(cfg=self.config.model.text_encoder['text_cfg'], projection_dim=self.config.model.text_encoder["embed_dim"]) # adopt 1 / 100. as in ViCLIP self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp) self.temp_min = config.model.temp_min if self.config.model.freeze_vision: for name, p in self.vision_encoder.named_parameters(): if self.config.model.open_vision_clip_projector and name.startswith('clip_projector'): logger.info(f"Unfreeze {name}") else: logger.info(f"Freeze {name}") p.requires_grad = False if self.config.model.freeze_text: for name, p in self.text_encoder.named_parameters(): if self.config.model.open_text_projection and name.startswith('projection_layer'): logger.info(f"Unfreeze {name}") else: logger.info(f"Freeze {name}") p.requires_grad = False img_size = self.config.model.vision_encoder.img_size self.transform = transforms.Compose( [ transforms.Resize( (img_size, img_size), interpolation=InterpolationMode.BICUBIC, ), transforms.Lambda(lambda x: x.float().div(255.0)), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] ) @torch.no_grad() def clip_contrastive_temperature(self): """Seems only used during pre-training""" self.temp.clamp_(min=self.temp_min) def encode_vision(self, image, test=False): """encode image / videos as features. Args: image (torch.Tensor): The input images. test (bool): Whether testing. Returns: tuple. - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,C]. """ T = image.shape[1] use_image = True if T == 1 else False image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W] vision_embeds = self.vision_encoder(image, use_image=use_image) vision_embeds = self.vision_align(vision_embeds) return vision_embeds def encode_text(self, text): """encode text. Args: text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys: - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L]. - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token. - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". Returns: tuple. - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,C]. """ text_embeds = self.text_encoder(text) return text_embeds def build_vision_encoder(self): """build vision encoder Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`. """ vision_encoder = InternVideo2( in_chans=self.config.model.vision_encoder.in_chans, patch_size=self.config.model.vision_encoder.patch_size, img_size=self.config.model.vision_encoder.img_size, qkv_bias=self.config.model.vision_encoder.qkv_bias, drop_path_rate=self.config.model.vision_encoder.drop_path_rate, head_drop_path_rate=self.config.model.vision_encoder.head_drop_path_rate, embed_dim=self.config.model.vision_encoder.embed_dim, num_heads=self.config.model.vision_encoder.num_heads, mlp_ratio=self.config.model.vision_encoder.mlp_ratio, init_values=self.config.model.vision_encoder.init_values, qk_normalization=self.config.model.vision_encoder.qk_normalization, depth=self.config.model.vision_encoder.depth, use_flash_attn=self.config.model.vision_encoder.use_flash_attn, use_fused_rmsnorm=self.config.model.vision_encoder.use_fused_rmsnorm, use_fused_mlp=self.config.model.vision_encoder.use_fused_mlp, fused_mlp_heuristic=self.config.model.vision_encoder.fused_mlp_heuristic, attn_pool_num_heads=self.config.model.vision_encoder.attn_pool_num_heads, clip_embed_dim=self.config.model.vision_encoder.clip_embed_dim, layerscale_no_force_fp32=self.config.model.vision_encoder.layerscale_no_force_fp32, num_frames=self.config.model.vision_encoder.num_frames, tubelet_size=self.config.model.vision_encoder.tubelet_size, sep_pos_embed=self.config.model.vision_encoder.sep_pos_embed, use_checkpoint=self.config.model.vision_encoder.use_checkpoint, checkpoint_num=self.config.model.vision_encoder.checkpoint_num, ) return vision_encoder def build_text_encoder(self, cfg, projection_dim): """build text_encoder and possiblly video-to-text multimodal fusion encoder. Returns: nn.Module. The text encoder """ text_encoder = TextTransformer(cfg, projection_dim) return text_encoder if __name__ == "__main__": model_config = config() model = InternVideo2Stage2VideoEncoder(model_config) x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device) output = model(x)