| import numpy as np |
| import pickle |
| import os |
| import torch |
| from torch.utils.data import TensorDataset |
| from torchvision.datasets import ImageFolder |
| import torchvision.transforms as transforms |
| from sklearn.model_selection import train_test_split |
|
|
|
|
| def set_up_data(H): |
| shift_loss = -127.5 |
| scale_loss = 1. / 127.5 |
| if H.dataset == 'imagenet32': |
| trX, vaX, teX = imagenet32(H.data_root) |
| H.image_size = 32 |
| H.image_channels = 3 |
| shift = -116.2373 |
| scale = 1. / 69.37404 |
| elif H.dataset == 'imagenet64': |
| trX, vaX, teX = imagenet64(H.data_root) |
| H.image_size = 64 |
| H.image_channels = 3 |
| shift = -115.92961967 |
| scale = 1. / 69.37404 |
| elif H.dataset == 'ffhq_256': |
| trX, vaX, teX = ffhq256(H.data_root) |
| H.image_size = 256 |
| H.image_channels = 3 |
| shift = -112.8666757481 |
| scale = 1. / 69.84780273 |
| elif H.dataset == 'ffhq_1024': |
| trX, vaX, teX = ffhq1024(H.data_root) |
| H.image_size = 1024 |
| H.image_channels = 3 |
| shift = -0.4387 |
| scale = 1.0 / 0.2743 |
| shift_loss = -0.5 |
| scale_loss = 2.0 |
| elif H.dataset == 'cifar10': |
| (trX, _), (vaX, _), (teX, _) = cifar10(H.data_root, one_hot=False) |
| H.image_size = 32 |
| H.image_channels = 3 |
| shift = -120.63838 |
| scale = 1. / 64.16736 |
| else: |
| raise ValueError('unknown dataset: ', H.dataset) |
|
|
| do_low_bit = H.dataset in ['ffhq_256'] |
|
|
| if H.test_eval: |
| print('DOING TEST') |
| eval_dataset = teX |
| else: |
| eval_dataset = vaX |
|
|
| shift = torch.tensor([shift]).cuda().view(1, 1, 1, 1) |
| scale = torch.tensor([scale]).cuda().view(1, 1, 1, 1) |
| shift_loss = torch.tensor([shift_loss]).cuda().view(1, 1, 1, 1) |
| scale_loss = torch.tensor([scale_loss]).cuda().view(1, 1, 1, 1) |
|
|
| if H.dataset == 'ffhq_1024': |
| train_data = ImageFolder(trX, transforms.ToTensor()) |
| valid_data = ImageFolder(eval_dataset, transforms.ToTensor()) |
| untranspose = True |
| else: |
| train_data = TensorDataset(torch.as_tensor(trX)) |
| valid_data = TensorDataset(torch.as_tensor(eval_dataset)) |
| untranspose = False |
|
|
| def preprocess_func(x): |
| nonlocal shift |
| nonlocal scale |
| nonlocal shift_loss |
| nonlocal scale_loss |
| nonlocal do_low_bit |
| nonlocal untranspose |
| 'takes in a data example and returns the preprocessed input' |
| 'as well as the input processed for the loss' |
| if untranspose: |
| x[0] = x[0].permute(0, 2, 3, 1) |
| inp = x[0].cuda(non_blocking=True).float() |
| out = inp.clone() |
| inp.add_(shift).mul_(scale) |
| if do_low_bit: |
| |
| out.mul_(1. / 8.).floor_().mul_(8.) |
| out.add_(shift_loss).mul_(scale_loss) |
| return inp, out |
|
|
| return H, train_data, valid_data, preprocess_func |
|
|
|
|
| def mkdir_p(path): |
| os.makedirs(path, exist_ok=True) |
|
|
|
|
| def flatten(outer): |
| return [el for inner in outer for el in inner] |
|
|
|
|
| def unpickle_cifar10(file): |
| fo = open(file, 'rb') |
| data = pickle.load(fo, encoding='bytes') |
| fo.close() |
| data = dict(zip([k.decode() for k in data.keys()], data.values())) |
| return data |
|
|
|
|
| def imagenet32(data_root): |
| trX = np.load(os.path.join(data_root, 'imagenet32-train.npy'), mmap_mode='r') |
| np.random.seed(42) |
| tr_va_split_indices = np.random.permutation(trX.shape[0]) |
| train = trX[tr_va_split_indices[:-5000]] |
| valid = trX[tr_va_split_indices[-5000:]] |
| test = np.load(os.path.join(data_root, 'imagenet32-valid.npy'), mmap_mode='r') |
| return train, valid, test |
|
|
|
|
| def imagenet64(data_root): |
| trX = np.load(os.path.join(data_root, 'imagenet64-train.npy'), mmap_mode='r') |
| np.random.seed(42) |
| tr_va_split_indices = np.random.permutation(trX.shape[0]) |
| train = trX[tr_va_split_indices[:-5000]] |
| valid = trX[tr_va_split_indices[-5000:]] |
| test = np.load(os.path.join(data_root, 'imagenet64-valid.npy'), mmap_mode='r') |
| return train, valid, test |
|
|
|
|
| def ffhq1024(data_root): |
| |
| return os.path.join(data_root, 'ffhq1024/train'), os.path.join(data_root, 'ffhq1024/valid'), os.path.join(data_root, 'ffhq1024/valid') |
|
|
|
|
| def ffhq256(data_root): |
| trX = np.load(os.path.join(data_root, 'ffhq-256.npy'), mmap_mode='r') |
| np.random.seed(5) |
| tr_va_split_indices = np.random.permutation(trX.shape[0]) |
| train = trX[tr_va_split_indices[:-7000]] |
| valid = trX[tr_va_split_indices[-7000:]] |
| |
| return train, valid, valid |
|
|
|
|
| def cifar10(data_root, one_hot=True): |
| tr_data = [unpickle_cifar10(os.path.join(data_root, 'cifar-10-batches-py/', 'data_batch_%d' % i)) for i in range(1, 6)] |
| trX = np.vstack(data['data'] for data in tr_data) |
| trY = np.asarray(flatten([data['labels'] for data in tr_data])) |
| te_data = unpickle_cifar10(os.path.join(data_root, 'cifar-10-batches-py/', 'test_batch')) |
| teX = np.asarray(te_data['data']) |
| teY = np.asarray(te_data['labels']) |
| trX = trX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) |
| teX = teX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) |
| trX, vaX, trY, vaY = train_test_split(trX, trY, test_size=5000, random_state=11172018) |
| if one_hot: |
| trY = np.eye(10, dtype=np.float32)[trY] |
| vaY = np.eye(10, dtype=np.float32)[vaY] |
| teY = np.eye(10, dtype=np.float32)[teY] |
| else: |
| trY = np.reshape(trY, [-1, 1]) |
| vaY = np.reshape(vaY, [-1, 1]) |
| teY = np.reshape(teY, [-1, 1]) |
| return (trX, trY), (vaX, vaY), (teX, teY) |
|
|