| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import math |
| |
|
| | from einops import rearrange |
| | from pdb import set_trace as st |
| |
|
| | |
| |
|
| | from .dit_models_xformers import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer |
| | |
| |
|
| |
|
| | def modulate2(x, shift, scale): |
| | return x * (1 + scale) + shift |
| |
|
| |
|
| | class DiTBlock2(DiTBlock): |
| | """ |
| | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. |
| | """ |
| |
|
| | def __init__(self, hidden_size, num_heads, mlp_ratio=4, **block_kwargs): |
| | super().__init__(hidden_size, num_heads, mlp_ratio, **block_kwargs) |
| |
|
| | def forward(self, x, c): |
| | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation( |
| | c).chunk(6, dim=-1) |
| | |
| | x = x + gate_msa * self.attn( |
| | modulate2(self.norm1(x), shift_msa, scale_msa)) |
| | x = x + gate_mlp * self.mlp( |
| | modulate2(self.norm2(x), shift_mlp, scale_mlp)) |
| | return x |
| |
|
| |
|
| | class FinalLayer2(FinalLayer): |
| | """ |
| | The final layer of DiT, basically the decoder_pred in MAE with adaLN. |
| | """ |
| |
|
| | def __init__(self, hidden_size, patch_size, out_channels): |
| | super().__init__(hidden_size, patch_size, out_channels) |
| |
|
| | def forward(self, x, c): |
| | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) |
| | x = modulate2(self.norm_final(x), shift, scale) |
| | x = self.linear(x) |
| | return x |
| |
|
| |
|
| | class DiT2(DiT): |
| | |
| | def __init__(self, |
| | input_size=32, |
| | patch_size=2, |
| | in_channels=4, |
| | hidden_size=1152, |
| | depth=28, |
| | num_heads=16, |
| | mlp_ratio=4, |
| | class_dropout_prob=0.1, |
| | num_classes=1000, |
| | learn_sigma=True, |
| | mixing_logit_init=-3, |
| | mixed_prediction=True, |
| | context_dim=False, |
| | roll_out=False, |
| | plane_n=3, |
| | return_all_layers=False, |
| | vit_blk=...): |
| | super().__init__(input_size, |
| | patch_size, |
| | in_channels, |
| | hidden_size, |
| | depth, |
| | num_heads, |
| | mlp_ratio, |
| | class_dropout_prob, |
| | num_classes, |
| | learn_sigma, |
| | mixing_logit_init, |
| | mixed_prediction, |
| | context_dim, |
| | roll_out, |
| | vit_blk=DiTBlock2, |
| | final_layer_blk=FinalLayer2) |
| |
|
| | |
| | del self.x_embedder |
| | del self.t_embedder |
| | del self.final_layer |
| | torch.cuda.empty_cache() |
| | self.clip_text_proj = None |
| | self.plane_n = plane_n |
| | self.return_all_layers = return_all_layers |
| |
|
| | def forward(self, c, *args, **kwargs): |
| | |
| | """ |
| | Forward pass of DiT. |
| | c: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
| | """ |
| | x = self.pos_embed.repeat( |
| | c.shape[0], 1, 1) |
| |
|
| | if self.return_all_layers: |
| | all_layers = [] |
| |
|
| | |
| | |
| |
|
| | for blk_idx, block in enumerate(self.blocks): |
| | if self.roll_out: |
| | if blk_idx % 2 == 0: |
| | x = rearrange(x, 'b (n l) c -> (b n) l c ', n=self.plane_n) |
| | x = block(x, |
| | rearrange(c, |
| | 'b (n l) c -> (b n) l c ', |
| | n=self.plane_n)) |
| | |
| | if self.return_all_layers: |
| | all_layers.append(x) |
| | else: |
| | x = rearrange(x, '(b n) l c -> b (n l) c ', n=self.plane_n) |
| | x = block(x, c) |
| | |
| | if self.return_all_layers: |
| | |
| | all_layers.append( |
| | rearrange(x, |
| | 'b (n l) c -> (b n) l c', |
| | n=self.plane_n)) |
| | else: |
| | x = block(x, c) |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | if self.return_all_layers: |
| | return all_layers |
| | else: |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def DiT2_XL_2(**kwargs): |
| | return DiT2(depth=28, |
| | hidden_size=1152, |
| | patch_size=2, |
| | num_heads=16, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_XL_2_half(**kwargs): |
| | return DiT2(depth=28 // 2, |
| | hidden_size=1152, |
| | patch_size=2, |
| | num_heads=16, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_XL_4(**kwargs): |
| | return DiT2(depth=28, |
| | hidden_size=1152, |
| | patch_size=4, |
| | num_heads=16, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_XL_8(**kwargs): |
| | return DiT2(depth=28, |
| | hidden_size=1152, |
| | patch_size=8, |
| | num_heads=16, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_L_2(**kwargs): |
| | return DiT2(depth=24, |
| | hidden_size=1024, |
| | patch_size=2, |
| | num_heads=16, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_L_2_half(**kwargs): |
| | return DiT2(depth=24 // 2, |
| | hidden_size=1024, |
| | patch_size=2, |
| | num_heads=16, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_L_4(**kwargs): |
| | return DiT2(depth=24, |
| | hidden_size=1024, |
| | patch_size=4, |
| | num_heads=16, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_L_8(**kwargs): |
| | return DiT2(depth=24, |
| | hidden_size=1024, |
| | patch_size=8, |
| | num_heads=16, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_B_2(**kwargs): |
| | return DiT2(depth=12, |
| | hidden_size=768, |
| | patch_size=2, |
| | num_heads=12, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_B_4(**kwargs): |
| | return DiT2(depth=12, |
| | hidden_size=768, |
| | patch_size=4, |
| | num_heads=12, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_B_8(**kwargs): |
| | return DiT2(depth=12, |
| | hidden_size=768, |
| | patch_size=8, |
| | num_heads=12, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_B_16(**kwargs): |
| | return DiT2(depth=12, |
| | hidden_size=768, |
| | patch_size=16, |
| | num_heads=12, |
| | **kwargs) |
| |
|
| |
|
| | def DiT2_S_2(**kwargs): |
| | return DiT2(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) |
| |
|
| |
|
| | def DiT2_S_4(**kwargs): |
| | return DiT2(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) |
| |
|
| |
|
| | def DiT2_S_8(**kwargs): |
| | return DiT2(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) |
| |
|
| |
|
| | DiT2_models = { |
| | 'DiT2-XL/2': DiT2_XL_2, |
| | 'DiT2-XL/2/half': DiT2_XL_2_half, |
| | 'DiT2-XL/4': DiT2_XL_4, |
| | 'DiT2-XL/8': DiT2_XL_8, |
| | 'DiT2-L/2': DiT2_L_2, |
| | 'DiT2-L/2/half': DiT2_L_2_half, |
| | 'DiT2-L/4': DiT2_L_4, |
| | 'DiT2-L/8': DiT2_L_8, |
| | 'DiT2-B/2': DiT2_B_2, |
| | 'DiT2-B/4': DiT2_B_4, |
| | 'DiT2-B/8': DiT2_B_8, |
| | 'DiT2-B/16': DiT2_B_16, |
| | 'DiT2-S/2': DiT2_S_2, |
| | 'DiT2-S/4': DiT2_S_4, |
| | 'DiT2-S/8': DiT2_S_8, |
| | } |
| |
|