| | |
| | from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig |
| | from .config import InternVideo2Config as config |
| | import warnings |
| | import torch |
| | from torch import nn |
| | import torchvision.transforms as transforms |
| | from torchvision.transforms import InterpolationMode |
| | from transformers.utils import logging |
| | warnings.filterwarnings("ignore") |
| | from .internvideo2_clip_vision import InternVideo2 |
| | from .mobile_clip import TextTransformer, ClipTokenizer |
| | logger = logging.get_logger(__name__) |
| |
|
| | class InternVideo2_CLIP_small(PreTrainedModel): |
| | config_class = config |
| |
|
| | def __init__(self, config, tokenizer=None, is_pretrain=True): |
| | super().__init__(config) |
| | self.config = config |
| | self.tokenizer = tokenizer |
| | self.is_pretrain = is_pretrain |
| | print(config) |
| | if tokenizer is None: |
| | self.tokenizer = ClipTokenizer(self.config.model.text_encoder) |
| | |
| | self.vision_encoder = self.build_vision_encoder() |
| |
|
| | self.vision_align = nn.Sequential( |
| | nn.LayerNorm(self.config.model.vision_encoder.clip_embed_dim), |
| | nn.Linear( |
| | self.config.model.vision_encoder.clip_embed_dim, |
| | self.config.model.vision_encoder.align_dim |
| | ), |
| | ) |
| | self.text_encoder = self.build_text_encoder(cfg=self.config.model.text_encoder['text_cfg'], projection_dim=self.config.model.text_encoder["embed_dim"]) |
| | |
| | self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp) |
| | self.temp_min = config.model.temp_min |
| |
|
| | if self.config.model.freeze_vision: |
| | for name, p in self.vision_encoder.named_parameters(): |
| | if self.config.model.open_vision_clip_projector and name.startswith('clip_projector'): |
| | logger.info(f"Unfreeze {name}") |
| | else: |
| | logger.info(f"Freeze {name}") |
| | p.requires_grad = False |
| | if self.config.model.freeze_text: |
| | for name, p in self.text_encoder.named_parameters(): |
| | if self.config.model.open_text_projection and name.startswith('projection_layer'): |
| | logger.info(f"Unfreeze {name}") |
| | else: |
| | logger.info(f"Freeze {name}") |
| | p.requires_grad = False |
| | img_size = self.config.model.vision_encoder.img_size |
| | self.transform = transforms.Compose( |
| | [ |
| | transforms.Resize( |
| | (img_size, img_size), |
| | interpolation=InterpolationMode.BICUBIC, |
| | ), |
| | transforms.Lambda(lambda x: x.float().div(255.0)), |
| | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| | ] |
| | ) |
| |
|
| |
|
| | @torch.no_grad() |
| | def clip_contrastive_temperature(self): |
| | """Seems only used during pre-training""" |
| | self.temp.clamp_(min=self.temp_min) |
| |
|
| | def encode_vision(self, image, test=False): |
| | """encode image / videos as features. |
| | |
| | Args: |
| | image (torch.Tensor): The input images. |
| | test (bool): Whether testing. |
| | |
| | Returns: tuple. |
| | - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,C]. |
| | |
| | """ |
| | T = image.shape[1] |
| | use_image = True if T == 1 else False |
| | image = image.permute(0, 2, 1, 3, 4) |
| |
|
| | vision_embeds = self.vision_encoder(image, use_image=use_image) |
| | vision_embeds = self.vision_align(vision_embeds) |
| | return vision_embeds |
| |
|
| | def encode_text(self, text): |
| | """encode text. |
| | Args: |
| | text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys: |
| | - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L]. |
| | - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token. |
| | - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". |
| | Returns: tuple. |
| | - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,C]. |
| | |
| | """ |
| | text_embeds = self.text_encoder(text) |
| | return text_embeds |
| |
|
| | def build_vision_encoder(self): |
| | """build vision encoder |
| | Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`. |
| | |
| | """ |
| | vision_encoder = InternVideo2( |
| | in_chans=self.config.model.vision_encoder.in_chans, |
| | patch_size=self.config.model.vision_encoder.patch_size, |
| | img_size=self.config.model.vision_encoder.img_size, |
| | qkv_bias=self.config.model.vision_encoder.qkv_bias, |
| | drop_path_rate=self.config.model.vision_encoder.drop_path_rate, |
| | head_drop_path_rate=self.config.model.vision_encoder.head_drop_path_rate, |
| | embed_dim=self.config.model.vision_encoder.embed_dim, |
| | num_heads=self.config.model.vision_encoder.num_heads, |
| | mlp_ratio=self.config.model.vision_encoder.mlp_ratio, |
| | init_values=self.config.model.vision_encoder.init_values, |
| | qk_normalization=self.config.model.vision_encoder.qk_normalization, |
| | depth=self.config.model.vision_encoder.depth, |
| | use_flash_attn=self.config.model.vision_encoder.use_flash_attn, |
| | use_fused_rmsnorm=self.config.model.vision_encoder.use_fused_rmsnorm, |
| | use_fused_mlp=self.config.model.vision_encoder.use_fused_mlp, |
| | fused_mlp_heuristic=self.config.model.vision_encoder.fused_mlp_heuristic, |
| | attn_pool_num_heads=self.config.model.vision_encoder.attn_pool_num_heads, |
| | clip_embed_dim=self.config.model.vision_encoder.clip_embed_dim, |
| | layerscale_no_force_fp32=self.config.model.vision_encoder.layerscale_no_force_fp32, |
| | num_frames=self.config.model.vision_encoder.num_frames, |
| | tubelet_size=self.config.model.vision_encoder.tubelet_size, |
| | sep_pos_embed=self.config.model.vision_encoder.sep_pos_embed, |
| | use_checkpoint=self.config.model.vision_encoder.use_checkpoint, |
| | checkpoint_num=self.config.model.vision_encoder.checkpoint_num, |
| | ) |
| | return vision_encoder |
| |
|
| | def build_text_encoder(self, cfg, projection_dim): |
| | """build text_encoder and possiblly video-to-text multimodal fusion encoder. |
| | Returns: nn.Module. The text encoder |
| | |
| | """ |
| | text_encoder = TextTransformer(cfg, projection_dim) |
| |
|
| | return text_encoder |
| | |
| | if __name__ == "__main__": |
| | model_config = config() |
| | model = InternVideo2Stage2VideoEncoder(model_config) |
| | x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device) |
| | output = model(x) |