| | import math |
| | from collections import OrderedDict |
| | from functools import partial |
| | from typing import Any, Callable, List, NamedTuple, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | |
| | from .vision_transformer_misc import ConvNormActivation |
| | from .vision_transformer_utils import _log_api_usage_once |
| |
|
| | try: |
| | from torch.hub import load_state_dict_from_url |
| | except ImportError: |
| | from torch.utils.model_zoo import load_url as load_state_dict_from_url |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | model_urls = { |
| | "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", |
| | "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", |
| | "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", |
| | "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", |
| | } |
| |
|
| |
|
| | class ConvStemConfig(NamedTuple): |
| | out_channels: int |
| | kernel_size: int |
| | stride: int |
| | norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d |
| | activation_layer: Callable[..., nn.Module] = nn.ReLU |
| |
|
| |
|
| | class MLPBlock(nn.Sequential): |
| | """Transformer MLP block.""" |
| |
|
| | def __init__(self, in_dim: int, mlp_dim: int, dropout: float): |
| | super().__init__() |
| | self.linear_1 = nn.Linear(in_dim, mlp_dim) |
| | self.act = nn.GELU() |
| | self.dropout_1 = nn.Dropout(dropout) |
| | self.linear_2 = nn.Linear(mlp_dim, in_dim) |
| | self.dropout_2 = nn.Dropout(dropout) |
| |
|
| | nn.init.xavier_uniform_(self.linear_1.weight) |
| | nn.init.xavier_uniform_(self.linear_2.weight) |
| | nn.init.normal_(self.linear_1.bias, std=1e-6) |
| | nn.init.normal_(self.linear_2.bias, std=1e-6) |
| |
|
| |
|
| | class EncoderBlock(nn.Module): |
| | """Transformer encoder block.""" |
| |
|
| | def __init__( |
| | self, |
| | num_heads: int, |
| | hidden_dim: int, |
| | mlp_dim: int, |
| | dropout: float, |
| | attention_dropout: float, |
| | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), |
| | ): |
| | super().__init__() |
| | self.num_heads = num_heads |
| |
|
| | |
| | self.ln_1 = norm_layer(hidden_dim) |
| | self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | |
| | self.ln_2 = norm_layer(hidden_dim) |
| | self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) |
| |
|
| | def forward(self, input: torch.Tensor): |
| | torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") |
| | x = self.ln_1(input) |
| | x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) |
| | x = self.dropout(x) |
| | x = x + input |
| |
|
| | y = self.ln_2(x) |
| | y = self.mlp(y) |
| | return x + y |
| |
|
| |
|
| | class Encoder(nn.Module): |
| | """Transformer Model Encoder for sequence to sequence translation.""" |
| |
|
| | def __init__( |
| | self, |
| | seq_length: int, |
| | num_layers: int, |
| | num_heads: int, |
| | hidden_dim: int, |
| | mlp_dim: int, |
| | dropout: float, |
| | attention_dropout: float, |
| | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), |
| | ): |
| | super().__init__() |
| | |
| | |
| | self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) |
| | self.dropout = nn.Dropout(dropout) |
| | layers: OrderedDict[str, nn.Module] = OrderedDict() |
| | for i in range(num_layers): |
| | layers[f"encoder_layer_{i}"] = EncoderBlock( |
| | num_heads, |
| | hidden_dim, |
| | mlp_dim, |
| | dropout, |
| | attention_dropout, |
| | norm_layer, |
| | ) |
| | self.layers = nn.Sequential(layers) |
| | self.ln = norm_layer(hidden_dim) |
| |
|
| | def forward(self, input: torch.Tensor): |
| | torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") |
| | input = input + self.pos_embedding |
| | return self.ln(self.layers(self.dropout(input))) |
| |
|
| |
|
| | class VisionTransformer(nn.Module): |
| | """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" |
| |
|
| | def __init__( |
| | self, |
| | image_size: int, |
| | patch_size: int, |
| | num_layers: int, |
| | num_heads: int, |
| | hidden_dim: int, |
| | mlp_dim: int, |
| | dropout: float = 0.0, |
| | attention_dropout: float = 0.0, |
| | num_classes: int = 1000, |
| | representation_size: Optional[int] = None, |
| | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), |
| | conv_stem_configs: Optional[List[ConvStemConfig]] = None, |
| | ): |
| | super().__init__() |
| | _log_api_usage_once(self) |
| | torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") |
| | self.image_size = image_size |
| | self.patch_size = patch_size |
| | self.hidden_dim = hidden_dim |
| | self.mlp_dim = mlp_dim |
| | self.attention_dropout = attention_dropout |
| | self.dropout = dropout |
| | self.num_classes = num_classes |
| | self.representation_size = representation_size |
| | self.norm_layer = norm_layer |
| |
|
| | if conv_stem_configs is not None: |
| | |
| | seq_proj = nn.Sequential() |
| | prev_channels = 3 |
| | for i, conv_stem_layer_config in enumerate(conv_stem_configs): |
| | seq_proj.add_module( |
| | f"conv_bn_relu_{i}", |
| | ConvNormActivation( |
| | in_channels=prev_channels, |
| | out_channels=conv_stem_layer_config.out_channels, |
| | kernel_size=conv_stem_layer_config.kernel_size, |
| | stride=conv_stem_layer_config.stride, |
| | norm_layer=conv_stem_layer_config.norm_layer, |
| | activation_layer=conv_stem_layer_config.activation_layer, |
| | ), |
| | ) |
| | prev_channels = conv_stem_layer_config.out_channels |
| | seq_proj.add_module( |
| | "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1) |
| | ) |
| | self.conv_proj: nn.Module = seq_proj |
| | else: |
| | self.conv_proj = nn.Conv2d( |
| | in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size |
| | ) |
| |
|
| | seq_length = (image_size // patch_size) ** 2 |
| |
|
| | |
| | self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) |
| | seq_length += 1 |
| |
|
| | self.encoder = Encoder( |
| | seq_length, |
| | num_layers, |
| | num_heads, |
| | hidden_dim, |
| | mlp_dim, |
| | dropout, |
| | attention_dropout, |
| | norm_layer, |
| | ) |
| | self.seq_length = seq_length |
| |
|
| | heads_layers: OrderedDict[str, nn.Module] = OrderedDict() |
| | if representation_size is None: |
| | heads_layers["head"] = nn.Linear(hidden_dim, num_classes) |
| | else: |
| | heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) |
| | heads_layers["act"] = nn.Tanh() |
| | heads_layers["head"] = nn.Linear(representation_size, num_classes) |
| |
|
| | self.heads = nn.Sequential(heads_layers) |
| |
|
| | if isinstance(self.conv_proj, nn.Conv2d): |
| | |
| | fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] |
| | nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) |
| | if self.conv_proj.bias is not None: |
| | nn.init.zeros_(self.conv_proj.bias) |
| | elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d): |
| | |
| | nn.init.normal_( |
| | self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) |
| | ) |
| | if self.conv_proj.conv_last.bias is not None: |
| | nn.init.zeros_(self.conv_proj.conv_last.bias) |
| |
|
| | if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): |
| | fan_in = self.heads.pre_logits.in_features |
| | nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) |
| | nn.init.zeros_(self.heads.pre_logits.bias) |
| |
|
| | if isinstance(self.heads.head, nn.Linear): |
| | nn.init.zeros_(self.heads.head.weight) |
| | nn.init.zeros_(self.heads.head.bias) |
| |
|
| | def _process_input(self, x: torch.Tensor) -> torch.Tensor: |
| | n, c, h, w = x.shape |
| | p = self.patch_size |
| | torch._assert(h == self.image_size, "Wrong image height!") |
| | torch._assert(w == self.image_size, "Wrong image width!") |
| | n_h = h // p |
| | n_w = w // p |
| |
|
| | |
| | x = self.conv_proj(x) |
| | |
| | x = x.reshape(n, self.hidden_dim, n_h * n_w) |
| |
|
| | |
| | |
| | |
| | |
| | x = x.permute(0, 2, 1) |
| |
|
| | return x |
| |
|
| | def forward(self, x: torch.Tensor): |
| | out = {} |
| |
|
| | |
| | x = self._process_input(x) |
| | n = x.shape[0] |
| |
|
| | |
| | batch_class_token = self.class_token.expand(n, -1, -1) |
| | x = torch.cat([batch_class_token, x], dim=1) |
| |
|
| | |
| | x = self.encoder(x) |
| | img_feature = x[:,1:] |
| | H = W = int(self.image_size / self.patch_size) |
| | out['f4'] = img_feature.view(n, H, W, self.hidden_dim).permute(0,3,1,2) |
| |
|
| | |
| | x = x[:, 0] |
| | out['penultimate'] = x |
| |
|
| | x = self.heads(x) |
| | out['logits'] = x |
| |
|
| | return out |
| |
|
| |
|
| | def _vision_transformer( |
| | arch: str, |
| | patch_size: int, |
| | num_layers: int, |
| | num_heads: int, |
| | hidden_dim: int, |
| | mlp_dim: int, |
| | pretrained: bool, |
| | progress: bool, |
| | **kwargs: Any, |
| | ) -> VisionTransformer: |
| | image_size = kwargs.pop("image_size", 224) |
| |
|
| | model = VisionTransformer( |
| | image_size=image_size, |
| | patch_size=patch_size, |
| | num_layers=num_layers, |
| | num_heads=num_heads, |
| | hidden_dim=hidden_dim, |
| | mlp_dim=mlp_dim, |
| | **kwargs, |
| | ) |
| |
|
| | if pretrained: |
| | if arch not in model_urls: |
| | raise ValueError(f"No checkpoint is available for model type '{arch}'!") |
| | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) |
| | model.load_state_dict(state_dict) |
| |
|
| | return model |
| |
|
| |
|
| | def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: |
| | """ |
| | Constructs a vit_b_16 architecture from |
| | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. |
| | |
| | Args: |
| | pretrained (bool): If True, returns a model pre-trained on ImageNet |
| | progress (bool): If True, displays a progress bar of the download to stderr |
| | """ |
| | return _vision_transformer( |
| | arch="vit_b_16", |
| | patch_size=16, |
| | num_layers=12, |
| | num_heads=12, |
| | hidden_dim=768, |
| | mlp_dim=3072, |
| | pretrained=pretrained, |
| | progress=progress, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: |
| | """ |
| | Constructs a vit_b_32 architecture from |
| | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. |
| | |
| | Args: |
| | pretrained (bool): If True, returns a model pre-trained on ImageNet |
| | progress (bool): If True, displays a progress bar of the download to stderr |
| | """ |
| | return _vision_transformer( |
| | arch="vit_b_32", |
| | patch_size=32, |
| | num_layers=12, |
| | num_heads=12, |
| | hidden_dim=768, |
| | mlp_dim=3072, |
| | pretrained=pretrained, |
| | progress=progress, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: |
| | """ |
| | Constructs a vit_l_16 architecture from |
| | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. |
| | |
| | Args: |
| | pretrained (bool): If True, returns a model pre-trained on ImageNet |
| | progress (bool): If True, displays a progress bar of the download to stderr |
| | """ |
| | return _vision_transformer( |
| | arch="vit_l_16", |
| | patch_size=16, |
| | num_layers=24, |
| | num_heads=16, |
| | hidden_dim=1024, |
| | mlp_dim=4096, |
| | pretrained=pretrained, |
| | progress=progress, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: |
| | """ |
| | Constructs a vit_l_32 architecture from |
| | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. |
| | |
| | Args: |
| | pretrained (bool): If True, returns a model pre-trained on ImageNet |
| | progress (bool): If True, displays a progress bar of the download to stderr |
| | """ |
| | return _vision_transformer( |
| | arch="vit_l_32", |
| | patch_size=32, |
| | num_layers=24, |
| | num_heads=16, |
| | hidden_dim=1024, |
| | mlp_dim=4096, |
| | pretrained=pretrained, |
| | progress=progress, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | def interpolate_embeddings( |
| | image_size: int, |
| | patch_size: int, |
| | model_state: "OrderedDict[str, torch.Tensor]", |
| | interpolation_mode: str = "bicubic", |
| | reset_heads: bool = False, |
| | ) -> "OrderedDict[str, torch.Tensor]": |
| | """This function helps interpolating positional embeddings during checkpoint loading, |
| | especially when you want to apply a pre-trained model on images with different resolution. |
| | |
| | Args: |
| | image_size (int): Image size of the new model. |
| | patch_size (int): Patch size of the new model. |
| | model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. |
| | interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. |
| | reset_heads (bool): If true, not copying the state of heads. Default: False. |
| | |
| | Returns: |
| | OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. |
| | """ |
| | |
| | pos_embedding = model_state["encoder.pos_embedding"] |
| | n, seq_length, hidden_dim = pos_embedding.shape |
| | if n != 1: |
| | raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") |
| |
|
| | new_seq_length = (image_size // patch_size) ** 2 + 1 |
| |
|
| | |
| | |
| | |
| | if new_seq_length != seq_length: |
| | |
| | seq_length -= 1 |
| | new_seq_length -= 1 |
| | pos_embedding_token = pos_embedding[:, :1, :] |
| | pos_embedding_img = pos_embedding[:, 1:, :] |
| |
|
| | |
| | pos_embedding_img = pos_embedding_img.permute(0, 2, 1) |
| | seq_length_1d = int(math.sqrt(seq_length)) |
| | torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!") |
| |
|
| | |
| | pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) |
| | new_seq_length_1d = image_size // patch_size |
| |
|
| | |
| | |
| | new_pos_embedding_img = nn.functional.interpolate( |
| | pos_embedding_img, |
| | size=new_seq_length_1d, |
| | mode=interpolation_mode, |
| | align_corners=True, |
| | ) |
| |
|
| | |
| | new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) |
| |
|
| | |
| | new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) |
| | new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) |
| |
|
| | model_state["encoder.pos_embedding"] = new_pos_embedding |
| |
|
| | if reset_heads: |
| | model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict() |
| | for k, v in model_state.items(): |
| | if not k.startswith("heads"): |
| | model_state_copy[k] = v |
| | model_state = model_state_copy |
| |
|
| | return model_state |
| |
|