| import argparse |
| import os |
| import ruamel_yaml as yaml |
| import numpy as np |
| import json |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| import torch.nn.functional as F |
| from dataset.dataset_RSNA import RSNA2018_Dataset |
| from models.model_MeDSLIP import MeDSLIP |
| from models.tokenization_bert import BertTokenizer |
| from sklearn.metrics import roc_auc_score, precision_recall_curve, accuracy_score |
| from tqdm import tqdm |
|
|
| original_class = [ |
| "normal", |
| "clear", |
| "sharp", |
| "sharply", |
| "unremarkable", |
| "intact", |
| "stable", |
| "free", |
| "effusion", |
| "opacity", |
| "pneumothorax", |
| "edema", |
| "atelectasis", |
| "tube", |
| "consolidation", |
| "process", |
| "abnormality", |
| "enlarge", |
| "tip", |
| "low", |
| "pneumonia", |
| "line", |
| "congestion", |
| "catheter", |
| "cardiomegaly", |
| "fracture", |
| "air", |
| "tortuous", |
| "lead", |
| "disease", |
| "calcification", |
| "prominence", |
| "device", |
| "engorgement", |
| "picc", |
| "clip", |
| "elevation", |
| "expand", |
| "nodule", |
| "wire", |
| "fluid", |
| "degenerative", |
| "pacemaker", |
| "thicken", |
| "marking", |
| "scar", |
| "hyperinflate", |
| "blunt", |
| "loss", |
| "widen", |
| "collapse", |
| "density", |
| "emphysema", |
| "aerate", |
| "mass", |
| "crowd", |
| "infiltrate", |
| "obscure", |
| "deformity", |
| "hernia", |
| "drainage", |
| "distention", |
| "shift", |
| "stent", |
| "pressure", |
| "lesion", |
| "finding", |
| "borderline", |
| "hardware", |
| "dilation", |
| "chf", |
| "redistribution", |
| "aspiration", |
| "tail_abnorm_obs", |
| "excluded_obs", |
| ] |
|
|
|
|
| def get_tokenizer(tokenizer, target_text): |
|
|
| target_tokenizer = tokenizer( |
| list(target_text), |
| padding="max_length", |
| truncation=True, |
| max_length=64, |
| return_tensors="pt", |
| ) |
|
|
| return target_tokenizer |
|
|
|
|
| def compute_AUCs(gt, pred, n_class): |
| """Computes Area Under the Curve (AUC) from prediction scores. |
| Args: |
| gt: Pytorch tensor on GPU, shape = [n_samples, n_classes] |
| true binary labels. |
| pred: Pytorch tensor on GPU, shape = [n_samples, n_classes] |
| can either be probability estimates of the positive class, |
| confidence values, or binary decisions. |
| Returns: |
| List of AUROCs of all classes. |
| """ |
| AUROCs = [] |
| gt_np = gt.cpu().numpy() |
| pred_np = pred.cpu().numpy() |
| for i in range(n_class): |
| AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i])) |
| return AUROCs |
|
|
|
|
| def main(args, config): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print("Total CUDA devices: ", torch.cuda.device_count()) |
| torch.set_default_tensor_type("torch.FloatTensor") |
|
|
| |
| print("Creating dataset") |
| test_dataset = RSNA2018_Dataset(config["test_file"]) |
| test_dataloader = DataLoader( |
| test_dataset, |
| batch_size=config["test_batch_size"], |
| num_workers=8, |
| pin_memory=True, |
| sampler=None, |
| shuffle=False, |
| collate_fn=None, |
| drop_last=False, |
| ) |
| json_book = json.load(open(config["disease_book"], "r")) |
| disease_book = [json_book[i] for i in json_book] |
| ana_book = [ |
| "It is located at " + i |
| for i in [ |
| "trachea", |
| "left_hilar", |
| "right_hilar", |
| "hilar_unspec", |
| "left_pleural", |
| "right_pleural", |
| "pleural_unspec", |
| "heart_size", |
| "heart_border", |
| "left_diaphragm", |
| "right_diaphragm", |
| "diaphragm_unspec", |
| "retrocardiac", |
| "lower_left_lobe", |
| "upper_left_lobe", |
| "lower_right_lobe", |
| "middle_right_lobe", |
| "upper_right_lobe", |
| "left_lower_lung", |
| "left_mid_lung", |
| "left_upper_lung", |
| "left_apical_lung", |
| "left_lung_unspec", |
| "right_lower_lung", |
| "right_mid_lung", |
| "right_upper_lung", |
| "right_apical_lung", |
| "right_lung_unspec", |
| "lung_apices", |
| "lung_bases", |
| "left_costophrenic", |
| "right_costophrenic", |
| "costophrenic_unspec", |
| "cardiophrenic_sulcus", |
| "mediastinal", |
| "spine", |
| "clavicle", |
| "rib", |
| "stomach", |
| "right_atrium", |
| "right_ventricle", |
| "aorta", |
| "svc", |
| "interstitium", |
| "parenchymal", |
| "cavoatrial_junction", |
| "cardiopulmonary", |
| "pulmonary", |
| "lung_volumes", |
| "unspecified", |
| "other", |
| ] |
| ] |
| tokenizer = BertTokenizer.from_pretrained(config["text_encoder"]) |
| ana_book_tokenizer = get_tokenizer(tokenizer, ana_book).to(device) |
| disease_book_tokenizer = get_tokenizer(tokenizer, disease_book).to(device) |
|
|
| print("Creating model") |
| model = MeDSLIP(config, disease_book_tokenizer) |
| if args.ddp: |
| model = nn.DataParallel( |
| model, device_ids=[i for i in range(torch.cuda.device_count())] |
| ) |
| model = model.to(device) |
|
|
| checkpoint = torch.load(args.checkpoint, map_location="cpu") |
| state_dict = checkpoint["model"] |
| model.load_state_dict(state_dict, strict=False) |
| print("load checkpoint from %s" % args.checkpoint) |
|
|
| print("Start testing") |
| model.eval() |
|
|
| gt = torch.FloatTensor() |
| gt = gt.to(device) |
| pred = torch.FloatTensor() |
| pred = pred.to(device) |
| loop = tqdm(test_dataloader) |
| for i, sample in enumerate(loop): |
| loop.set_description(f"Testing: {i+1}/{len(test_dataloader)}") |
| images = sample["image"].to(device) |
| labels = sample["label"].to(device) |
| gt = torch.cat((gt, labels), 0) |
|
|
| with torch.no_grad(): |
| pred_class = model(images) |
| pred_class = pred_class[:, original_class.index("pneumonia"), :] |
| pred_class = 1 - F.softmax(pred_class) |
| pred = torch.cat((pred, pred_class), 0) |
|
|
| AUROC = compute_AUCs(gt, pred, 1) |
| print("The AUROC of {} is {}".format("pneumonia", AUROC[0])) |
| max_f1s = [] |
| accs = [] |
| gt_np = gt[:, 0].cpu().numpy() |
| pred_np = pred[:, 0].cpu().numpy() |
| precision, recall, thresholds = precision_recall_curve(gt_np, pred_np) |
| numerator = 2 * recall * precision |
| denom = recall + precision |
| f1_scores = np.divide( |
| numerator, denom, out=np.zeros_like(denom), where=(denom != 0) |
| ) |
| max_f1 = np.max(f1_scores) |
| max_f1_thresh = thresholds[np.argmax(f1_scores)] |
| max_f1s.append(max_f1) |
| accs.append(accuracy_score(gt_np, pred_np > max_f1_thresh)) |
| f1_avg = np.array(max_f1s).mean() |
| acc_avg = np.array(accs).mean() |
| print("The average f1 is {F1_avg:.4f}".format(F1_avg=f1_avg)) |
| print("The average ACC is {ACC_avg:.4f}".format(ACC_avg=acc_avg)) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--config", |
| default="Sample_zero-Shot_Classification_RSNA/configs/MeDSLIP_config.yaml", |
| ) |
| parser.add_argument("--checkpoint", default="MeDSLIP_resnet50.pth") |
| parser.add_argument("--device", default="cuda") |
| parser.add_argument("--gpu", type=str, default="0", help="gpu") |
| parser.add_argument("--ddp", action="store_true", help="use ddp") |
|
|
| args = parser.parse_args() |
|
|
| config = yaml.load(open(args.config, "r"), Loader=yaml.Loader) |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
| if args.gpu != "-1": |
| torch.cuda.current_device() |
| torch.cuda._initialized = True |
|
|
| main(args, config) |
|
|