| | import sys |
| | import numpy as np |
| | import pandas as pd |
| | from metrics import ece_logits, aurc_logits, multi_aurc_plot, apply_metrics |
| | from sklearn.metrics import f1_score |
| | from collections import OrderedDict |
| |
|
| | EXPERIMENT_ROOT = "/mnt/lerna/experiments" |
| |
|
| |
|
| | def softmax(x, axis=-1): |
| | |
| | x = x - np.max(x, axis=axis, keepdims=True) |
| |
|
| | |
| | exps = np.exp(x) |
| |
|
| | |
| | exps_sum = np.sum(exps, axis=axis, keepdims=True) |
| |
|
| | |
| | softmax_probs = exps / exps_sum |
| |
|
| | return softmax_probs |
| |
|
| |
|
| | def predictions_loader(predictions_path): |
| | data = np.load(predictions_path)["arr_0"] |
| | dataset_idx = data[:, -1] |
| | labels = data[:, -2] |
| | if "DiT-base-rvl_cdip_MP" in predictions_path and any(x in predictions_path for x in ["first", "second", "last"]): |
| | data = data[:, :-2] |
| | predictions = np.argmax(data, -1) |
| | else: |
| | labels = data[:, -2].astype(int) |
| | predictions = data[:, -3].astype(int) |
| | data = data[:, :-3] |
| | return data, labels, predictions, dataset_idx |
| |
|
| |
|
| | def compare_errors(): |
| | """ |
| | from scipy.stats import pearsonr, spearmanr |
| | #idx = [x for x in strategy_correctness['first'] if x ==0] |
| | spearmanr(strategy_correctness['first'], strategy_correctness['second']) |
| | #SignificanceResult(statistic=0.5429413617297623, pvalue=0.0) |
| | spearmanr(strategy_correctness['first'], strategy_correctness['last']) |
| | #SignificanceResult(statistic=0.5005224326802595, pvalue=0.0) |
| | |
| | pearsonr(strategy_correctness['first'], strategy_correctness['second']) |
| | #PearsonRResult(statistic=0.5429413617297617, pvalue=0.0) |
| | pearsonr(strategy_correctness['first'], strategy_correctness['last']) |
| | #PearsonRResult(statistic=0.5005224326802583, pvalue=0.0) |
| | """ |
| | for dataset in ["rvl_cdip_n_mp"]: |
| | strategy_logits = {} |
| | strategy_correctness = {} |
| | for strategy in ["first", "second", "last"]: |
| | path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz" |
| |
|
| | strategy_logits[strategy], labels, predictions, dataset_idx = predictions_loader(path) |
| | strategy_correctness[strategy] = (predictions == labels).astype(int) |
| |
|
| | print("Base accuracy of first: ", np.mean(strategy_correctness["first"])) |
| | firstcorrectifsecondcorrect = [ |
| | x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["first"]) |
| | ] |
| | print(f"Accuracy of first when adding knowledge from second page: {np.mean(firstcorrectifsecondcorrect)}") |
| | firstcorrectiflastcorrect = [ |
| | x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["first"]) |
| | ] |
| | print(f"Accuracy of first when adding knowledge from last page: {np.mean(firstcorrectiflastcorrect)}") |
| |
|
| | firstcorrectifsecondorlastcorrect = [ |
| | x if x == 1 else (strategy_correctness["second"][i] or strategy_correctness["last"][i]) |
| | for i, x in enumerate(strategy_correctness["first"]) |
| | ] |
| | print( |
| | f"Accuracy of first when adding knowledge from second/last page: {np.mean(firstcorrectifsecondorlastcorrect)}" |
| | ) |
| |
|
| | |
| | print("Base accuracy of second: ", np.mean(strategy_correctness["second"])) |
| | secondcorrectiffirstcorrect = [ |
| | x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["second"]) |
| | ] |
| | print(f"Accuracy of second when adding knowledge from first page: {np.mean(secondcorrectiffirstcorrect)}") |
| | secondcorrectiflastcorrect = [ |
| | x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["second"]) |
| | ] |
| | print(f"Accuracy of second when adding knowledge from last page: {np.mean(secondcorrectiflastcorrect)}") |
| |
|
| | |
| | print("Base accuracy of last: ", np.mean(strategy_correctness["last"])) |
| | lastcorrectiffirstcorrect = [ |
| | x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["last"]) |
| | ] |
| | print(f"Accuracy of last when adding knowledge from first page: {np.mean(lastcorrectiffirstcorrect)}") |
| | lastcorrectifsecondcorrect = [ |
| | x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["last"]) |
| | ] |
| | print(f"Accuracy of last when adding knowledge from second page: {np.mean(lastcorrectifsecondcorrect)}") |
| |
|
| |
|
| | def review_one(path): |
| | collect = OrderedDict() |
| | try: |
| | logits, labels, predictions, dataset_idx = predictions_loader(path) |
| | except Exception as e: |
| | print(f"something went wrong in inference loading {e}") |
| | return |
| | |
| | y_correct = (predictions == labels).astype(int) |
| | acc = np.mean(y_correct) |
| | p_hat = np.array([softmax(p, -1)[predictions[i]] for i, p in enumerate(logits)]) |
| |
|
| | res = aurc_logits( |
| | y_correct, p_hat, plot=False, get_cache=True, use_as_is=True |
| | ) |
| |
|
| | collect["aurc"] = res["aurc"] |
| | collect["accuracy"] = 100 * acc |
| | collect["f1"] = 100 * f1_score(labels, predictions, average="weighted") |
| | collect["f1_macro"] = 100 * f1_score(labels, predictions, average="macro") |
| | collect["ece"] = ece_logits(np.logical_not(y_correct), np.expand_dims(p_hat, -1), use_as_is=True) |
| |
|
| | df = pd.DataFrame.from_dict([collect]) |
| | |
| | print(df.to_latex()) |
| | print(df.to_string()) |
| | return collect, res |
| |
|
| |
|
| | def experiments_review(): |
| | STRATEGIES = ["first", "second", "last", "max_confidence", "soft_voting", "hard_voting", "grid"] |
| | for dataset in ["DiT-base-rvl_cdip_MP", "rvl_cdip_n_mp"]: |
| | collect = {} |
| | aurcs = [] |
| | caches = [] |
| | for strategy in STRATEGIES: |
| | path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz" |
| | collect[strategy], res = review_one(path) |
| | aurcs.append(res["aurc"]) |
| | caches.append(res["cache"]) |
| |
|
| | df = pd.DataFrame.from_dict(collect, orient="index") |
| | df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]] |
| | print(df.to_latex()) |
| | print(df.to_string()) |
| | """ |
| | subset = [0, 1, 2] |
| | multi_aurc_plot( |
| | [x for i, x in enumerate(caches) if i in subset], |
| | [x for i, x in enumerate(STRATEGIES) if i in subset], |
| | aurcs=[x for i, x in enumerate(aurcs) if i in subset], |
| | ) |
| | """ |
| |
|
| |
|
| | if __name__ == "__main__": |
| | from argparse import ArgumentParser |
| |
|
| | parser = ArgumentParser("""Deeper evaluation of different inference strategies to classify a document""") |
| | DEFAULT = "./dit-base-finetuned-rvlcdip_last-10.npz" |
| | parser.add_argument( |
| | "predictions_path", |
| | type=str, |
| | default=DEFAULT, |
| | nargs="?", |
| | help="path to predictions", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | if args.predictions_path == DEFAULT: |
| | experiments_review() |
| | compare_errors() |
| | sys.exit(1) |
| |
|
| | print(f"Running default experiment on {args.predictions_path}") |
| | review_one(args.predictions_path) |
| |
|