| | from .internvideo2_stage2 import InternVideo2_Stage2 as IV2S2 |
| | from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig |
| | from .config import InternVideo2Config as config |
| | import warnings |
| | import torch |
| | |
| | warnings.filterwarnings("ignore") |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | class InternVideo2Stage2VideoEncoder(PreTrainedModel): |
| | config_class = config |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| | |
| | self.model = IV2S2(self.config).to('cpu').to(torch.float16) |
| |
|
| | def forward(self, x: torch.tensor): |
| | """forward pass |
| | Args: |
| | x (torch.tensor): Shape (B, N, C, H, W) or (B, C, H, W) |
| | Returns: |
| | torch.tensor: Shape (B*N, hidden_size) or (B, hidden_size) |
| | """ |
| | if len(x.shape) == 5 and x.shape[1] > 8: |
| | |
| | |
| | T = x.shape[1] |
| | embs = torch.cat([self.forward(x[:, i:i+8, :, :, :])for i in range(0, T, 8)], dim=1) |
| | return embs |
| | |
| | image = False |
| | if len(x.shape) == 4: |
| | x = x.unsqueeze(1) |
| | image = True |
| | B, N, C, H, W = x.shape |
| | |
| | output = self.model.encode_vision(x) |
| | pooled_vision_embeds = output[1] |
| | output = pooled_vision_embeds[:, :256*N, :] |
| | output = output.reshape(B, N, 256, -1) |
| | output = output.mean(dim=2) |
| | if image: |
| | output = output.squeeze(1) |
| | return output |
| | |
| | if __name__ == "__main__": |
| | model_config = config() |
| | model = InternVideo2Stage2VideoEncoder(model_config) |
| | x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device) |
| | output = model(x) |