| | import os |
| | import numpy as np |
| | from tqdm import tqdm |
| | import torch |
| | from datasets import load_dataset, ClassLabel |
| | from datasets import Features, Array3D |
| | from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
| | from metrics import apply_metrics |
| |
|
| |
|
| | def process_label_ids(batch, remapper, label_column="label"): |
| | batch[label_column] = [remapper[label_id] for label_id in batch[label_column]] |
| | return batch |
| |
|
| |
|
| | CACHE_DIR = "/mnt/lerna/data/HFcache" if os.path.exists("/mnt/lerna/data/HFcache") else None |
| |
|
| |
|
| | def main(args): |
| | dataset = load_dataset(args.dataset, split="test", cache_dir=CACHE_DIR) |
| | if args.dataset == "rvl_cdip": |
| | dataset = dataset.select([i for i in range(len(dataset)) if i != 33669]) |
| | batch_size = 100 if args.dataset == "jordyvl/RVL-CDIP-N" else 1000 |
| |
|
| | feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") |
| | model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") |
| |
|
| | label2idx = {label.replace(" ", "_"): i for label, i in model.config.label2id.items()} |
| | data_idx2label = dict(zip(enumerate(dataset.features["label"].names))) |
| | data_label2idx = {label: i for i, label in enumerate(dataset.features["label"].names)} |
| | model_idx2label = dict(zip(label2idx.values(), label2idx.keys())) |
| | diff = [i for i in range(len(data_label2idx)) if data_idx2label[i] != model_idx2label[i]] |
| |
|
| | if diff: |
| | print(f"aligning labels {diff}") |
| | print(f"model labels: {model_idx2label}") |
| | print(f"data labels: {data_idx2label}") |
| | print(f"Remapping to {label2idx}") |
| |
|
| | remapper = {} |
| | for k, v in label2idx.items(): |
| | if k in data_label2idx: |
| | remapper[data_label2idx[k]] = v |
| |
|
| | print(remapper) |
| | new_features = Features( |
| | { |
| | **{k: v for k, v in dataset.features.items() if k != "label"}, |
| | "label": ClassLabel(num_classes=len(label2idx), names=list(label2idx.keys())), |
| | } |
| | ) |
| |
|
| | dataset = dataset.map( |
| | lambda example: process_label_ids(example, remapper), |
| | features=new_features, |
| | batched=True, |
| | batch_size=batch_size, |
| | desc="Aligning the labels", |
| | ) |
| |
|
| | features = Features({**dataset.features, "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224))}) |
| |
|
| | encoded_dataset = dataset.map( |
| | lambda examples: feature_extractor([image.convert("RGB") for image in examples["image"]]), |
| | batched=True, |
| | batch_size=batch_size, |
| | features=features, |
| | ) |
| | encoded_dataset.set_format(type="torch", columns=["pixel_values", "label"]) |
| | BATCH_SIZE = 16 |
| | dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=BATCH_SIZE) |
| |
|
| | all_logits, all_references = np.zeros((len(encoded_dataset), len(label2idx))), np.zeros( |
| | len(encoded_dataset), dtype=int |
| | ) |
| |
|
| | count = 0 |
| | for i, batch in tqdm(enumerate(dataloader)): |
| | with torch.no_grad(): |
| | outputs = model(batch["pixel_values"]) |
| | logits = outputs.logits |
| | all_logits[count : count + BATCH_SIZE] = logits.detach().cpu().numpy() |
| | all_references[count : count + BATCH_SIZE] = batch["label"].detach().cpu().numpy() |
| | count += len(batch["label"]) |
| |
|
| | all_references = np.array(all_references) |
| | all_logits = np.array(all_logits) |
| | results = apply_metrics(all_references, all_logits) |
| | print(results) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | from argparse import ArgumentParser |
| |
|
| | parser = ArgumentParser("""DiT inference on dataset test set""") |
| | parser.add_argument("-d", dest="dataset", type=str, default="rvl_cdip", help="the dataset to be evaluated") |
| | args = parser.parse_args() |
| |
|
| | main(args) |
| |
|