|
|
|
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig |
|
from .config import InternVideo2Config as config |
|
import warnings |
|
import torch |
|
from torch import nn |
|
import torchvision.transforms as transforms |
|
from torchvision.transforms import InterpolationMode |
|
from transformers.utils import logging |
|
warnings.filterwarnings("ignore") |
|
from .internvideo2_clip_vision import InternVideo2 |
|
from .mobile_clip import TextTransformer, ClipTokenizer |
|
logger = logging.get_logger(__name__) |
|
|
|
class InternVideo2_CLIP_small(PreTrainedModel): |
|
config_class = config |
|
|
|
def __init__(self, config, tokenizer=None, is_pretrain=True): |
|
super().__init__(config) |
|
self.config = config |
|
self.tokenizer = tokenizer |
|
self.is_pretrain = is_pretrain |
|
print(config) |
|
if tokenizer is None: |
|
self.tokenizer = ClipTokenizer(self.config.model.text_encoder) |
|
|
|
self.vision_encoder = self.build_vision_encoder() |
|
|
|
self.vision_align = nn.Sequential( |
|
nn.LayerNorm(self.config.model.vision_encoder.clip_embed_dim), |
|
nn.Linear( |
|
self.config.model.vision_encoder.clip_embed_dim, |
|
self.config.model.vision_encoder.align_dim |
|
), |
|
) |
|
self.text_encoder = self.build_text_encoder(cfg=self.config.model.text_encoder['text_cfg'], projection_dim=self.config.model.text_encoder["embed_dim"]) |
|
|
|
self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp) |
|
self.temp_min = config.model.temp_min |
|
|
|
if self.config.model.freeze_vision: |
|
for name, p in self.vision_encoder.named_parameters(): |
|
if self.config.model.open_vision_clip_projector and name.startswith('clip_projector'): |
|
logger.info(f"Unfreeze {name}") |
|
else: |
|
logger.info(f"Freeze {name}") |
|
p.requires_grad = False |
|
if self.config.model.freeze_text: |
|
for name, p in self.text_encoder.named_parameters(): |
|
if self.config.model.open_text_projection and name.startswith('projection_layer'): |
|
logger.info(f"Unfreeze {name}") |
|
else: |
|
logger.info(f"Freeze {name}") |
|
p.requires_grad = False |
|
img_size = self.config.model.vision_encoder.img_size |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
(img_size, img_size), |
|
interpolation=InterpolationMode.BICUBIC, |
|
), |
|
transforms.Lambda(lambda x: x.float().div(255.0)), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
] |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def clip_contrastive_temperature(self): |
|
"""Seems only used during pre-training""" |
|
self.temp.clamp_(min=self.temp_min) |
|
|
|
def encode_vision(self, image, test=False): |
|
"""encode image / videos as features. |
|
|
|
Args: |
|
image (torch.Tensor): The input images. |
|
test (bool): Whether testing. |
|
|
|
Returns: tuple. |
|
- vision_embeds (torch.Tensor): The features of all patches. Shape: [B,C]. |
|
|
|
""" |
|
T = image.shape[1] |
|
use_image = True if T == 1 else False |
|
image = image.permute(0, 2, 1, 3, 4) |
|
|
|
vision_embeds = self.vision_encoder(image, use_image=use_image) |
|
vision_embeds = self.vision_align(vision_embeds) |
|
return vision_embeds |
|
|
|
def encode_text(self, text): |
|
"""encode text. |
|
Args: |
|
text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys: |
|
- input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L]. |
|
- attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token. |
|
- other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". |
|
Returns: tuple. |
|
- text_embeds (torch.Tensor): The features of all tokens. Shape: [B,C]. |
|
|
|
""" |
|
text_embeds = self.text_encoder(text) |
|
return text_embeds |
|
|
|
def build_vision_encoder(self): |
|
"""build vision encoder |
|
Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`. |
|
|
|
""" |
|
vision_encoder = InternVideo2( |
|
in_chans=self.config.model.vision_encoder.in_chans, |
|
patch_size=self.config.model.vision_encoder.patch_size, |
|
img_size=self.config.model.vision_encoder.img_size, |
|
qkv_bias=self.config.model.vision_encoder.qkv_bias, |
|
drop_path_rate=self.config.model.vision_encoder.drop_path_rate, |
|
head_drop_path_rate=self.config.model.vision_encoder.head_drop_path_rate, |
|
embed_dim=self.config.model.vision_encoder.embed_dim, |
|
num_heads=self.config.model.vision_encoder.num_heads, |
|
mlp_ratio=self.config.model.vision_encoder.mlp_ratio, |
|
init_values=self.config.model.vision_encoder.init_values, |
|
qk_normalization=self.config.model.vision_encoder.qk_normalization, |
|
depth=self.config.model.vision_encoder.depth, |
|
use_flash_attn=self.config.model.vision_encoder.use_flash_attn, |
|
use_fused_rmsnorm=self.config.model.vision_encoder.use_fused_rmsnorm, |
|
use_fused_mlp=self.config.model.vision_encoder.use_fused_mlp, |
|
fused_mlp_heuristic=self.config.model.vision_encoder.fused_mlp_heuristic, |
|
attn_pool_num_heads=self.config.model.vision_encoder.attn_pool_num_heads, |
|
clip_embed_dim=self.config.model.vision_encoder.clip_embed_dim, |
|
layerscale_no_force_fp32=self.config.model.vision_encoder.layerscale_no_force_fp32, |
|
num_frames=self.config.model.vision_encoder.num_frames, |
|
tubelet_size=self.config.model.vision_encoder.tubelet_size, |
|
sep_pos_embed=self.config.model.vision_encoder.sep_pos_embed, |
|
use_checkpoint=self.config.model.vision_encoder.use_checkpoint, |
|
checkpoint_num=self.config.model.vision_encoder.checkpoint_num, |
|
) |
|
return vision_encoder |
|
|
|
def build_text_encoder(self, cfg, projection_dim): |
|
"""build text_encoder and possiblly video-to-text multimodal fusion encoder. |
|
Returns: nn.Module. The text encoder |
|
|
|
""" |
|
text_encoder = TextTransformer(cfg, projection_dim) |
|
|
|
return text_encoder |
|
|
|
if __name__ == "__main__": |
|
model_config = config() |
|
model = InternVideo2Stage2VideoEncoder(model_config) |
|
x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device) |
|
output = model(x) |