InternVideo2_CLIP_S / modeling_internvideo2encoder.py
ynhe's picture
[Init] upload model
ff495b4 verified
# from .internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
from .config import InternVideo2Config as config
import warnings
import torch
from torch import nn
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
from transformers.utils import logging
warnings.filterwarnings("ignore")
from .internvideo2_clip_vision import InternVideo2
from .mobile_clip import TextTransformer, ClipTokenizer
logger = logging.get_logger(__name__)
class InternVideo2_CLIP_small(PreTrainedModel):
config_class = config
def __init__(self, config, tokenizer=None, is_pretrain=True):
super().__init__(config)
self.config = config
self.tokenizer = tokenizer
self.is_pretrain = is_pretrain
print(config)
if tokenizer is None:
self.tokenizer = ClipTokenizer(self.config.model.text_encoder)
# self.model = IV2S2(self.config).to('cpu').to(torch.float16)
self.vision_encoder = self.build_vision_encoder()
self.vision_align = nn.Sequential(
nn.LayerNorm(self.config.model.vision_encoder.clip_embed_dim),
nn.Linear(
self.config.model.vision_encoder.clip_embed_dim,
self.config.model.vision_encoder.align_dim
),
)
self.text_encoder = self.build_text_encoder(cfg=self.config.model.text_encoder['text_cfg'], projection_dim=self.config.model.text_encoder["embed_dim"])
# adopt 1 / 100. as in ViCLIP
self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp)
self.temp_min = config.model.temp_min
if self.config.model.freeze_vision:
for name, p in self.vision_encoder.named_parameters():
if self.config.model.open_vision_clip_projector and name.startswith('clip_projector'):
logger.info(f"Unfreeze {name}")
else:
logger.info(f"Freeze {name}")
p.requires_grad = False
if self.config.model.freeze_text:
for name, p in self.text_encoder.named_parameters():
if self.config.model.open_text_projection and name.startswith('projection_layer'):
logger.info(f"Unfreeze {name}")
else:
logger.info(f"Freeze {name}")
p.requires_grad = False
img_size = self.config.model.vision_encoder.img_size
self.transform = transforms.Compose(
[
transforms.Resize(
(img_size, img_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.Lambda(lambda x: x.float().div(255.0)),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
@torch.no_grad()
def clip_contrastive_temperature(self):
"""Seems only used during pre-training"""
self.temp.clamp_(min=self.temp_min)
def encode_vision(self, image, test=False):
"""encode image / videos as features.
Args:
image (torch.Tensor): The input images.
test (bool): Whether testing.
Returns: tuple.
- vision_embeds (torch.Tensor): The features of all patches. Shape: [B,C].
"""
T = image.shape[1]
use_image = True if T == 1 else False
image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W]
vision_embeds = self.vision_encoder(image, use_image=use_image)
vision_embeds = self.vision_align(vision_embeds)
return vision_embeds
def encode_text(self, text):
"""encode text.
Args:
text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
- input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
- attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
- other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
Returns: tuple.
- text_embeds (torch.Tensor): The features of all tokens. Shape: [B,C].
"""
text_embeds = self.text_encoder(text)
return text_embeds
def build_vision_encoder(self):
"""build vision encoder
Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
"""
vision_encoder = InternVideo2(
in_chans=self.config.model.vision_encoder.in_chans,
patch_size=self.config.model.vision_encoder.patch_size,
img_size=self.config.model.vision_encoder.img_size,
qkv_bias=self.config.model.vision_encoder.qkv_bias,
drop_path_rate=self.config.model.vision_encoder.drop_path_rate,
head_drop_path_rate=self.config.model.vision_encoder.head_drop_path_rate,
embed_dim=self.config.model.vision_encoder.embed_dim,
num_heads=self.config.model.vision_encoder.num_heads,
mlp_ratio=self.config.model.vision_encoder.mlp_ratio,
init_values=self.config.model.vision_encoder.init_values,
qk_normalization=self.config.model.vision_encoder.qk_normalization,
depth=self.config.model.vision_encoder.depth,
use_flash_attn=self.config.model.vision_encoder.use_flash_attn,
use_fused_rmsnorm=self.config.model.vision_encoder.use_fused_rmsnorm,
use_fused_mlp=self.config.model.vision_encoder.use_fused_mlp,
fused_mlp_heuristic=self.config.model.vision_encoder.fused_mlp_heuristic,
attn_pool_num_heads=self.config.model.vision_encoder.attn_pool_num_heads,
clip_embed_dim=self.config.model.vision_encoder.clip_embed_dim,
layerscale_no_force_fp32=self.config.model.vision_encoder.layerscale_no_force_fp32,
num_frames=self.config.model.vision_encoder.num_frames,
tubelet_size=self.config.model.vision_encoder.tubelet_size,
sep_pos_embed=self.config.model.vision_encoder.sep_pos_embed,
use_checkpoint=self.config.model.vision_encoder.use_checkpoint,
checkpoint_num=self.config.model.vision_encoder.checkpoint_num,
)
return vision_encoder
def build_text_encoder(self, cfg, projection_dim):
"""build text_encoder and possiblly video-to-text multimodal fusion encoder.
Returns: nn.Module. The text encoder
"""
text_encoder = TextTransformer(cfg, projection_dim)
return text_encoder
if __name__ == "__main__":
model_config = config()
model = InternVideo2Stage2VideoEncoder(model_config)
x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device)
output = model(x)