| | from typing import List, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from diffusers.modular_pipelines import ( |
| | ComponentSpec, |
| | InputParam, |
| | ModularPipelineBlocks, |
| | OutputParam, |
| | PipelineState, |
| | ) |
| | from PIL import Image, ImageDraw |
| | from transformers import AutoProcessor, Florence2ForConditionalGeneration |
| |
|
| |
|
| | class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): |
| | @property |
| | def expected_components(self): |
| | return [ |
| | ComponentSpec( |
| | name="image_annotator", |
| | type_hint=Florence2ForConditionalGeneration, |
| | repo="florence-community/Florence-2-base-ft", |
| | ), |
| | ComponentSpec( |
| | name="image_annotator_processor", |
| | type_hint=AutoProcessor, |
| | repo="florence-community/Florence-2-base-ft", |
| | ), |
| | ] |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "image", |
| | type_hint=Union[Image.Image, List[Image.Image]], |
| | required=True, |
| | description="Image(s) to annotate", |
| | metadata={"mellon":"image"}, |
| | ), |
| | InputParam( |
| | "annotation_task", |
| | type_hint=Union[str, List[str]], |
| | default="<REFERRING_EXPRESSION_SEGMENTATION>", |
| | metadata={"mellon":"dropdown"}, |
| | description="""Annotation Task to perform on the image. |
| | Supported Tasks: |
| | |
| | <OD> |
| | <REFERRING_EXPRESSION_SEGMENTATION> |
| | <CAPTION> |
| | <DETAILED_CAPTION> |
| | <MORE_DETAILED_CAPTION> |
| | <DENSE_REGION_CAPTION> |
| | <REGION_PROPOSAL> |
| | <CAPTION_TO_PHRASE_GROUNDING> |
| | <OPEN_VOCABULARY_DETECTION> |
| | <OCR> |
| | <OCR_WITH_REGION> |
| | |
| | """, |
| | ), |
| | InputParam( |
| | "annotation_prompt", |
| | type_hint=Union[str, List[str]], |
| | required=True, |
| | metadata={"mellon":"textbox"}, |
| | description="""Annotation Prompt to provide more context to the task. |
| | Can be used to detect or segment out specific elements in the image |
| | """, |
| | ), |
| | InputParam( |
| | "annotation_output_type", |
| | type_hint=str, |
| | default="mask_image", |
| | metadata={"mellon":"dropdown"}, |
| | description="""Output type from annotation predictions. Availabe options are |
| | annotation: |
| | - raw annotation predictions from the model based on task type. |
| | mask_image: |
| | -black and white mask image for the given image based on the task type |
| | mask_overlay: |
| | - white mask overlayed on the original image |
| | bounding_box: |
| | - bounding boxes drawn on the original image |
| | """, |
| | ), |
| | InputParam( |
| | "annotation_overlay", |
| | type_hint=bool, |
| | required=True, |
| | default=False, |
| | description="", |
| | metadata={"mellon":"checkbox"}, |
| | ), |
| | InputParam( |
| | "fill", |
| | type_hint=str, |
| | default="white", |
| | description="", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "annotations", |
| | type_hint=dict, |
| | description="Annotations Predictions for input Image(s)", |
| | ), |
| | OutputParam( |
| | "images", |
| | type_hint=Image, |
| | description="Annotated input Image(s)", |
| | metadata={"mellon":"image"}, |
| | ), |
| | ] |
| |
|
| | def get_annotations(self, components, images, prompts, task): |
| | task_prompts = [task + prompt for prompt in prompts] |
| |
|
| | inputs = components.image_annotator_processor( |
| | text=task_prompts, images=images, return_tensors="pt" |
| | ).to(components.image_annotator.device, components.image_annotator.dtype) |
| |
|
| | generated_ids = components.image_annotator.generate( |
| | input_ids=inputs["input_ids"], |
| | pixel_values=inputs["pixel_values"], |
| | max_new_tokens=1024, |
| | early_stopping=False, |
| | do_sample=False, |
| | num_beams=3, |
| | ) |
| | annotations = components.image_annotator_processor.batch_decode( |
| | generated_ids, skip_special_tokens=False |
| | ) |
| |
|
| | outputs = [] |
| | for image, annotation in zip(images, annotations): |
| | outputs.append( |
| | components.image_annotator_processor.post_process_generation( |
| | annotation, task=task, image_size=(image.width, image.height) |
| | ) |
| | ) |
| |
|
| | return outputs |
| |
|
| | def _iter_polygon_point_sets(self, poly): |
| | """ |
| | Yields lists of (x, y) points for all simple polygons found in `poly`. |
| | Supports formats: |
| | - [x1, y1, x2, y2, ...] |
| | - [[x, y], [x, y], ...] |
| | - [xs, ys] |
| | - dict {'x': xs, 'y': ys} |
| | - nested lists containing any of the above |
| | """ |
| | if poly is None: |
| | return |
| |
|
| | def is_num(v): |
| | return isinstance(v, (int, float, np.number)) |
| |
|
| | |
| | if isinstance(poly, dict) and "x" in poly and "y" in poly: |
| | xs, ys = poly["x"], poly["y"] |
| | if ( |
| | isinstance(xs, (list, tuple)) |
| | and isinstance(ys, (list, tuple)) |
| | and len(xs) == len(ys) |
| | ): |
| | pts = list(zip(xs, ys)) |
| | if len(pts) >= 3: |
| | yield pts |
| | return |
| |
|
| | if isinstance(poly, (list, tuple)): |
| | |
| | if all(is_num(v) for v in poly): |
| | coords = list(poly) |
| | if len(coords) >= 6 and len(coords) % 2 == 0: |
| | yield list(zip(coords[0::2], coords[1::2])) |
| | return |
| |
|
| | |
| | if all( |
| | isinstance(v, (list, tuple)) |
| | and len(v) == 2 |
| | and all(is_num(n) for n in v) |
| | for v in poly |
| | ): |
| | if len(poly) >= 3: |
| | yield [tuple(v) for v in poly] |
| | return |
| |
|
| | |
| | if len(poly) == 2 and all(isinstance(v, (list, tuple)) for v in poly): |
| | xs, ys = poly |
| | try: |
| | if len(xs) == len(ys) and len(xs) >= 3: |
| | yield list(zip(xs, ys)) |
| | return |
| | except TypeError: |
| | pass |
| |
|
| | |
| | for part in poly: |
| | yield from self._iter_polygon_point_sets(part) |
| | |
| |
|
| | def prepare_mask(self, images, annotations, overlay=False, fill="white"): |
| | masks = [] |
| | for image, annotation in zip(images, annotations): |
| | mask_image = image.copy() if overlay else Image.new("L", image.size, 0) |
| | draw = ImageDraw.Draw(mask_image) |
| |
|
| | |
| | mask_fill = fill |
| | if not overlay and isinstance(fill, str): |
| | |
| | mask_fill = 255 |
| |
|
| | for _, _annotation in annotation.items(): |
| | if "polygons" in _annotation: |
| | for poly in _annotation["polygons"]: |
| | for pts in self._iter_polygon_point_sets(poly): |
| | if len(pts) < 3: |
| | continue |
| | |
| | flat = [] |
| | for x, y in pts: |
| | xi = int(round(max(0, min(image.width - 1, x)))) |
| | yi = int(round(max(0, min(image.height - 1, y)))) |
| | flat.extend([xi, yi]) |
| | draw.polygon(flat, fill=mask_fill) |
| |
|
| | elif "bboxes" in _annotation: |
| | for bbox in _annotation["bboxes"]: |
| | flat = np.array(bbox).flatten().tolist() |
| | if len(flat) == 4: |
| | x0, y0, x1, y1 = flat |
| | draw.rectangle( |
| | ( |
| | int(round(x0)), |
| | int(round(y0)), |
| | int(round(x1)), |
| | int(round(y1)), |
| | ), |
| | fill=mask_fill, |
| | ) |
| |
|
| | elif "quad_boxes" in _annotation: |
| | for quad in _annotation["quad_boxes"]: |
| | for pts in self._iter_polygon_point_sets(quad): |
| | if len(pts) < 3: |
| | continue |
| | flat = [] |
| | for x, y in pts: |
| | xi = int(round(max(0, min(image.width - 1, x)))) |
| | yi = int(round(max(0, min(image.height - 1, y)))) |
| | flat.extend([xi, yi]) |
| | draw.polygon(flat, fill=mask_fill) |
| |
|
| | masks.append(mask_image) |
| |
|
| | return masks |
| |
|
| | def prepare_bounding_boxes(self, images, annotations): |
| | outputs = [] |
| | for image, annotation in zip(images, annotations): |
| | image_copy = image.copy() |
| | draw = ImageDraw.Draw(image_copy) |
| | for _, _annotation in annotation.items(): |
| | |
| | bboxes = _annotation.get("bboxes", []) |
| | labels = _annotation.get("labels", []) |
| |
|
| | if len(labels) == 0: |
| | labels = _annotation.get("bboxes_labels", []) |
| |
|
| | for i, bbox in enumerate(bboxes): |
| | flat = np.array(bbox).flatten().tolist() |
| |
|
| | if len(flat) != 4: |
| | continue |
| |
|
| | x0, y0, x1, y1 = flat |
| | draw.rectangle( |
| | ( |
| | int(round(x0)), |
| | int(round(y0)), |
| | int(round(x1)), |
| | int(round(y1)), |
| | ), |
| | outline="red", |
| | width=3, |
| | ) |
| | label = labels[i] if i < len(labels) else "" |
| | if label: |
| | text_y = max(0, int(y0) - 20) |
| | draw.text((int(x0), text_y), label, fill="red") |
| |
|
| | |
| | quad_boxes = _annotation.get("quad_boxes", []) |
| | qlabels = _annotation.get("labels", []) |
| | for i, quad in enumerate(quad_boxes): |
| | for pts in self._iter_polygon_point_sets(quad): |
| | if len(pts) < 3: |
| | continue |
| | flat = [] |
| | xs, ys = [], [] |
| | for x, y in pts: |
| | xi = int(round(max(0, min(image.width - 1, x)))) |
| | yi = int(round(max(0, min(image.height - 1, y)))) |
| | flat.extend([xi, yi]) |
| | xs.append(xi) |
| | ys.append(yi) |
| |
|
| | |
| | try: |
| | draw.polygon(flat, outline="red", width=3) |
| | except TypeError: |
| | |
| | draw.polygon(flat, outline="red") |
| |
|
| | |
| | label = qlabels[i] if i < len(qlabels) else "" |
| | if label: |
| | cx = int(round(sum(xs) / len(xs))) |
| | cy = int(round(sum(ys) / len(ys))) |
| | cx = max(0, min(image.width - 1, cx)) |
| | cy = max(0, min(image.height - 1, cy)) |
| | draw.text((cx, cy), label, fill="red") |
| |
|
| | outputs.append(image_copy) |
| |
|
| | return outputs |
| |
|
| | def prepare_inputs(self, images, prompts): |
| | prompts = prompts or "" |
| |
|
| | if isinstance(images, Image.Image): |
| | images = [images] |
| | if isinstance(prompts, str): |
| | prompts = [prompts] |
| |
|
| | if len(images) != len(prompts): |
| | raise ValueError("Number of images and annotation prompts must match.") |
| |
|
| | return images, prompts |
| |
|
| | @torch.no_grad() |
| | def __call__(self, components, state: PipelineState) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | skip_image = False |
| |
|
| | |
| | if ( |
| | block_state.annotation_task == "<OD>" |
| | or block_state.annotation_task == "<DENSE_REGION_CAPTION>" |
| | or block_state.annotation_task == "<REGION_PROPOSAL>" |
| | or block_state.annotation_task == "<OCR_WITH_REGION>" |
| | ): |
| | block_state.annotation_prompt = "" |
| | block_state.annotation_output_type = "bounding_box" |
| | |
| | elif ( |
| | block_state.annotation_task == "<CAPTION>" |
| | or block_state.annotation_task == "<DETAILED_CAPTION>" |
| | or block_state.annotation_task == "<MORE_DETAILED_CAPTION>" |
| | or block_state.annotation_task == "<OCR>" |
| | ): |
| | block_state.annotation_prompt = "" |
| | skip_image = True |
| |
|
| | images, annotation_task_prompt = self.prepare_inputs( |
| | block_state.image, block_state.annotation_prompt |
| | ) |
| | task = block_state.annotation_task |
| | fill = block_state.fill |
| |
|
| | annotations = self.get_annotations( |
| | components, images, annotation_task_prompt, task |
| | ) |
| |
|
| | block_state.annotations = annotations |
| | block_state.images = None |
| |
|
| | if not skip_image: |
| | if block_state.annotation_output_type == "mask_image": |
| | block_state.images = self.prepare_mask(images, annotations) |
| |
|
| | if block_state.annotation_output_type == "mask_overlay": |
| | block_state.images = self.prepare_mask( |
| | images, annotations, overlay=True, fill=fill |
| | ) |
| | elif block_state.annotation_output_type == "bounding_box": |
| | block_state.images = self.prepare_bounding_boxes(images, annotations) |
| |
|
| | self.set_block_state(state, block_state) |
| |
|
| | return components, state |
| |
|