| | import argparse |
| | from ast import arg |
| | import os |
| | import csv |
| | import torch |
| | import torchvision.transforms as transforms |
| | import torch.utils.data |
| | import numpy as np |
| | from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score |
| | from torch.utils.data import Dataset |
| | import sys |
| | from models import get_model |
| | from PIL import Image |
| | import pickle |
| | from tqdm import tqdm |
| | from io import BytesIO |
| | from copy import deepcopy |
| | from dataset_paths import DATASET_PATHS |
| | import random |
| | import shutil |
| | from scipy.ndimage.filters import gaussian_filter |
| |
|
| | SEED = 0 |
| | def set_seed(): |
| | torch.manual_seed(SEED) |
| | torch.cuda.manual_seed(SEED) |
| | np.random.seed(SEED) |
| | random.seed(SEED) |
| |
|
| |
|
| | MEAN = { |
| | "imagenet":[0.485, 0.456, 0.406], |
| | "clip":[0.48145466, 0.4578275, 0.40821073] |
| | } |
| |
|
| | STD = { |
| | "imagenet":[0.229, 0.224, 0.225], |
| | "clip":[0.26862954, 0.26130258, 0.27577711] |
| | } |
| |
|
| |
|
| |
|
| |
|
| |
|
| | def find_best_threshold(y_true, y_pred): |
| | "We assume first half is real 0, and the second half is fake 1" |
| |
|
| | N = y_true.shape[0] |
| |
|
| | if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): |
| | return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2 |
| |
|
| | best_acc = 0 |
| | best_thres = 0 |
| | for thres in y_pred: |
| | temp = deepcopy(y_pred) |
| | temp[temp>=thres] = 1 |
| | temp[temp<thres] = 0 |
| |
|
| | acc = (temp == y_true).sum() / N |
| | if acc >= best_acc: |
| | best_thres = thres |
| | best_acc = acc |
| | |
| | return best_thres |
| | |
| |
|
| | |
| | def png2jpg(img, quality): |
| | out = BytesIO() |
| | img.save(out, format='jpeg', quality=quality) |
| | img = Image.open(out) |
| | |
| | img = np.array(img) |
| | out.close() |
| | return Image.fromarray(img) |
| |
|
| |
|
| | def gaussian_blur(img, sigma): |
| | img = np.array(img) |
| |
|
| | gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) |
| | gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) |
| | gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) |
| |
|
| | return Image.fromarray(img) |
| |
|
| |
|
| |
|
| | def calculate_acc(y_true, y_pred, thres): |
| | r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > thres) |
| | f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > thres) |
| | acc = accuracy_score(y_true, y_pred > thres) |
| | return r_acc, f_acc, acc |
| |
|
| |
|
| | def validate(model, loader, find_thres=False): |
| |
|
| | with torch.no_grad(): |
| | y_true, y_pred = [], [] |
| | print ("Length of dataset: %d" %(len(loader))) |
| | for img, label in loader: |
| | in_tens = img.cuda() |
| |
|
| | y_pred.extend(model(in_tens).sigmoid().flatten().tolist()) |
| | y_true.extend(label.flatten().tolist()) |
| |
|
| | y_true, y_pred = np.array(y_true), np.array(y_pred) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | ap = average_precision_score(y_true, y_pred) |
| |
|
| | |
| | r_acc0, f_acc0, acc0 = calculate_acc(y_true, y_pred, 0.5) |
| | if not find_thres: |
| | return ap, r_acc0, f_acc0, acc0 |
| |
|
| |
|
| | |
| | best_thres = find_best_threshold(y_true, y_pred) |
| | r_acc1, f_acc1, acc1 = calculate_acc(y_true, y_pred, best_thres) |
| |
|
| | return ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres |
| |
|
| | |
| | |
| |
|
| |
|
| |
|
| | |
| |
|
| |
|
| |
|
| |
|
| | def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", "bmp"]): |
| | out = [] |
| | for r, d, f in os.walk(rootdir): |
| | for file in f: |
| | if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): |
| | out.append(os.path.join(r, file)) |
| | return out |
| |
|
| |
|
| | def get_list(path, must_contain=''): |
| | if ".pickle" in path: |
| | with open(path, 'rb') as f: |
| | image_list = pickle.load(f) |
| | image_list = [ item for item in image_list if must_contain in item ] |
| | else: |
| | image_list = recursively_read(path, must_contain) |
| | return image_list |
| |
|
| |
|
| |
|
| |
|
| |
|
| | class RealFakeDataset(Dataset): |
| | def __init__(self, real_path, |
| | fake_path, |
| | data_mode, |
| | max_sample, |
| | arch, |
| | jpeg_quality=None, |
| | gaussian_sigma=None): |
| |
|
| | assert data_mode in ["wang2020", "ours"] |
| | self.jpeg_quality = jpeg_quality |
| | self.gaussian_sigma = gaussian_sigma |
| | |
| | |
| | if type(real_path) == str and type(fake_path) == str: |
| | real_list, fake_list = self.read_path(real_path, fake_path, data_mode, max_sample) |
| | else: |
| | real_list = [] |
| | fake_list = [] |
| | for real_p, fake_p in zip(real_path, fake_path): |
| | real_l, fake_l = self.read_path(real_p, fake_p, data_mode, max_sample) |
| | real_list += real_l |
| | fake_list += fake_l |
| |
|
| | self.total_list = real_list + fake_list |
| |
|
| |
|
| | |
| |
|
| | self.labels_dict = {} |
| | for i in real_list: |
| | self.labels_dict[i] = 0 |
| | for i in fake_list: |
| | self.labels_dict[i] = 1 |
| |
|
| | stat_from = "imagenet" if arch.lower().startswith("imagenet") else "clip" |
| | self.transform = transforms.Compose([ |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize( mean=MEAN[stat_from], std=STD[stat_from] ), |
| | ]) |
| |
|
| |
|
| | def read_path(self, real_path, fake_path, data_mode, max_sample): |
| |
|
| | if data_mode == 'wang2020': |
| | real_list = get_list(real_path, must_contain='0_real') |
| | fake_list = get_list(fake_path, must_contain='1_fake') |
| | else: |
| | real_list = get_list(real_path) |
| | fake_list = get_list(fake_path) |
| |
|
| |
|
| | if max_sample is not None: |
| | if (max_sample > len(real_list)) or (max_sample > len(fake_list)): |
| | max_sample = 100 |
| | print("not enough images, max_sample falling to 100") |
| | random.shuffle(real_list) |
| | random.shuffle(fake_list) |
| | real_list = real_list[0:max_sample] |
| | fake_list = fake_list[0:max_sample] |
| |
|
| | assert len(real_list) == len(fake_list) |
| |
|
| | return real_list, fake_list |
| |
|
| |
|
| |
|
| | def __len__(self): |
| | return len(self.total_list) |
| |
|
| | def __getitem__(self, idx): |
| | |
| | img_path = self.total_list[idx] |
| |
|
| | label = self.labels_dict[img_path] |
| | img = Image.open(img_path).convert("RGB") |
| |
|
| | if self.gaussian_sigma is not None: |
| | img = gaussian_blur(img, self.gaussian_sigma) |
| | if self.jpeg_quality is not None: |
| | img = png2jpg(img, self.jpeg_quality) |
| |
|
| | img = self.transform(img) |
| | return img, label |
| |
|
| |
|
| |
|
| |
|
| |
|
| | if __name__ == '__main__': |
| |
|
| |
|
| | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| | parser.add_argument('--real_path', type=str, default=None, help='dir name or a pickle') |
| | parser.add_argument('--fake_path', type=str, default=None, help='dir name or a pickle') |
| | parser.add_argument('--data_mode', type=str, default=None, help='wang2020 or ours') |
| | parser.add_argument('--max_sample', type=int, default=1000, help='only check this number of images for both fake/real') |
| |
|
| | parser.add_argument('--arch', type=str, default='res50') |
| | parser.add_argument('--ckpt', type=str, default='./pretrained_weights/fc_weights.pth') |
| |
|
| | parser.add_argument('--result_folder', type=str, default='result', help='') |
| | parser.add_argument('--batch_size', type=int, default=128) |
| |
|
| | parser.add_argument('--jpeg_quality', type=int, default=None, help="100, 90, 80, ... 30. Used to test robustness of our model. Not apply if None") |
| | parser.add_argument('--gaussian_sigma', type=int, default=None, help="0,1,2,3,4. Used to test robustness of our model. Not apply if None") |
| |
|
| |
|
| | opt = parser.parse_args() |
| |
|
| | |
| | if os.path.exists(opt.result_folder): |
| | shutil.rmtree(opt.result_folder) |
| | os.makedirs(opt.result_folder) |
| |
|
| | model = get_model(opt.arch) |
| | state_dict = torch.load(opt.ckpt, map_location='cpu') |
| | model.fc.load_state_dict(state_dict) |
| | print ("Model loaded..") |
| | model.eval() |
| | model.cuda() |
| |
|
| | if (opt.real_path == None) or (opt.fake_path == None) or (opt.data_mode == None): |
| | dataset_paths = DATASET_PATHS |
| | else: |
| | dataset_paths = [ dict(real_path=opt.real_path, fake_path=opt.fake_path, data_mode=opt.data_mode) ] |
| |
|
| |
|
| |
|
| | for dataset_path in (dataset_paths): |
| | set_seed() |
| |
|
| | dataset = RealFakeDataset( dataset_path['real_path'], |
| | dataset_path['fake_path'], |
| | dataset_path['data_mode'], |
| | opt.max_sample, |
| | opt.arch, |
| | jpeg_quality=opt.jpeg_quality, |
| | gaussian_sigma=opt.gaussian_sigma, |
| | ) |
| |
|
| | loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=4) |
| | ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres = validate(model, loader, find_thres=True) |
| |
|
| | with open( os.path.join(opt.result_folder,'ap.txt'), 'a') as f: |
| | f.write(dataset_path['key']+': ' + str(round(ap*100, 2))+'\n' ) |
| |
|
| | with open( os.path.join(opt.result_folder,'acc0.txt'), 'a') as f: |
| | f.write(dataset_path['key']+': ' + str(round(r_acc0*100, 2))+' '+str(round(f_acc0*100, 2))+' '+str(round(acc0*100, 2))+'\n' ) |
| |
|
| |
|