| | from enum import Enum |
| | from PIL import Image |
| |
|
| |
|
| | def unravel_index(index, shape): |
| | out = [] |
| | for dim in reversed(shape): |
| | out.append(index % dim) |
| | index = index // dim |
| | return tuple(reversed(out)) |
| |
|
| |
|
| | class ExplicitEnum(Enum): |
| | """ |
| | Enum with more explicit error message for missing values or getting all options |
| | """ |
| |
|
| | @classmethod |
| | def _missing_(cls, value): |
| | raise ValueError( |
| | f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" |
| | ) |
| |
|
| | @classmethod |
| | def options(cls): |
| | return list(cls._value2member_map_.keys()) |
| |
|
| |
|
| | class InferenceMethod(ExplicitEnum): |
| | """All the implemented inference methods""" |
| |
|
| | FIRST = "first" |
| | SECOND = "second" |
| | LAST = "last" |
| |
|
| | GRID = "grid" |
| | |
| |
|
| | MAX_CONFIDENCE = "max_confidence" |
| | SOFT_VOTING = "soft_voting" |
| | HARD_VOTING = "hard_voting" |
| |
|
| | @property |
| | def scope(self): |
| | if self in [InferenceMethod.FIRST, InferenceMethod.SECOND, InferenceMethod.LAST]: |
| | return "sample" |
| | if self in [InferenceMethod.GRID]: |
| | return "sample-grid" |
| | else: |
| | return "iter" |
| |
|
| | def get_page_scope(self, pages): |
| | if self.scope == "iter": |
| | return pages |
| | if self == InferenceMethod.GRID: |
| | try: |
| | return equal_image_grid(pages) |
| | except Exception as e: |
| | return pages[-1] |
| | if self == InferenceMethod.FIRST: |
| | return pages[0] |
| | if self == InferenceMethod.SECOND: |
| | if len(pages) > 1: |
| | return pages[1] |
| | return pages[0] |
| | if self == InferenceMethod.LAST: |
| | return pages[-1] |
| |
|
| | def apply_decision_strategy(self, page_logits): |
| | """ |
| | page logits is of shape [NUM_PAGES x CLASSES] |
| | """ |
| | if self == InferenceMethod.MAX_CONFIDENCE: |
| | index = page_logits.argmax() |
| | indices = unravel_index(index, page_logits.shape) |
| | print(f"The page which is max confident: {indices[0]}") |
| | return indices[-1] |
| | if self == InferenceMethod.HARD_VOTING: |
| | return page_logits.argmax(-1).max() |
| | if self == InferenceMethod.SOFT_VOTING: |
| | return page_logits.mean(0).argmax(-1) |
| |
|
| |
|
| | def equal_image_grid(images): |
| | def compute_grid(n, max_cols=6): |
| | equalDivisor = int(n**0.5) |
| | cols = min(equalDivisor, max_cols) |
| | rows = equalDivisor |
| | if rows * cols >= n: |
| | return rows, cols |
| | cols += 1 |
| | if rows * cols >= n: |
| | return rows, cols |
| | while rows * cols < n: |
| | rows += 1 |
| | return rows, cols |
| |
|
| | |
| | rows, cols = compute_grid(len(images)) |
| |
|
| | |
| | images = [im for im in images if (im.height > 0) and (im.width > 0)] |
| |
|
| | min_width = min(im.width for im in images) |
| | images = [im.resize((min_width, int(im.height * min_width / im.width)), resample=Image.BICUBIC) for im in images] |
| |
|
| | w, h = max([img.size[0] for img in images]), max([img.size[1] for img in images]) |
| |
|
| | grid = Image.new("RGB", size=(cols * w, rows * h)) |
| | grid_w, grid_h = grid.size |
| |
|
| | for i, img in enumerate(images): |
| | grid.paste(img, box=(i % cols * w, i // cols * h)) |
| | return grid |
| |
|