| | import cv2 |
| | import numpy as np |
| | import torchvision.datasets as datasets |
| | import torchvision.transforms as transforms |
| | import torchvision.transforms.functional as TF |
| | from torch.utils.data import Dataset |
| | from random import random, choice, shuffle |
| | from io import BytesIO |
| | from PIL import Image |
| | from PIL import ImageFile |
| | from scipy.ndimage.filters import gaussian_filter |
| | import pickle |
| | import os |
| | from skimage.io import imread |
| | from copy import deepcopy |
| |
|
| | ImageFile.LOAD_TRUNCATED_IMAGES = True |
| |
|
| |
|
| | MEAN = { |
| | "imagenet":[0.485, 0.456, 0.406], |
| | "clip":[0.48145466, 0.4578275, 0.40821073] |
| | } |
| |
|
| | STD = { |
| | "imagenet":[0.229, 0.224, 0.225], |
| | "clip":[0.26862954, 0.26130258, 0.27577711] |
| | } |
| |
|
| |
|
| |
|
| |
|
| | def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg"]): |
| | out = [] |
| | for r, d, f in os.walk(rootdir): |
| | for file in f: |
| | if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): |
| | out.append(os.path.join(r, file)) |
| | return out |
| |
|
| |
|
| | def get_list(path, must_contain=''): |
| | if ".pickle" in path: |
| | with open(path, 'rb') as f: |
| | image_list = pickle.load(f) |
| | image_list = [ item for item in image_list if must_contain in item ] |
| | else: |
| | image_list = recursively_read(path, must_contain) |
| | return image_list |
| |
|
| |
|
| |
|
| |
|
| | class RealFakeDataset(Dataset): |
| | def __init__(self, opt): |
| | assert opt.data_label in ["train", "val"] |
| | |
| | self.data_label = opt.data_label |
| | if opt.data_mode == 'ours': |
| | pickle_name = "train.pickle" if opt.data_label=="train" else "val.pickle" |
| | real_list = get_list( os.path.join(opt.real_list_path, pickle_name) ) |
| | fake_list = get_list( os.path.join(opt.fake_list_path, pickle_name) ) |
| | elif opt.data_mode == 'wang2020': |
| | temp = 'train/progan' if opt.data_label == 'train' else 'test/progan' |
| | real_list = get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='0_real' ) |
| | fake_list = get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='1_fake' ) |
| | elif opt.data_mode == 'ours_wang2020': |
| | pickle_name = "train.pickle" if opt.data_label=="train" else "val.pickle" |
| | real_list = get_list( os.path.join(opt.real_list_path, pickle_name) ) |
| | fake_list = get_list( os.path.join(opt.fake_list_path, pickle_name) ) |
| | temp = 'train/progan' if opt.data_label == 'train' else 'test/progan' |
| | real_list += get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='0_real' ) |
| | fake_list += get_list( os.path.join(opt.wang2020_data_path,temp), must_contain='1_fake' ) |
| |
|
| |
|
| |
|
| | |
| | self.labels_dict = {} |
| | for i in real_list: |
| | self.labels_dict[i] = 0 |
| | for i in fake_list: |
| | self.labels_dict[i] = 1 |
| |
|
| | self.total_list = real_list + fake_list |
| | shuffle(self.total_list) |
| | if opt.isTrain: |
| | crop_func = transforms.RandomCrop(opt.cropSize) |
| | elif opt.no_crop: |
| | crop_func = transforms.Lambda(lambda img: img) |
| | else: |
| | crop_func = transforms.CenterCrop(opt.cropSize) |
| |
|
| | if opt.isTrain and not opt.no_flip: |
| | flip_func = transforms.RandomHorizontalFlip() |
| | else: |
| | flip_func = transforms.Lambda(lambda img: img) |
| | if not opt.isTrain and opt.no_resize: |
| | rz_func = transforms.Lambda(lambda img: img) |
| | else: |
| | rz_func = transforms.Lambda(lambda img: custom_resize(img, opt)) |
| | |
| |
|
| | stat_from = "imagenet" if opt.arch.lower().startswith("imagenet") else "clip" |
| |
|
| | print("mean and std stats are from: ", stat_from) |
| | if '2b' not in opt.arch: |
| | print ("using Official CLIP's normalization") |
| | self.transform = transforms.Compose([ |
| | rz_func, |
| | transforms.Lambda(lambda img: data_augment(img, opt)), |
| | crop_func, |
| | flip_func, |
| | transforms.ToTensor(), |
| | transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ), |
| | ]) |
| | else: |
| | print ("Using CLIP 2B transform") |
| | self.transform = None |
| |
|
| |
|
| | def __len__(self): |
| | return len(self.total_list) |
| |
|
| |
|
| | def __getitem__(self, idx): |
| | img_path = self.total_list[idx] |
| | label = self.labels_dict[img_path] |
| | img = Image.open(img_path).convert("RGB") |
| | img = self.transform(img) |
| | return img, label |
| |
|
| |
|
| | def data_augment(img, opt): |
| | img = np.array(img) |
| | if img.ndim == 2: |
| | img = np.expand_dims(img, axis=2) |
| | img = np.repeat(img, 3, axis=2) |
| |
|
| | if random() < opt.blur_prob: |
| | sig = sample_continuous(opt.blur_sig) |
| | gaussian_blur(img, sig) |
| |
|
| | if random() < opt.jpg_prob: |
| | method = sample_discrete(opt.jpg_method) |
| | qual = sample_discrete(opt.jpg_qual) |
| | img = jpeg_from_key(img, qual, method) |
| |
|
| | return Image.fromarray(img) |
| |
|
| |
|
| | def sample_continuous(s): |
| | if len(s) == 1: |
| | return s[0] |
| | if len(s) == 2: |
| | rg = s[1] - s[0] |
| | return random() * rg + s[0] |
| | raise ValueError("Length of iterable s should be 1 or 2.") |
| |
|
| |
|
| | def sample_discrete(s): |
| | if len(s) == 1: |
| | return s[0] |
| | return choice(s) |
| |
|
| |
|
| | def gaussian_blur(img, sigma): |
| | gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) |
| | gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) |
| | gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) |
| |
|
| |
|
| | def cv2_jpg(img, compress_val): |
| | img_cv2 = img[:,:,::-1] |
| | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val] |
| | result, encimg = cv2.imencode('.jpg', img_cv2, encode_param) |
| | decimg = cv2.imdecode(encimg, 1) |
| | return decimg[:,:,::-1] |
| |
|
| |
|
| | def pil_jpg(img, compress_val): |
| | out = BytesIO() |
| | img = Image.fromarray(img) |
| | img.save(out, format='jpeg', quality=compress_val) |
| | img = Image.open(out) |
| | |
| | img = np.array(img) |
| | out.close() |
| | return img |
| |
|
| |
|
| | jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg} |
| | def jpeg_from_key(img, compress_val, key): |
| | method = jpeg_dict[key] |
| | return method(img, compress_val) |
| |
|
| |
|
| | rz_dict = {'bilinear': Image.BILINEAR, |
| | 'bicubic': Image.BICUBIC, |
| | 'lanczos': Image.LANCZOS, |
| | 'nearest': Image.NEAREST} |
| | def custom_resize(img, opt): |
| | interp = sample_discrete(opt.rz_interp) |
| | return TF.resize(img, opt.loadSize, interpolation=rz_dict[interp]) |
| |
|