| import torch |
| from torch import nn |
| from nsr.triplane import Triplane_fg_bg_plane |
| |
| from vit.vit_triplane import Triplane, ViTTriplaneDecomposed |
| import argparse |
| import inspect |
| import dnnlib |
| from guided_diffusion import dist_util |
|
|
| from pdb import set_trace as st |
|
|
| import vit.vision_transformer as vits |
| from guided_diffusion import logger |
| from .confnet import ConfNet |
|
|
| from ldm.modules.diffusionmodules.model import Encoder, MVEncoder, MVEncoderGS |
| from ldm.modules.diffusionmodules.mv_unet import MVUNet, LGM_MVEncoder |
|
|
| |
|
|
| |
|
|
|
|
| class AE(torch.nn.Module): |
|
|
| def __init__(self, |
| encoder, |
| decoder, |
| img_size, |
| encoder_cls_token, |
| decoder_cls_token, |
| preprocess, |
| use_clip, |
| dino_version='v1', |
| clip_dtype=None, |
| no_dim_up_mlp=False, |
| dim_up_mlp_as_func=False, |
| uvit_skip_encoder=False, |
| confnet=None) -> None: |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.img_size = img_size |
| self.encoder_cls_token = encoder_cls_token |
| self.decoder_cls_token = decoder_cls_token |
| self.use_clip = use_clip |
| self.dino_version = dino_version |
| self.confnet = confnet |
|
|
| if self.dino_version == 'v2': |
| self.encoder.mask_token = None |
| self.decoder.vit_decoder.mask_token = None |
|
|
| if 'sd' not in self.dino_version: |
|
|
| self.uvit_skip_encoder = uvit_skip_encoder |
| if uvit_skip_encoder: |
| logger.log( |
| f'enables uvit: length of vit_encoder.blocks: {len(self.encoder.blocks)}' |
| ) |
| for blk in self.encoder.blocks[len(self.encoder.blocks) // 2:]: |
| blk.skip_linear = nn.Linear(2 * self.encoder.embed_dim, |
| self.encoder.embed_dim) |
|
|
| |
| nn.init.constant_(blk.skip_linear.weight, 0) |
| if isinstance( |
| blk.skip_linear, |
| nn.Linear) and blk.skip_linear.bias is not None: |
| nn.init.constant_(blk.skip_linear.bias, 0) |
| else: |
| logger.log(f'disable uvit') |
| else: |
| if 'dit' not in self.dino_version: |
| self.decoder.vit_decoder.cls_token = None |
| self.decoder.vit_decoder.patch_embed.proj = nn.Identity() |
| self.decoder.triplane_decoder.planes = None |
| self.decoder.vit_decoder.mask_token = None |
|
|
| if self.use_clip: |
| self.clip_dtype = clip_dtype |
|
|
| else: |
|
|
| if not no_dim_up_mlp and self.encoder.embed_dim != self.decoder.vit_decoder.embed_dim: |
| self.dim_up_mlp = nn.Linear( |
| self.encoder.embed_dim, |
| self.decoder.vit_decoder.embed_dim) |
| logger.log( |
| f"dim_up_mlp: {self.encoder.embed_dim} -> {self.decoder.vit_decoder.embed_dim}, as_func: {self.dim_up_mlp_as_func}" |
| ) |
| else: |
| logger.log('ignore dim_up_mlp: ', no_dim_up_mlp) |
|
|
| self.preprocess = preprocess |
|
|
| self.dim_up_mlp = None |
| self.dim_up_mlp_as_func = dim_up_mlp_as_func |
|
|
| |
| |
| torch.cuda.empty_cache() |
| |
| |
| |
|
|
| def encode(self, *args, **kwargs): |
| if not self.use_clip: |
| if self.dino_version == 'v1': |
| latent = self.encode_dinov1(*args, **kwargs) |
| elif self.dino_version == 'v2': |
| if self.uvit_skip_encoder: |
| latent = self.encode_dinov2_uvit(*args, **kwargs) |
| else: |
| latent = self.encode_dinov2(*args, **kwargs) |
| else: |
| latent = self.encoder(*args) |
|
|
| else: |
| latent = self.encode_clip(*args, **kwargs) |
|
|
| return latent |
|
|
| def encode_dinov1(self, x): |
| |
| x = self.encoder.prepare_tokens(x) |
| for blk in self.encoder.blocks: |
| x = blk(x) |
| x = self.encoder.norm(x) |
| if not self.encoder_cls_token: |
| return x[:, 1:] |
|
|
| return x |
|
|
| def encode_dinov2(self, x): |
| |
| x = self.encoder.prepare_tokens_with_masks(x, masks=None) |
| for blk in self.encoder.blocks: |
| x = blk(x) |
| x_norm = self.encoder.norm(x) |
|
|
| if not self.encoder_cls_token: |
| return x_norm[:, 1:] |
| |
| |
|
|
| |
| |
| |
| |
|
|
| return x_norm |
|
|
| def encode_dinov2_uvit(self, x): |
| |
| x = self.encoder.prepare_tokens_with_masks(x, masks=None) |
|
|
| |
| |
|
|
| skips = [x] |
|
|
| |
| for blk in self.encoder.blocks[0:len(self.encoder.blocks) // 2 - 1]: |
| x = blk(x) |
| skips.append(x) |
|
|
| |
| for blk in self.encoder.blocks[len(self.encoder.blocks) // 2 - |
| 1:len(self.encoder.blocks) // 2]: |
| x = blk(x) |
|
|
| |
| for blk in self.encoder.blocks[len(self.encoder.blocks) // 2:]: |
| x = x + blk.skip_linear(torch.cat( |
| [x, skips.pop()], dim=-1)) |
| x = blk(x) |
|
|
| x_norm = self.encoder.norm(x) |
|
|
| if not self.decoder_cls_token: |
| return x_norm[:, 1:] |
|
|
| return x_norm |
|
|
| def encode_clip(self, x): |
| |
| |
| |
| x = self.encoder.conv1(x) |
| x = x.reshape(x.shape[0], x.shape[1], |
| -1) |
| x = x.permute(0, 2, 1) |
| x = torch.cat([ |
| self.encoder.class_embedding.to(x.dtype) + torch.zeros( |
| x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x |
| ], |
| dim=1) |
| x = x + self.encoder.positional_embedding.to(x.dtype) |
| x = self.encoder.ln_pre(x) |
|
|
| x = x.permute(1, 0, 2) |
| x = self.encoder.transformer(x) |
| x = x.permute(1, 0, 2) |
| x = self.encoder.ln_post(x[:, 1:, :]) |
|
|
| return x |
|
|
| |
|
|
| |
| |
|
|
| |
|
|
| def decode_wo_triplane(self, latent, c=None, img_size=None): |
| if img_size is None: |
| img_size = self.img_size |
|
|
| if self.dim_up_mlp is not None: |
| if not self.dim_up_mlp_as_func: |
| latent = self.dim_up_mlp(latent) |
| |
| else: |
| return self.decoder.vit_decode( |
| latent, img_size, |
| dim_up_mlp=self.dim_up_mlp) |
|
|
| return self.decoder.vit_decode(latent, img_size, c=c) |
|
|
| def decode(self, latent, c, img_size=None, return_raw_only=False): |
| |
| |
|
|
| |
| |
|
|
| latent = self.decode_wo_triplane(latent, img_size=img_size, c=c) |
| |
| return self.decoder.triplane_decode(latent, c) |
|
|
| def decode_after_vae_no_render( |
| self, |
| ret_dict, |
| img_size=None, |
| ): |
|
|
| if img_size is None: |
| img_size = self.img_size |
|
|
| assert self.dim_up_mlp is None |
| |
| |
| |
|
|
| latent = self.decoder.vit_decode_backbone(ret_dict, img_size) |
| ret_dict = self.decoder.vit_decode_postprocess(latent, ret_dict) |
| return ret_dict |
|
|
| def decode_after_vae( |
| self, |
| |
| ret_dict, |
| c, |
| img_size=None, |
| return_raw_only=False): |
| ret_dict = self.decode_after_vae_no_render(ret_dict, img_size) |
| return self.decoder.triplane_decode(ret_dict, c) |
|
|
| def decode_confmap(self, img): |
| assert self.confnet is not None |
| |
| |
| return self.confnet(img) |
|
|
| def encode_decode(self, img, c, return_raw_only=False): |
| latent = self.encode(img) |
| pred = self.decode(latent, c, return_raw_only=return_raw_only) |
| if self.confnet is not None: |
| pred.update({ |
| 'conf_sigma': self.decode_confmap(img) |
| }) |
|
|
| return pred |
|
|
| def forward(self, |
| img=None, |
| c=None, |
| latent=None, |
| behaviour='enc_dec', |
| coordinates=None, |
| directions=None, |
| return_raw_only=False, |
| *args, |
| **kwargs): |
| """wrap all operations inside forward() for DDP use. |
| """ |
|
|
| if behaviour == 'enc_dec': |
| pred = self.encode_decode(img, c, return_raw_only=return_raw_only) |
| return pred |
|
|
| elif behaviour == 'enc': |
| latent = self.encode(img) |
| return latent |
|
|
| elif behaviour == 'dec': |
| assert latent is not None |
| pred: dict = self.decode(latent, |
| c, |
| self.img_size, |
| return_raw_only=return_raw_only) |
| return pred |
|
|
| elif behaviour == 'dec_wo_triplane': |
| assert latent is not None |
| pred: dict = self.decode_wo_triplane(latent, self.img_size) |
| return pred |
|
|
| elif behaviour == 'enc_dec_wo_triplane': |
| latent = self.encode(img) |
| pred: dict = self.decode_wo_triplane(latent, img_size=self.img_size, c=c) |
| return pred |
|
|
| elif behaviour == 'encoder_vae': |
| latent = self.encode(img) |
| ret_dict = self.decoder.vae_reparameterization(latent, True) |
| return ret_dict |
|
|
| elif behaviour == 'decode_after_vae_no_render': |
| pred: dict = self.decode_after_vae_no_render(latent, self.img_size) |
| return pred |
|
|
| elif behaviour == 'decode_after_vae': |
| pred: dict = self.decode_after_vae(latent, c, self.img_size) |
| return pred |
|
|
| |
| |
| |
| |
| |
|
|
| elif behaviour == 'triplane_dec': |
| assert latent is not None |
| pred: dict = self.decoder.triplane_decode( |
| latent, c, return_raw_only=return_raw_only, **kwargs) |
| |
|
|
| elif behaviour == 'triplane_decode_grid': |
| assert latent is not None |
| pred: dict = self.decoder.triplane_decode_grid( |
| latent, **kwargs) |
| |
|
|
| elif behaviour == 'vit_postprocess_triplane_dec': |
| assert latent is not None |
| latent = self.decoder.vit_decode_postprocess( |
| latent) |
| pred: dict = self.decoder.triplane_decode( |
| latent, c) |
|
|
| elif behaviour == 'triplane_renderer': |
| assert latent is not None |
| pred: dict = self.decoder.triplane_renderer( |
| latent, coordinates, directions) |
|
|
| |
| |
| |
| |
|
|
| elif behaviour == 'get_rendering_kwargs': |
| pred = self.decoder.triplane_decoder.rendering_kwargs |
|
|
| return pred |
|
|
|
|
| class AE_CLIPEncoder(AE): |
|
|
| def __init__(self, encoder, decoder, img_size, cls_token) -> None: |
| super().__init__(encoder, decoder, img_size, cls_token) |
|
|
|
|
| class AE_with_Diffusion(torch.nn.Module): |
|
|
| def __init__(self, auto_encoder, denoise_model) -> None: |
| super().__init__() |
| self.auto_encoder = auto_encoder |
| self.denoise_model = denoise_model |
|
|
| def forward(self, |
| img, |
| c, |
| behaviour='enc_dec', |
| latent=None, |
| *args, |
| **kwargs): |
| |
| if behaviour == 'enc_dec': |
| pred = self.auto_encoder(img, c) |
| return pred |
| elif behaviour == 'enc': |
| latent = self.auto_encoder.encode(img) |
| if self.auto_encoder.dim_up_mlp is not None: |
| latent = self.auto_encoder.dim_up_mlp(latent) |
| return latent |
| elif behaviour == 'dec': |
| assert latent is not None |
| pred: dict = self.auto_encoder.decode(latent, c, self.img_size) |
| return pred |
| elif behaviour == 'denoise': |
| assert latent is not None |
| pred: dict = self.denoise_model(*args, **kwargs) |
| return pred |
|
|
|
|
| def eg3d_options_default(): |
|
|
| opts = dnnlib.EasyDict( |
| dict( |
| cbase=32768, |
| cmax=512, |
| map_depth=2, |
| g_class_name='nsr.triplane.TriPlaneGenerator', |
| g_num_fp16_res=0, |
| )) |
|
|
| return opts |
|
|
|
|
| def rendering_options_defaults(opts): |
|
|
| rendering_options = { |
| |
| 'image_resolution': 256, |
| 'disparity_space_sampling': False, |
| 'clamp_mode': 'softplus', |
| 'c_gen_conditioning_zero': |
| True, |
| |
| 'c_scale': |
| opts.c_scale, |
| 'superresolution_noise_mode': 'none', |
| 'density_reg': opts.density_reg, |
| 'density_reg_p_dist': opts. |
| density_reg_p_dist, |
| 'reg_type': opts. |
| reg_type, |
| 'decoder_lr_mul': 1, |
| |
| 'decoder_activation': 'sigmoid', |
| 'sr_antialias': True, |
| 'return_triplane_features': False, |
| 'return_sampling_details_flag': False, |
|
|
| |
|
|
| |
| |
| |
| 'superresolution_module': 'utils.torch_utils.components.NearestConvSR', |
| } |
|
|
| if opts.cfg == 'ffhq': |
| rendering_options.update({ |
| 'superresolution_module': |
| 'nsr.superresolution.SuperresolutionHybrid8XDC', |
| 'focal': 2985.29 / 700, |
| 'depth_resolution': |
| 48 - 0, |
| 'depth_resolution_importance': |
| 48 - 0, |
| 'bg_depth_resolution': |
| 16, |
| 'ray_start': |
| 2.25, |
| 'ray_end': |
| 3.3, |
| 'box_warp': |
| 1, |
| 'avg_camera_radius': |
| 2.7, |
| 'avg_camera_pivot': [ |
| 0, 0, 0.2 |
| ], |
| 'superresolution_noise_mode': 'random', |
| }) |
| elif opts.cfg == 'afhq': |
| rendering_options.update({ |
| 'superresolution_module': |
| 'nsr.superresolution.SuperresolutionHybrid8X', |
| 'superresolution_noise_mode': 'random', |
| 'focal': 4.2647, |
| 'depth_resolution': 48, |
| 'depth_resolution_importance': 48, |
| 'ray_start': 2.25, |
| 'ray_end': 3.3, |
| 'box_warp': 1, |
| 'avg_camera_radius': 2.7, |
| 'avg_camera_pivot': [0, 0, -0.06], |
| }) |
| elif opts.cfg == 'shapenet': |
| rendering_options.update({ |
| 'depth_resolution': 64, |
| 'depth_resolution_importance': 64, |
| |
| 'ray_start': 0.2, |
| 'ray_end': 2.2, |
| |
| |
| 'box_warp': 2, |
| 'white_back': True, |
| 'avg_camera_radius': 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
|
|
| elif opts.cfg == 'eg3d_shapenet_aug_resolution': |
| rendering_options.update({ |
| 'depth_resolution': 80, |
| 'depth_resolution_importance': 80, |
| 'ray_start': 0.1, |
| 'ray_end': 1.9, |
| 'box_warp': 1.1, |
| 'white_back': True, |
| 'avg_camera_radius': 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
|
|
| elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair': |
| rendering_options.update({ |
| 'depth_resolution': 96, |
| 'depth_resolution_importance': 96, |
| 'ray_start': 0.1, |
| 'ray_end': 1.9, |
| 'box_warp': 1.1, |
| 'white_back': True, |
| 'avg_camera_radius': 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
|
|
| elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_128': |
| rendering_options.update({ |
| 'depth_resolution': 128, |
| 'depth_resolution_importance': 128, |
| 'ray_start': 0.1, |
| 'ray_end': 1.9, |
| 'box_warp': 1.1, |
| 'white_back': True, |
| 'avg_camera_radius': 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
|
|
| elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_64': |
| rendering_options.update({ |
| 'depth_resolution': 64, |
| 'depth_resolution_importance': 64, |
| 'ray_start': 0.1, |
| 'ray_end': 1.9, |
| 'box_warp': 1.1, |
| 'white_back': True, |
| 'avg_camera_radius': 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
|
|
| elif opts.cfg == 'srn_shapenet_aug_resolution_chair_128': |
| rendering_options.update({ |
| 'depth_resolution': 128, |
| 'depth_resolution_importance': 128, |
| 'ray_start': 1.25, |
| 'ray_end': 2.75, |
| 'box_warp': 1.5, |
| 'white_back': True, |
| 'avg_camera_radius': 2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
|
|
| elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_128_residualSR': |
| rendering_options.update({ |
| 'depth_resolution': |
| 128, |
| 'depth_resolution_importance': |
| 128, |
| 'ray_start': |
| 0.1, |
| 'ray_end': |
| 1.9, |
| 'box_warp': |
| 1.1, |
| 'white_back': |
| True, |
| 'avg_camera_radius': |
| 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| 'superresolution_module': |
| 'utils.torch_utils.components.NearestConvSR_Residual', |
| }) |
|
|
| elif opts.cfg == 'shapenet_tuneray': |
| rendering_options.update({ |
| 'depth_resolution': 64, |
| 'depth_resolution_importance': 64, |
| |
| 'ray_start': opts.ray_start, |
| 'ray_end': opts.ray_end, |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| 'white_back': True, |
| 'avg_camera_radius': 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
|
|
| elif opts.cfg == 'shapenet_tuneray_aug_resolution': |
| rendering_options.update({ |
| 'depth_resolution': 80, |
| 'depth_resolution_importance': 80, |
| |
| 'ray_start': opts.ray_start, |
| 'ray_end': opts.ray_end, |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| 'white_back': True, |
| 'avg_camera_radius': 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
|
|
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64': |
| rendering_options.update({ |
| 'depth_resolution': 128, |
| 'depth_resolution_importance': 128, |
| |
| 'ray_start': opts.ray_start, |
| 'ray_end': opts.ray_end, |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| 'white_back': True, |
| 'avg_camera_radius': 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
|
|
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96': |
| rendering_options.update({ |
| 'depth_resolution': 96, |
| 'depth_resolution_importance': 96, |
| |
| 'ray_start': opts.ray_start, |
| 'ray_end': opts.ray_end, |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| 'white_back': True, |
| 'avg_camera_radius': 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
| |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96_nearestSR': |
| rendering_options.update({ |
| 'depth_resolution': |
| 96, |
| 'depth_resolution_importance': |
| 96, |
| |
| 'ray_start': |
| opts.ray_start, |
| 'ray_end': |
| opts.ray_end, |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| 'white_back': |
| True, |
| 'avg_camera_radius': |
| 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| 'superresolution_module': |
| 'utils.torch_utils.components.NearestConvSR', |
| }) |
|
|
| |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestSR': |
| rendering_options.update({ |
| 'depth_resolution': |
| 64, |
| 'depth_resolution_importance': |
| 64, |
| |
| 'ray_start': |
| opts.ray_start, |
| 'ray_end': |
| opts.ray_end, |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| 'white_back': |
| True, |
| 'avg_camera_radius': |
| 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| 'superresolution_module': |
| 'utils.torch_utils.components.NearestConvSR', |
| }) |
|
|
| |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestSR_patch': |
| rendering_options.update({ |
| 'depth_resolution': |
| 64, |
| 'depth_resolution_importance': |
| 64, |
| |
| 'ray_start': |
| opts.ray_start, |
| 'ray_end': |
| opts.ray_end, |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| 'white_back': |
| True, |
| 'avg_camera_radius': |
| 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| 'superresolution_module': |
| 'utils.torch_utils.components.NearestConvSR', |
| |
| 'PatchRaySampler': |
| True, |
| |
| |
| 'patch_rendering_resolution': |
| opts.patch_rendering_resolution, |
| }) |
|
|
| elif opts.cfg == 'objverse_tuneray_aug_resolution_64_64_nearestSR': |
| rendering_options.update({ |
| 'depth_resolution': |
| 64, |
| 'depth_resolution_importance': |
| 64, |
| |
| 'ray_start': |
| opts.ray_start, |
| |
| 'ray_end': |
| opts.ray_end, |
| |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| |
| 'white_back': |
| True, |
| 'avg_camera_radius': |
| 1.946, |
| 'avg_camera_pivot': [0, 0, 0], |
| 'superresolution_module': |
| 'utils.torch_utils.components.NearestConvSR', |
| |
| |
| |
| |
| |
| }) |
|
|
| elif opts.cfg == 'objverse_tuneray_aug_resolution_64_64_auto': |
| rendering_options.update({ |
| 'depth_resolution': |
| 64, |
| 'depth_resolution_importance': |
| 64, |
| |
| 'ray_start': |
| 'auto', |
| 'ray_end': |
| 'auto', |
| 'box_warp': |
| 0.9, |
| 'white_back': |
| True, |
| 'radius_range': [1.5,2], |
| |
| |
| 'sampler_bbox_min': |
| -0.45, |
| 'sampler_bbox_max': |
| 0.45, |
| |
| 'filter_out_of_bbox': |
| True, |
| |
| |
| |
| 'PatchRaySampler': |
| True, |
| |
| |
| 'patch_rendering_resolution': |
| opts.patch_rendering_resolution, |
| }) |
| rendering_options['z_near'] = rendering_options['radius_range'][0]+rendering_options['sampler_bbox_min'] |
| rendering_options['z_far'] = rendering_options['radius_range'][1]+rendering_options['sampler_bbox_max'] |
|
|
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96_nearestResidualSR': |
| rendering_options.update({ |
| 'depth_resolution': |
| 96, |
| 'depth_resolution_importance': |
| 96, |
| |
| 'ray_start': |
| opts.ray_start, |
| 'ray_end': |
| opts.ray_end, |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| 'white_back': |
| True, |
| 'avg_camera_radius': |
| 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| 'superresolution_module': |
| 'utils.torch_utils.components.NearestConvSR_Residual', |
| }) |
|
|
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestResidualSR': |
| rendering_options.update({ |
| 'depth_resolution': |
| 64, |
| 'depth_resolution_importance': |
| 64, |
| |
| 'ray_start': |
| opts.ray_start, |
| 'ray_end': |
| opts.ray_end, |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| 'white_back': |
| True, |
| 'avg_camera_radius': |
| 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| 'superresolution_module': |
| 'utils.torch_utils.components.NearestConvSR_Residual', |
| }) |
|
|
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_104': |
| rendering_options.update({ |
| 'depth_resolution': 104, |
| 'depth_resolution_importance': 104, |
| |
| 'ray_start': opts.ray_start, |
| 'ray_end': opts.ray_end, |
| 'box_warp': |
| opts.ray_end - opts.ray_start, |
| 'white_back': True, |
| 'avg_camera_radius': 1.2, |
| 'avg_camera_pivot': [0, 0, 0], |
| }) |
|
|
| rendering_options.update({'return_sampling_details_flag': True}) |
| rendering_options.update({'return_sampling_details_flag': True}) |
|
|
| return rendering_options |
|
|
|
|
| def model_encoder_defaults(): |
|
|
| return dict( |
| use_clip=False, |
| arch_encoder="vits", |
| arch_decoder="vits", |
| load_pretrain_encoder=False, |
| encoder_lr=1e-5, |
| encoder_weight_decay= |
| 0.001, |
| no_dim_up_mlp=False, |
| dim_up_mlp_as_func=False, |
| decoder_load_pretrained=True, |
| uvit_skip_encoder=False, |
| |
| vae_p=1, |
| ldm_z_channels=4, |
| ldm_embed_dim=4, |
| use_conf_map=False, |
| |
| sd_E_ch=64, |
| z_channels=3*4, |
| sd_E_num_res_blocks=1, |
| |
| arch_dit_decoder='DiT2-B/2', |
| return_all_dit_layers=False, |
| |
| |
| |
| |
| lrm_decoder=False, |
| gs_rendering=False, |
| ) |
|
|
|
|
| def triplane_decoder_defaults(): |
| opts = dict( |
| triplane_fg_bg=False, |
| cfg='shapenet', |
| density_reg=0.25, |
| density_reg_p_dist=0.004, |
| reg_type='l1', |
| triplane_decoder_lr=0.0025, |
| super_resolution_lr=0.0025, |
| |
| c_scale=1, |
| nsr_lr=0.02, |
| triplane_size=224, |
| decoder_in_chans=32, |
| triplane_in_chans=-1, |
| decoder_output_dim=3, |
| out_chans=96, |
| c_dim=25, |
| |
| |
| ray_start=0.6, |
| ray_end=1.8, |
| rendering_kwargs={}, |
| sr_training=False, |
| bcg_synthesis=False, |
| bcg_synthesis_kwargs={}, |
| |
| image_size=128, |
| patch_rendering_resolution=45, |
| ) |
|
|
| |
| |
|
|
| |
| |
|
|
| return opts |
|
|
|
|
| def vit_decoder_defaults(): |
| res = dict( |
| vit_decoder_lr=1e-5, |
| vit_decoder_wd=0.001, |
| ) |
| return res |
|
|
|
|
| def nsr_decoder_defaults(): |
| res = { |
| 'decomposed': False, |
| } |
| res.update(triplane_decoder_defaults()) |
| res.update(vit_decoder_defaults()) |
| return res |
|
|
|
|
| def loss_defaults(): |
| opt = dict( |
| color_criterion='mse', |
| l2_lambda=1.0, |
| lpips_lambda=0., |
| lpips_delay_iter=0, |
| sr_delay_iter=0, |
| |
| kl_anneal=False, |
| latent_lambda=0., |
| latent_criterion='mse', |
| kl_lambda=0.0, |
| |
| ssim_lambda=0., |
| l1_lambda=0., |
| id_lambda=0.0, |
| depth_lambda=0.0, |
| alpha_lambda=0.0, |
| fg_mse=False, |
| bg_lamdba=0.0, |
| density_reg=0.0, |
| density_reg_p_dist=0.004, |
| density_reg_every=4, |
|
|
| |
| shape_uniform_lambda=0.005, |
| shape_importance_lambda=0.01, |
| shape_depth_lambda=0., |
|
|
| |
| rec_cvD_lambda=0.01, |
| nvs_cvD_lambda=0.025, |
| patchgan_disc_factor=0.01, |
| patchgan_disc_g_weight=0.2, |
| r1_gamma=1.0, |
| sds_lamdba=1.0, |
| nvs_D_lr_mul=1, |
| cano_D_lr_mul=1, |
|
|
| |
| ce_balanced_kl=1., |
| p_eps_lambda=1, |
| |
| symmetry_loss=False, |
| depth_smoothness_lambda=0.0, |
| ce_lambda=1.0, |
| negative_entropy_lambda=1.0, |
| grad_clip=False, |
| online_mask=False, |
| ) |
| return opt |
|
|
|
|
| def dataset_defaults(): |
| res = dict( |
| use_lmdb=False, |
| use_wds=False, |
| use_lmdb_compressed=True, |
| compile=False, |
| interval=1, |
| objv_dataset=False, |
| decode_encode_img_only=False, |
| load_wds_diff=False, |
| load_wds_latent=False, |
| eval_load_wds_instance=True, |
| shards_lst="", |
| eval_shards_lst="", |
| mv_input=False, |
| duplicate_sample=True, |
| orthog_duplicate=False, |
| split_chunk_input=False, |
| load_real=False, |
| four_view_for_latent=False, |
| single_view_for_i23d=False, |
| shuffle_across_cls=False, |
| load_extra_36_view=False, |
| mv_latent_dir='', |
| append_depth=False, |
| plucker_embedding=False, |
| gs_cam_format=False, |
| ) |
| return res |
|
|
|
|
| def encoder_and_nsr_defaults(): |
| """ |
| Defaults for image training. |
| """ |
| |
| res = dict( |
| dino_version='v1', |
| encoder_in_channels=3, |
| img_size=[224], |
| patch_size=16, |
| in_chans=384, |
| num_classes=0, |
| embed_dim=384, |
| depth=6, |
| num_heads=16, |
| mlp_ratio=4., |
| qkv_bias=False, |
| qk_scale=None, |
| drop_rate=0.1, |
| attn_drop_rate=0., |
| drop_path_rate=0., |
| norm_layer='nn.LayerNorm', |
| |
| cls_token=False, |
| |
| |
| encoder_cls_token=False, |
| decoder_cls_token=False, |
| sr_kwargs={}, |
| sr_ratio=2, |
| |
| ) |
| |
| res.update(model_encoder_defaults()) |
| res.update(nsr_decoder_defaults()) |
| res.update( |
| ae_classname='vit.vit_triplane.ViTTriplaneDecomposed') |
| return res |
|
|
|
|
| def create_3DAE_model( |
| arch_encoder, |
| arch_decoder, |
| dino_version='v1', |
| img_size=[224], |
| patch_size=16, |
| in_chans=384, |
| num_classes=0, |
| embed_dim=1024, |
| depth=6, |
| num_heads=16, |
| mlp_ratio=4., |
| qkv_bias=False, |
| qk_scale=None, |
| drop_rate=0.1, |
| attn_drop_rate=0., |
| drop_path_rate=0., |
| |
| norm_layer='nn.LayerNorm', |
| out_chans=96, |
| decoder_in_chans=32, |
| triplane_in_chans=-1, |
| decoder_output_dim=32, |
| encoder_cls_token=False, |
| decoder_cls_token=False, |
| c_dim=25, |
| image_size=128, |
| img_channels=3, |
| rendering_kwargs={}, |
| load_pretrain_encoder=False, |
| decomposed=True, |
| triplane_size=224, |
| ae_classname='ViTTriplaneDecomposed', |
| use_clip=False, |
| sr_kwargs={}, |
| sr_ratio=2, |
| no_dim_up_mlp=False, |
| dim_up_mlp_as_func=False, |
| decoder_load_pretrained=True, |
| uvit_skip_encoder=False, |
| bcg_synthesis_kwargs={}, |
| |
| vae_p=1, |
| ldm_z_channels=4, |
| ldm_embed_dim=4, |
| use_conf_map=False, |
| triplane_fg_bg=False, |
| encoder_in_channels=3, |
| sd_E_ch=64, |
| z_channels=3*4, |
| sd_E_num_res_blocks=1, |
| arch_dit_decoder='DiT2-B/2', |
| lrm_decoder=False, |
| gs_rendering=False, |
| return_all_dit_layers=False, |
| *args, |
| **kwargs): |
|
|
| |
|
|
| preprocess = None |
| clip_dtype = None |
| if load_pretrain_encoder: |
| if not use_clip: |
| if dino_version == 'v1': |
| encoder = torch.hub.load( |
| 'facebookresearch/dino:main', |
| 'dino_{}{}'.format(arch_encoder, patch_size)) |
| logger.log( |
| f'loaded pre-trained dino v1 ViT-S{patch_size} encoder ckpt' |
| ) |
| elif dino_version == 'v2': |
| encoder = torch.hub.load( |
| 'facebookresearch/dinov2', |
| 'dinov2_{}{}'.format(arch_encoder, patch_size)) |
| logger.log( |
| f'loaded pre-trained dino v2 {arch_encoder}{patch_size} encoder ckpt' |
| ) |
| elif 'sd' in dino_version: |
|
|
| if 'mv' in dino_version: |
| if 'lgm' in dino_version: |
| encoder_cls = MVUNet( |
| input_size=256, |
| up_channels=(1024, 1024, 512, 256, |
| 128), |
| up_attention=(True, True, True, False, False), |
| splat_size=128, |
| output_size= |
| 512, |
| batch_size=8, |
| num_views=8, |
| gradient_accumulation_steps=1, |
| |
| ) |
| elif 'gs' in dino_version: |
| encoder_cls = MVEncoder |
| else: |
| encoder_cls = MVEncoder |
|
|
| else: |
| encoder_cls = Encoder |
|
|
| encoder = encoder_cls( |
| double_z=True, |
| resolution=256, |
| in_channels=encoder_in_channels, |
| |
| ch=64, |
| |
| |
| ch_mult=[1, 2, 4, 4], |
| num_res_blocks=1, |
| dropout=0.0, |
| attn_resolutions=[], |
| out_ch=3, |
| z_channels=4 * 3, |
| ) |
| else: |
| raise NotImplementedError() |
|
|
| else: |
| import clip |
| model, preprocess = clip.load("ViT-B/16", device=dist_util.dev()) |
| model.float() |
| clip_dtype = model.dtype |
| encoder = getattr( |
| model, 'visual') |
| encoder.requires_grad_(False) |
| logger.log( |
| f'loaded pre-trained CLIP ViT-B{patch_size} encoder, fixed.') |
|
|
| elif 'sd' in dino_version: |
| attn_kwargs = {} |
| if 'mv' in dino_version: |
| if 'lgm' in dino_version: |
| encoder = LGM_MVEncoder( |
| in_channels=9, |
| |
| up_channels=(1024, 1024, 512, 256, |
| 128), |
| up_attention=(True, True, True, False, False), |
| |
| |
| |
| |
| |
| |
| |
| ) |
|
|
| elif 'gs' in dino_version: |
| encoder_cls = MVEncoderGS |
| attn_kwargs = { |
| 'n_heads': 8, |
| 'd_head': 64, |
| } |
|
|
| else: |
| encoder_cls = MVEncoder |
| attn_kwargs = { |
| 'n_heads': 8, |
| 'd_head': 64, |
| } |
|
|
| else: |
| encoder_cls = Encoder |
|
|
| if 'lgm' not in dino_version: |
| |
| encoder = encoder_cls( |
| double_z=True, |
| resolution=256, |
| in_channels=encoder_in_channels, |
| |
| |
| ch=sd_E_ch, |
| |
| |
| ch_mult=[1, 2, 4, 4], |
| |
| num_res_blocks=sd_E_num_res_blocks, |
| dropout=0.0, |
| attn_resolutions=[], |
| out_ch=3, |
| z_channels=z_channels, |
| attn_kwargs=attn_kwargs, |
| ) |
|
|
| else: |
| encoder = vits.__dict__[arch_encoder]( |
| patch_size=patch_size, |
| drop_path_rate=drop_path_rate, |
| img_size=img_size) |
|
|
| |
| |
| if triplane_in_chans == -1: |
| triplane_in_chans = decoder_in_chans |
|
|
| |
| |
| |
| triplane_renderer_cls = Triplane |
|
|
| |
| triplane_decoder = triplane_renderer_cls( |
| c_dim, |
| image_size, |
| img_channels, |
| rendering_kwargs=rendering_kwargs, |
| out_chans=out_chans, |
| |
| triplane_size=triplane_size, |
| decoder_in_chans=triplane_in_chans, |
| decoder_output_dim=decoder_output_dim, |
| sr_kwargs=sr_kwargs, |
| bcg_synthesis_kwargs=bcg_synthesis_kwargs, |
| lrm_decoder=lrm_decoder) |
|
|
| if load_pretrain_encoder: |
|
|
| if dino_version == 'v1': |
| vit_decoder = torch.hub.load( |
| 'facebookresearch/dino:main', |
| 'dino_{}{}'.format(arch_decoder, patch_size)) |
| logger.log( |
| 'loaded pre-trained decoder', |
| "facebookresearch/dino:main', 'dino_{}{}".format( |
| arch_decoder, patch_size)) |
| else: |
|
|
| vit_decoder = torch.hub.load( |
| 'facebookresearch/dinov2', |
| |
| 'dinov2_{}{}'.format(arch_decoder, patch_size), |
| pretrained=decoder_load_pretrained) |
| logger.log( |
| 'loaded pre-trained decoder', |
| "facebookresearch/dinov2', 'dinov2_{}{}".format( |
| arch_decoder, |
| patch_size), 'pretrianed=', decoder_load_pretrained) |
|
|
| elif 'dit' in dino_version: |
| from dit.dit_decoder import DiT2_models |
|
|
| vit_decoder = DiT2_models[arch_dit_decoder]( |
| input_size=16, |
| num_classes=0, |
| learn_sigma=False, |
| in_channels=embed_dim, |
| mixed_prediction=False, |
| context_dim=None, |
| roll_out=True, plane_n=4 if |
| 'gs' in dino_version else 3, |
| return_all_layers=return_all_dit_layers, |
| ) |
|
|
| else: |
| vit_decoder = vits.__dict__[arch_decoder]( |
| patch_size=patch_size, |
| drop_path_rate=drop_path_rate, |
| img_size=img_size) |
|
|
| |
| |
| decoder_kwargs = dict( |
| class_name=ae_classname, |
| vit_decoder=vit_decoder, |
| triplane_decoder=triplane_decoder, |
| |
| cls_token=decoder_cls_token, |
| sr_ratio=sr_ratio, |
| vae_p=vae_p, |
| ldm_z_channels=ldm_z_channels, |
| ldm_embed_dim=ldm_embed_dim, |
| ) |
| decoder = dnnlib.util.construct_class_by_name(**decoder_kwargs) |
|
|
|
|
| |
| |
| |
|
|
| if use_conf_map: |
| confnet = ConfNet(cin=3, cout=1, nf=64, zdim=128) |
| else: |
| confnet = None |
|
|
| auto_encoder = AE( |
| encoder, |
| decoder, |
| img_size[0], |
| encoder_cls_token, |
| decoder_cls_token, |
| preprocess, |
| use_clip, |
| dino_version, |
| clip_dtype, |
| no_dim_up_mlp=no_dim_up_mlp, |
| dim_up_mlp_as_func=dim_up_mlp_as_func, |
| uvit_skip_encoder=uvit_skip_encoder, |
| confnet=confnet, |
| ) |
|
|
| logger.log(auto_encoder) |
| torch.cuda.empty_cache() |
|
|
| return auto_encoder |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def create_Triplane( |
| c_dim=25, |
| img_resolution=128, |
| img_channels=3, |
| rendering_kwargs={}, |
| decoder_output_dim=32, |
| *args, |
| **kwargs): |
|
|
| decoder = Triplane( |
| c_dim, |
| img_resolution, |
| img_channels, |
| |
| rendering_kwargs=rendering_kwargs, |
| create_triplane=True, |
| decoder_output_dim=decoder_output_dim) |
| return decoder |
|
|
|
|
| def DiT_defaults(): |
| return { |
| 'dit_model': "DiT-B/16", |
| 'vae': "ema" |
| |
| |
| } |
|
|