| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | import os |
| | import math |
| | import shutil |
| | import time |
| | import sys |
| | import types |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import torch.distributed as dist |
| | |
| | from torch.cuda.amp import autocast |
| |
|
| | |
| |
|
| |
|
| | class AvgrageMeter(object): |
| |
|
| | def __init__(self): |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.cnt = 0 |
| |
|
| | def update(self, val, n=1): |
| | self.sum += val * n |
| | self.cnt += n |
| | self.avg = self.sum / self.cnt |
| |
|
| |
|
| | class ExpMovingAvgrageMeter(object): |
| |
|
| | def __init__(self, momentum=0.9): |
| | self.momentum = momentum |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.avg = 0 |
| |
|
| | def update(self, val): |
| | self.avg = (1. - self.momentum) * self.avg + self.momentum * val |
| |
|
| |
|
| | class DummyDDP(nn.Module): |
| | def __init__(self, model): |
| | super(DummyDDP, self).__init__() |
| | self.module = model |
| |
|
| | def forward(self, *input, **kwargs): |
| | return self.module(*input, **kwargs) |
| |
|
| |
|
| | def count_parameters_in_M(model): |
| | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 |
| |
|
| |
|
| | def save_checkpoint(state, is_best, save): |
| | filename = os.path.join(save, 'checkpoint.pth.tar') |
| | torch.save(state, filename) |
| | if is_best: |
| | best_filename = os.path.join(save, 'model_best.pth.tar') |
| | shutil.copyfile(filename, best_filename) |
| |
|
| |
|
| | def save(model, model_path): |
| | torch.save(model.state_dict(), model_path) |
| |
|
| |
|
| | def load(model, model_path): |
| | model.load_state_dict(torch.load(model_path)) |
| |
|
| |
|
| | def create_exp_dir(path, scripts_to_save=None): |
| | if not os.path.exists(path): |
| | os.makedirs(path, exist_ok=True) |
| | print('Experiment dir : {}'.format(path)) |
| |
|
| | if scripts_to_save is not None: |
| | if not os.path.exists(os.path.join(path, 'scripts')): |
| | os.mkdir(os.path.join(path, 'scripts')) |
| | for script in scripts_to_save: |
| | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) |
| | shutil.copyfile(script, dst_file) |
| |
|
| |
|
| | class Logger(object): |
| | def __init__(self, rank, save): |
| | |
| | |
| | from importlib import reload |
| | reload(logging) |
| | self.rank = rank |
| | if self.rank == 0: |
| | log_format = '%(asctime)s %(message)s' |
| | logging.basicConfig(stream=sys.stdout, level=logging.INFO, |
| | format=log_format, datefmt='%m/%d %I:%M:%S %p') |
| | fh = logging.FileHandler(os.path.join(save, 'log.txt')) |
| | fh.setFormatter(logging.Formatter(log_format)) |
| | logging.getLogger().addHandler(fh) |
| | self.start_time = time.time() |
| |
|
| | def info(self, string, *args): |
| | if self.rank == 0: |
| | elapsed_time = time.time() - self.start_time |
| | elapsed_time = time.strftime( |
| | '(Elapsed: %H:%M:%S) ', time.gmtime(elapsed_time)) |
| | if isinstance(string, str): |
| | string = elapsed_time + string |
| | else: |
| | logging.info(elapsed_time) |
| | logging.info(string, *args) |
| |
|
| |
|
| | class Writer(object): |
| | def __init__(self, rank, save): |
| | self.rank = rank |
| | if self.rank == 0: |
| | self.writer = SummaryWriter(log_dir=save, flush_secs=20) |
| |
|
| | def add_scalar(self, *args, **kwargs): |
| | if self.rank == 0: |
| | self.writer.add_scalar(*args, **kwargs) |
| |
|
| | def add_figure(self, *args, **kwargs): |
| | if self.rank == 0: |
| | self.writer.add_figure(*args, **kwargs) |
| |
|
| | def add_image(self, *args, **kwargs): |
| | if self.rank == 0: |
| | self.writer.add_image(*args, **kwargs) |
| |
|
| | def add_histogram(self, *args, **kwargs): |
| | if self.rank == 0: |
| | self.writer.add_histogram(*args, **kwargs) |
| |
|
| | def add_histogram_if(self, write, *args, **kwargs): |
| | if write and False: |
| | self.add_histogram(*args, **kwargs) |
| |
|
| | def close(self, *args, **kwargs): |
| | if self.rank == 0: |
| | self.writer.close() |
| |
|
| |
|
| | def common_init(rank, seed, save_dir): |
| | |
| | torch.manual_seed(rank + seed) |
| | np.random.seed(rank + seed) |
| | torch.cuda.manual_seed(rank + seed) |
| | torch.cuda.manual_seed_all(rank + seed) |
| | torch.backends.cudnn.benchmark = True |
| |
|
| | |
| | logging = Logger(rank, save_dir) |
| | writer = Writer(rank, save_dir) |
| |
|
| | return logging, writer |
| |
|
| |
|
| | def reduce_tensor(tensor, world_size): |
| | rt = tensor.clone() |
| | dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
| | rt /= world_size |
| | return rt |
| |
|
| |
|
| | def get_stride_for_cell_type(cell_type): |
| | if cell_type.startswith('normal') or cell_type.startswith('combiner'): |
| | stride = 1 |
| | elif cell_type.startswith('down'): |
| | stride = 2 |
| | elif cell_type.startswith('up'): |
| | stride = -1 |
| | else: |
| | raise NotImplementedError(cell_type) |
| |
|
| | return stride |
| |
|
| |
|
| | def get_cout(cin, stride): |
| | if stride == 1: |
| | cout = cin |
| | elif stride == -1: |
| | cout = cin // 2 |
| | elif stride == 2: |
| | cout = 2 * cin |
| |
|
| | return cout |
| |
|
| |
|
| | def kl_balancer_coeff(num_scales, groups_per_scale, fun): |
| | if fun == 'equal': |
| | coeff = torch.cat([torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda() |
| | elif fun == 'linear': |
| | coeff = torch.cat([(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], |
| | dim=0).cuda() |
| | elif fun == 'sqrt': |
| | coeff = torch.cat( |
| | [np.sqrt(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], |
| | dim=0).cuda() |
| | elif fun == 'square': |
| | coeff = torch.cat( |
| | [np.square(2 ** i) / groups_per_scale[num_scales - i - 1] * torch.ones(groups_per_scale[num_scales - i - 1]) |
| | for i in range(num_scales)], dim=0).cuda() |
| | else: |
| | raise NotImplementedError |
| | |
| | coeff /= torch.min(coeff) |
| | return coeff |
| |
|
| |
|
| | def kl_per_group(kl_all): |
| | kl_vals = torch.mean(kl_all, dim=0) |
| | kl_coeff_i = torch.abs(kl_all) |
| | kl_coeff_i = torch.mean(kl_coeff_i, dim=0, keepdim=True) + 0.01 |
| |
|
| | return kl_coeff_i, kl_vals |
| |
|
| |
|
| | def kl_balancer(kl_all, kl_coeff=1.0, kl_balance=False, alpha_i=None): |
| | if kl_balance and kl_coeff < 1.0: |
| | alpha_i = alpha_i.unsqueeze(0) |
| |
|
| | kl_all = torch.stack(kl_all, dim=1) |
| | kl_coeff_i, kl_vals = kl_per_group(kl_all) |
| | total_kl = torch.sum(kl_coeff_i) |
| |
|
| | kl_coeff_i = kl_coeff_i / alpha_i * total_kl |
| | kl_coeff_i = kl_coeff_i / torch.mean(kl_coeff_i, dim=1, keepdim=True) |
| | kl = torch.sum(kl_all * kl_coeff_i.detach(), dim=1) |
| |
|
| | |
| | kl_coeffs = kl_coeff_i.squeeze(0) |
| | else: |
| | kl_all = torch.stack(kl_all, dim=1) |
| | kl_vals = torch.mean(kl_all, dim=0) |
| | |
| | |
| | kl = torch.mean(kl_all) |
| | kl_coeffs = torch.ones(size=(len(kl_vals),)) |
| |
|
| | return kl_coeff * kl, kl_coeffs, kl_vals |
| |
|
| |
|
| | def kl_per_group_vada(all_log_q, all_neg_log_p): |
| | assert len(all_log_q) == len(all_neg_log_p) |
| |
|
| | kl_all_list = [] |
| | kl_diag = [] |
| | for log_q, neg_log_p in zip(all_log_q, all_neg_log_p): |
| | |
| | kl_diag.append(torch.mean(torch.mean(neg_log_p + log_q, dim=[2, 3]), dim=0)) |
| | |
| | kl_all_list.append(torch.mean(neg_log_p + log_q, dim=[1, 2, 3])) |
| |
|
| | |
| | kl_vals = torch.mean(torch.stack(kl_all_list, dim=1), dim=0) |
| |
|
| | return kl_all_list, kl_vals, kl_diag |
| |
|
| |
|
| | def kl_coeff(step, total_step, constant_step, min_kl_coeff, max_kl_coeff): |
| | |
| | return max(min(min_kl_coeff + (max_kl_coeff - min_kl_coeff) * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) |
| |
|
| |
|
| | def log_iw(decoder, x, log_q, log_p, crop=False): |
| | recon = reconstruction_loss(decoder, x, crop) |
| | return - recon - log_q + log_p |
| |
|
| |
|
| | def reconstruction_loss(decoder, x, crop=False): |
| | from util.distributions import DiscMixLogistic |
| |
|
| | recon = decoder.log_p(x) |
| | if crop: |
| | recon = recon[:, :, 2:30, 2:30] |
| |
|
| | if isinstance(decoder, DiscMixLogistic): |
| | return - torch.sum(recon, dim=[1, 2]) |
| | else: |
| | return - torch.sum(recon, dim=[1, 2, 3]) |
| |
|
| |
|
| | def vae_terms(all_log_q, all_eps): |
| | from util.distributions import log_p_standard_normal |
| |
|
| | |
| | kl_all = [] |
| | kl_diag = [] |
| | log_p, log_q = 0., 0. |
| | for log_q_conv, eps in zip(all_log_q, all_eps): |
| | log_p_conv = log_p_standard_normal(eps) |
| | kl_per_var = log_q_conv - log_p_conv |
| | kl_diag.append(torch.mean(torch.sum(kl_per_var, dim=[2, 3]), dim=0)) |
| | kl_all.append(torch.sum(kl_per_var, dim=[1, 2, 3])) |
| | log_q += torch.sum(log_q_conv, dim=[1, 2, 3]) |
| | log_p += torch.sum(log_p_conv, dim=[1, 2, 3]) |
| | return log_q, log_p, kl_all, kl_diag |
| |
|
| |
|
| | def sum_log_q(all_log_q): |
| | log_q = 0. |
| | for log_q_conv in all_log_q: |
| | log_q += torch.sum(log_q_conv, dim=[1, 2, 3]) |
| |
|
| | return log_q |
| |
|
| |
|
| | def cross_entropy_normal(all_eps): |
| | from util.distributions import log_p_standard_normal |
| |
|
| | cross_entropy = 0. |
| | neg_log_p_per_group = [] |
| | for eps in all_eps: |
| | neg_log_p_conv = - log_p_standard_normal(eps) |
| | neg_log_p = torch.sum(neg_log_p_conv, dim=[1, 2, 3]) |
| | cross_entropy += neg_log_p |
| | neg_log_p_per_group.append(neg_log_p_conv) |
| |
|
| | return cross_entropy, neg_log_p_per_group |
| |
|
| |
|
| | def tile_image(batch_image, n, m=None): |
| | if m is None: |
| | m = n |
| | assert n * m == batch_image.size(0) |
| | channels, height, width = batch_image.size(1), batch_image.size(2), batch_image.size(3) |
| | batch_image = batch_image.view(n, m, channels, height, width) |
| | batch_image = batch_image.permute(2, 0, 3, 1, 4) |
| | batch_image = batch_image.contiguous().view(channels, n * height, m * width) |
| | return batch_image |
| |
|
| |
|
| | def average_gradients_naive(params, is_distributed): |
| | """ Gradient averaging. """ |
| | if is_distributed: |
| | size = float(dist.get_world_size()) |
| | for param in params: |
| | if param.requires_grad: |
| | param.grad.data /= size |
| | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) |
| |
|
| |
|
| | def average_gradients(params, is_distributed): |
| | """ Gradient averaging. """ |
| | if is_distributed: |
| | if isinstance(params, types.GeneratorType): |
| | params = [p for p in params] |
| |
|
| | size = float(dist.get_world_size()) |
| | grad_data = [] |
| | grad_size = [] |
| | grad_shapes = [] |
| | |
| | for param in params: |
| | if param.requires_grad: |
| | grad_size.append(param.grad.data.numel()) |
| | grad_shapes.append(list(param.grad.data.shape)) |
| | grad_data.append(param.grad.data.flatten()) |
| | grad_data = torch.cat(grad_data).contiguous() |
| |
|
| | |
| | grad_data /= size |
| | dist.all_reduce(grad_data, op=dist.ReduceOp.SUM) |
| |
|
| | |
| | base = 0 |
| | for i, param in enumerate(params): |
| | if param.requires_grad: |
| | param.grad.data = grad_data[base:base + grad_size[i]].view(grad_shapes[i]) |
| | base += grad_size[i] |
| |
|
| |
|
| | def average_params(params, is_distributed): |
| | """ parameter averaging. """ |
| | if is_distributed: |
| | size = float(dist.get_world_size()) |
| | for param in params: |
| | param.data /= size |
| | dist.all_reduce(param.data, op=dist.ReduceOp.SUM) |
| |
|
| |
|
| | def average_tensor(t, is_distributed): |
| | if is_distributed: |
| | size = float(dist.get_world_size()) |
| | dist.all_reduce(t.data, op=dist.ReduceOp.SUM) |
| | t.data /= size |
| |
|
| |
|
| | def broadcast_params(params, is_distributed): |
| | if is_distributed: |
| | for param in params: |
| | dist.broadcast(param.data, src=0) |
| |
|
| |
|
| | def num_output(dataset): |
| | if dataset in {'mnist', 'omniglot'}: |
| | return 28 * 28 |
| | elif dataset == 'cifar10': |
| | return 3 * 32 * 32 |
| | elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'): |
| | size = int(dataset.split('_')[-1]) |
| | return 3 * size * size |
| | elif dataset == 'ffhq': |
| | return 3 * 256 * 256 |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | def get_input_size(dataset): |
| | if dataset in {'mnist', 'omniglot'}: |
| | return 32 |
| | elif dataset == 'cifar10': |
| | return 32 |
| | elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'): |
| | size = int(dataset.split('_')[-1]) |
| | return size |
| | elif dataset == 'ffhq': |
| | return 256 |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | def get_bpd_coeff(dataset): |
| | n = num_output(dataset) |
| | return 1. / np.log(2.) / n |
| |
|
| |
|
| | def get_channel_multiplier(dataset, num_scales): |
| | if dataset in {'cifar10', 'omniglot'}: |
| | mult = (1, 1, 1) |
| | elif dataset in {'celeba_256', 'ffhq', 'lsun_church_256'}: |
| | if num_scales == 3: |
| | mult = (1, 1, 1) |
| | elif num_scales == 4: |
| | mult = (1, 2, 2, 2) |
| | elif num_scales == 5: |
| | mult = (1, 1, 2, 2, 2) |
| | elif dataset == 'mnist': |
| | mult = (1, 1) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return mult |
| |
|
| |
|
| | def get_attention_scales(dataset): |
| | if dataset in {'cifar10', 'omniglot'}: |
| | attn = (True, False, False) |
| | elif dataset in {'celeba_256', 'ffhq', 'lsun_church_256'}: |
| | |
| | attn = (False, False, True, False, False) |
| | elif dataset == 'mnist': |
| | attn = (True, False) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return attn |
| |
|
| |
|
| | def change_bit_length(x, num_bits): |
| | if num_bits != 8: |
| | x = torch.floor(x * 255 / 2 ** (8 - num_bits)) |
| | x /= (2 ** num_bits - 1) |
| | return x |
| |
|
| |
|
| | def view4D(t, size, inplace=True): |
| | """ |
| | Equal to view(-1, 1, 1, 1).expand(size) |
| | Designed because of this bug: |
| | https://github.com/pytorch/pytorch/pull/48696 |
| | """ |
| | if inplace: |
| | return t.unsqueeze_(-1).unsqueeze_(-1).unsqueeze_(-1).expand(size) |
| | else: |
| | return t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(size) |
| |
|
| |
|
| | def get_arch_cells(arch_type, use_se): |
| | if arch_type == 'res_mbconv': |
| | arch_cells = dict() |
| | arch_cells['normal_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['down_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['normal_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se} |
| | arch_cells['up_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se} |
| | arch_cells['normal_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['down_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['normal_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} |
| | arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} |
| | arch_cells['ar_nn'] = [''] |
| | elif arch_type == 'res_bnswish': |
| | arch_cells = dict() |
| | arch_cells['normal_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['down_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['normal_dec'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['up_dec'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['normal_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['down_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['normal_post'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['up_post'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['ar_nn'] = [''] |
| | elif arch_type == 'res_bnswish2': |
| | arch_cells = dict() |
| | arch_cells['normal_enc'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se} |
| | arch_cells['down_enc'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se} |
| | arch_cells['normal_dec'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se} |
| | arch_cells['up_dec'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se} |
| | arch_cells['normal_pre'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se} |
| | arch_cells['down_pre'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se} |
| | arch_cells['normal_post'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se} |
| | arch_cells['up_post'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se} |
| | arch_cells['ar_nn'] = [''] |
| | elif arch_type == 'res_mbconv_attn': |
| | arch_cells = dict() |
| | arch_cells['normal_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish', ], 'se': use_se, 'attn_type': 'attn'} |
| | arch_cells['down_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se, 'attn_type': 'attn'} |
| | arch_cells['normal_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} |
| | arch_cells['up_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} |
| | arch_cells['normal_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['down_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['normal_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} |
| | arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} |
| | arch_cells['ar_nn'] = [''] |
| | elif arch_type == 'res_mbconv_attn_half': |
| | arch_cells = dict() |
| | arch_cells['normal_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['down_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['normal_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} |
| | arch_cells['up_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} |
| | arch_cells['normal_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['down_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se} |
| | arch_cells['normal_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} |
| | arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} |
| | arch_cells['ar_nn'] = [''] |
| | else: |
| | raise NotImplementedError |
| |
|
| | return arch_cells |
| |
|
| |
|
| | def groups_per_scale(num_scales, num_groups_per_scale): |
| | g = [] |
| | n = num_groups_per_scale |
| | for s in range(num_scales): |
| | assert n >= 1 |
| | g.append(n) |
| | return g |
| |
|
| |
|
| | class PositionalEmbedding(nn.Module): |
| | def __init__(self, embedding_dim, scale): |
| | super(PositionalEmbedding, self).__init__() |
| | self.embedding_dim = embedding_dim |
| | self.scale = scale |
| |
|
| | def forward(self, timesteps): |
| | assert len(timesteps.shape) == 1 |
| | timesteps = timesteps * self.scale |
| | half_dim = self.embedding_dim // 2 |
| | emb = math.log(10000) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim) * -emb) |
| | emb = emb.to(device=timesteps.device) |
| | emb = timesteps[:, None] * emb[None, :] |
| | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
| | return emb |
| |
|
| |
|
| | class RandomFourierEmbedding(nn.Module): |
| | def __init__(self, embedding_dim, scale): |
| | super(RandomFourierEmbedding, self).__init__() |
| | self.w = nn.Parameter(torch.randn(size=(1, embedding_dim // 2)) * scale, requires_grad=False) |
| |
|
| | def forward(self, timesteps): |
| | emb = torch.mm(timesteps[:, None], self.w * 2 * 3.14159265359) |
| | return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
| |
|
| |
|
| | def init_temb_fun(embedding_type, embedding_scale, embedding_dim): |
| | if embedding_type == 'positional': |
| | temb_fun = PositionalEmbedding(embedding_dim, embedding_scale) |
| | elif embedding_type == 'fourier': |
| | temb_fun = RandomFourierEmbedding(embedding_dim, embedding_scale) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return temb_fun |
| |
|
| | def get_dae_model(args, num_input_channels): |
| | if args.dae_arch == 'ncsnpp': |
| | |
| | from score_sde.ncsnpp import NCSNpp |
| | dae = NCSNpp(args, num_input_channels) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return dae |
| |
|
| | def symmetrize_image_data(images): |
| | return 2.0 * images - 1.0 |
| |
|
| |
|
| | def unsymmetrize_image_data(images): |
| | return (images + 1.) / 2. |
| |
|
| |
|
| | def normalize_symmetric(images): |
| | """ |
| | Normalize images by dividing the largest intensity. Used for visualizing the intermediate steps. |
| | """ |
| | b = images.shape[0] |
| | m, _ = torch.max(torch.abs(images).view(b, -1), dim=1) |
| | images /= (m.view(b, 1, 1, 1) + 1e-3) |
| |
|
| | return images |
| |
|
| |
|
| | @torch.jit.script |
| | def soft_clamp5(x: torch.Tensor): |
| | return x.div(5.).tanh_().mul(5.) |
| |
|
| | @torch.jit.script |
| | def soft_clamp(x: torch.Tensor, a: torch.Tensor): |
| | return x.div(a).tanh_().mul(a) |
| |
|
| | class SoftClamp5(nn.Module): |
| | def __init__(self): |
| | super(SoftClamp5, self).__init__() |
| |
|
| | def forward(self, x): |
| | return soft_clamp5(x) |
| |
|
| |
|
| | def override_architecture_fields(args, stored_args, logging): |
| | |
| | architecture_fields = ['arch_instance', 'num_nf', 'num_latent_scales', 'num_groups_per_scale', |
| | 'num_latent_per_group', 'num_channels_enc', 'num_preprocess_blocks', |
| | 'num_preprocess_cells', 'num_cell_per_cond_enc', 'num_channels_dec', |
| | 'num_postprocess_blocks', 'num_postprocess_cells', 'num_cell_per_cond_dec', |
| | 'decoder_dist', 'num_x_bits', 'log_sig_q_scale', |
| | 'progressive_input_vae', 'channel_mult'] |
| |
|
| | |
| | """ We have broken backward compatibility. No need to se these manually |
| | if not hasattr(stored_args, 'log_sig_q_scale'): |
| | logging.info('*** Setting %s manually ****', 'log_sig_q_scale') |
| | setattr(stored_args, 'log_sig_q_scale', 5.) |
| | |
| | if not hasattr(stored_args, 'latent_grad_cutoff'): |
| | logging.info('*** Setting %s manually ****', 'latent_grad_cutoff') |
| | setattr(stored_args, 'latent_grad_cutoff', 0.) |
| | |
| | if not hasattr(stored_args, 'progressive_input_vae'): |
| | logging.info('*** Setting %s manually ****', 'progressive_input_vae') |
| | setattr(stored_args, 'progressive_input_vae', 'none') |
| | |
| | if not hasattr(stored_args, 'progressive_output_vae'): |
| | logging.info('*** Setting %s manually ****', 'progressive_output_vae') |
| | setattr(stored_args, 'progressive_output_vae', 'none') |
| | """ |
| |
|
| | if not hasattr(stored_args, 'num_x_bits'): |
| | logging.info('*** Setting %s manually ****', 'num_x_bits') |
| | setattr(stored_args, 'num_x_bits', 8) |
| |
|
| | if not hasattr(stored_args, 'channel_mult'): |
| | logging.info('*** Setting %s manually ****', 'channel_mult') |
| | setattr(stored_args, 'channel_mult', [1, 2]) |
| |
|
| | for f in architecture_fields: |
| | if not hasattr(args, f) or getattr(args, f) != getattr(stored_args, f): |
| | logging.info('Setting %s from loaded checkpoint', f) |
| | setattr(args, f, getattr(stored_args, f)) |
| |
|
| |
|
| | def init_processes(rank, size, fn, args): |
| | """ Initialize the distributed environment. """ |
| | os.environ['MASTER_ADDR'] = args.master_address |
| | os.environ['MASTER_PORT'] = '6020' |
| | torch.cuda.set_device(args.local_rank) |
| | dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=size) |
| | fn(args) |
| | dist.barrier() |
| | dist.destroy_process_group() |
| |
|
| |
|
| | def sample_rademacher_like(y): |
| | return torch.randint(low=0, high=2, size=y.shape, device='cuda') * 2 - 1 |
| |
|
| |
|
| | def sample_gaussian_like(y): |
| | return torch.randn_like(y, device='cuda') |
| |
|
| |
|
| | def trace_df_dx_hutchinson(f, x, noise, no_autograd): |
| | """ |
| | Hutchinson's trace estimator for Jacobian df/dx, O(1) call to autograd |
| | """ |
| | if no_autograd: |
| | |
| | torch.sum(f * noise).backward() |
| | |
| | jvp = x.grad |
| | trJ = torch.sum(jvp * noise, dim=[1, 2, 3]) |
| | x.grad = None |
| | else: |
| | jvp = torch.autograd.grad(f, x, noise, create_graph=False)[0] |
| | trJ = torch.sum(jvp * noise, dim=[1, 2, 3]) |
| | |
| |
|
| | return trJ |
| |
|
| | def different_p_q_objectives(iw_sample_p, iw_sample_q): |
| | assert iw_sample_p in ['ll_uniform', 'drop_all_uniform', 'll_iw', 'drop_all_iw', 'drop_sigma2t_iw', 'rescale_iw', |
| | 'drop_sigma2t_uniform'] |
| | assert iw_sample_q in ['reweight_p_samples', 'll_uniform', 'll_iw'] |
| | |
| | |
| | if iw_sample_p in ['ll_uniform', 'll_iw'] and iw_sample_q == 'reweight_p_samples': |
| | return False |
| | |
| | |
| | else: |
| | return True |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def get_mixed_prediction(mixed_prediction, param, mixing_logit, mixing_component=None): |
| | if mixed_prediction: |
| | assert mixing_component is not None, 'Provide mixing component when mixed_prediction is enabled.' |
| | coeff = torch.sigmoid(mixing_logit) |
| | param = (1 - coeff) * mixing_component + coeff * param |
| |
|
| | return param |
| |
|
| |
|
| | def set_vesde_sigma_max(args, vae, train_queue, logging, is_distributed): |
| | logging.info('') |
| | logging.info('Calculating max. pairwise distance in latent space to set sigma2_max for VESDE...') |
| |
|
| | eps_list = [] |
| | vae.eval() |
| | for step, x in enumerate(train_queue): |
| | x = x[0] if len(x) > 1 else x |
| | x = x.cuda() |
| | x = symmetrize_image_data(x) |
| |
|
| | |
| | with autocast(enabled=args.autocast_train): |
| | with torch.set_grad_enabled(False): |
| | logits, all_log_q, all_eps = vae(x) |
| | eps = torch.cat(all_eps, dim=1) |
| |
|
| | eps_list.append(eps.detach()) |
| |
|
| | |
| | eps_this_rank = torch.cat(eps_list, dim=0) |
| | if is_distributed: |
| | eps_all_gathered = [torch.zeros_like(eps_this_rank)] * dist.get_world_size() |
| | dist.all_gather(eps_all_gathered, eps_this_rank) |
| | eps_full = torch.cat(eps_all_gathered, dim=0) |
| | else: |
| | eps_full = eps_this_rank |
| |
|
| | |
| | eps_full = eps_full.cpu().float() |
| | eps_full = eps_full.flatten(start_dim=1).unsqueeze(0) |
| | max_pairwise_dist_sqr = torch.cdist(eps_full, eps_full).square().max() |
| | max_pairwise_dist_sqr = max_pairwise_dist_sqr.cuda() |
| |
|
| | |
| | if is_distributed: |
| | dist.broadcast(max_pairwise_dist_sqr, src=0) |
| |
|
| | args.sigma2_max = max_pairwise_dist_sqr.item() |
| |
|
| | logging.info('Done! Set args.sigma2_max set to {}'.format(args.sigma2_max)) |
| | logging.info('') |
| | return args |
| |
|
| |
|
| | def mask_inactive_variables(x, is_active): |
| | x = x * is_active |
| | return x |
| |
|
| |
|
| | def common_x_operations(x, num_x_bits): |
| | x = x[0] if len(x) > 1 else x |
| | x = x.cuda() |
| |
|
| | |
| | x = change_bit_length(x, num_x_bits) |
| | x = symmetrize_image_data(x) |
| |
|
| | return x |
| |
|