| | import ast |
| | import os |
| | import json |
| | from matplotlib.patches import Polygon |
| | from matplotlib.collections import PatchCollection |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import cv2 |
| | import inflect |
| |
|
| | p = inflect.engine() |
| |
|
| | img_dir = "imgs" |
| | bg_prompt_text = "Background prompt: " |
| | |
| | box_scale = (512, 512) |
| | size = box_scale |
| | size_h, size_w = size |
| | print(f"Using box scale: {box_scale}") |
| |
|
| | def parse_input(text=None, no_input=False): |
| | if not text: |
| | if no_input: |
| | return |
| | |
| | text = input("Enter the response: ") |
| | if "Objects: " in text: |
| | text = text.split("Objects: ")[1] |
| | |
| | text_split = text.split(bg_prompt_text) |
| | if len(text_split) == 2: |
| | gen_boxes, bg_prompt = text_split |
| | elif len(text_split) == 1: |
| | if no_input: |
| | return |
| | gen_boxes = text |
| | bg_prompt = "" |
| | while not bg_prompt: |
| | |
| | bg_prompt = input("Enter the background prompt: ").strip() |
| | if bg_prompt_text in bg_prompt: |
| | bg_prompt = bg_prompt.split(bg_prompt_text)[1] |
| | else: |
| | raise ValueError(f"text: {text}") |
| | try: |
| | gen_boxes = ast.literal_eval(gen_boxes) |
| | except SyntaxError as e: |
| | |
| | if "No objects" in gen_boxes: |
| | gen_boxes = [] |
| | else: |
| | raise e |
| | bg_prompt = bg_prompt.strip() |
| | |
| | return gen_boxes, bg_prompt |
| |
|
| | def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=3): |
| | if len(gen_boxes) == 0: |
| | return [] |
| | |
| | box_dict_format = False |
| | gen_boxes_new = [] |
| | for gen_box in gen_boxes: |
| | if isinstance(gen_box, dict): |
| | name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box'] |
| | box_dict_format = True |
| | else: |
| | name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box |
| | if bbox_w <= 0 or bbox_h <= 0: |
| | |
| | continue |
| | if ignore_background: |
| | if (bbox_w >= size[1] and bbox_h >= size[0]) or bbox_x > size[1] or bbox_y > size[0]: |
| | |
| | continue |
| | gen_boxes_new.append(gen_box) |
| | |
| | gen_boxes = gen_boxes_new |
| | |
| | if len(gen_boxes) == 0: |
| | return [] |
| | |
| | filtered_gen_boxes = [] |
| | if box_dict_format: |
| | |
| | bbox_left_x_min = min([gen_box['bounding_box'][0] for gen_box in gen_boxes]) |
| | bbox_right_x_max = max([gen_box['bounding_box'][0] + gen_box['bounding_box'][2] for gen_box in gen_boxes]) |
| | bbox_top_y_min = min([gen_box['bounding_box'][1] for gen_box in gen_boxes]) |
| | bbox_bottom_y_max = max([gen_box['bounding_box'][1] + gen_box['bounding_box'][3] for gen_box in gen_boxes]) |
| | else: |
| | bbox_left_x_min = min([gen_box[1][0] for gen_box in gen_boxes]) |
| | bbox_right_x_max = max([gen_box[1][0] + gen_box[1][2] for gen_box in gen_boxes]) |
| | bbox_top_y_min = min([gen_box[1][1] for gen_box in gen_boxes]) |
| | bbox_bottom_y_max = max([gen_box[1][1] + gen_box[1][3] for gen_box in gen_boxes]) |
| | |
| | |
| | if (bbox_right_x_max - bbox_left_x_min) == 0: |
| | return [] |
| | |
| | |
| | shift = -bbox_left_x_min |
| | scale = size_w / (bbox_right_x_max - bbox_left_x_min) |
| | |
| | scale = min(scale, max_scale) |
| | |
| | for gen_box in gen_boxes: |
| | if box_dict_format: |
| | name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box'] |
| | else: |
| | name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box |
| | |
| | if scale_boxes: |
| | |
| | |
| | |
| | bbox_x = (bbox_x + shift) * scale |
| | bbox_y = bbox_y * scale |
| | bbox_w, bbox_h = bbox_w * scale, bbox_h * scale |
| | |
| | bbox_y_offset = 0 |
| | if bbox_top_y_min * scale + bbox_y_offset < 0: |
| | bbox_y_offset -= bbox_top_y_min * scale |
| | if bbox_bottom_y_max * scale + bbox_y_offset >= size_h: |
| | bbox_y_offset -= bbox_bottom_y_max * scale - size_h |
| | bbox_y += bbox_y_offset |
| | |
| | if bbox_y < 0: |
| | bbox_y, bbox_h = 0, bbox_h - bbox_y |
| | |
| | name = name.rstrip(".") |
| | bounding_box = (int(np.round(bbox_x)), int(np.round(bbox_y)), int(np.round(bbox_w)), int(np.round(bbox_h))) |
| | if box_dict_format: |
| | gen_box = { |
| | 'name': name, |
| | 'bounding_box': bounding_box |
| | } |
| | else: |
| | gen_box = (name, bounding_box) |
| | |
| | filtered_gen_boxes.append(gen_box) |
| | |
| | return filtered_gen_boxes |
| |
|
| | def draw_boxes(anns): |
| | ax = plt.gca() |
| | ax.set_autoscale_on(False) |
| | polygons = [] |
| | color = [] |
| | for ann in anns: |
| | c = (np.random.random((1, 3))*0.6+0.4) |
| | [bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox'] |
| | poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], |
| | [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]] |
| | np_poly = np.array(poly).reshape((4, 2)) |
| | polygons.append(Polygon(np_poly)) |
| | color.append(c) |
| |
|
| | |
| | name = ann['name'] if 'name' in ann else str(ann['category_id']) |
| | ax.text(bbox_x, bbox_y, name, style='italic', |
| | bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5}) |
| |
|
| | p = PatchCollection(polygons, facecolor='none', |
| | edgecolors=color, linewidths=2) |
| | ax.add_collection(p) |
| |
|
| |
|
| | def show_boxes(gen_boxes, bg_prompt=None, ind=None, show=False): |
| | if len(gen_boxes) == 0: |
| | return |
| | |
| | if isinstance(gen_boxes[0], dict): |
| | anns = [{'name': gen_box['name'], 'bbox': gen_box['bounding_box']} |
| | for gen_box in gen_boxes] |
| | else: |
| | anns = [{'name': gen_box[0], 'bbox': gen_box[1]} for gen_box in gen_boxes] |
| |
|
| | |
| | I = np.ones((size[0]+4, size[1]+4, 3), dtype=np.uint8) * 255 |
| |
|
| | plt.imshow(I) |
| | plt.axis('off') |
| |
|
| | if bg_prompt is not None: |
| | ax = plt.gca() |
| | ax.text(0, 0, bg_prompt, style='italic', |
| | bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5}) |
| |
|
| | c = (np.zeros((1, 3))) |
| | [bbox_x, bbox_y, bbox_w, bbox_h] = (0, 0, size[1], size[0]) |
| | poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], |
| | [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]] |
| | np_poly = np.array(poly).reshape((4, 2)) |
| | polygons = [Polygon(np_poly)] |
| | color = [c] |
| | p = PatchCollection(polygons, facecolor='none', |
| | edgecolors=color, linewidths=2) |
| | ax.add_collection(p) |
| |
|
| | draw_boxes(anns) |
| | if show: |
| | plt.show() |
| | else: |
| | print("Saved to", f"{img_dir}/boxes.png", f"ind: {ind}") |
| | if ind is not None: |
| | plt.savefig(f"{img_dir}/boxes_{ind}.png") |
| | plt.savefig(f"{img_dir}/boxes.png") |
| |
|
| |
|
| | def show_masks(masks): |
| | masks_to_show = np.zeros((*size, 3), dtype=np.float32) |
| | for mask in masks: |
| | c = (np.random.random((3,))*0.6+0.4) |
| |
|
| | masks_to_show += mask[..., None] * c[None, None, :] |
| | plt.imshow(masks_to_show) |
| | plt.savefig(f"{img_dir}/masks.png") |
| | plt.show() |
| | plt.clf() |
| |
|
| | def convert_box(box, height, width): |
| | |
| | x_min, y_min = box[0] / width, box[1] / height |
| | w_box, h_box = box[2] / width, box[3] / height |
| | |
| | x_max, y_max = x_min + w_box, y_min + h_box |
| | |
| | return x_min, y_min, x_max, y_max |
| |
|
| | def convert_spec(spec, height, width, include_counts=True, verbose=False): |
| | |
| | prompt, gen_boxes, bg_prompt = spec['prompt'], spec['gen_boxes'], spec['bg_prompt'] |
| | |
| | |
| | gen_boxes = sorted(gen_boxes, key=lambda gen_box: gen_box[0]) |
| | |
| | gen_boxes = [(name, convert_box(box, height=height, width=width)) for name, box in gen_boxes] |
| | |
| | |
| | |
| | |
| | if bg_prompt: |
| | so_prompt_phrase_word_box_list = [(f"{bg_prompt} with {name}", name, name.split(" ")[-1], box) for name, box in gen_boxes] |
| | else: |
| | so_prompt_phrase_word_box_list = [(f"{name}", name, name.split(" ")[-1], box) for name, box in gen_boxes] |
| | |
| | objects = [gen_box[0] for gen_box in gen_boxes] |
| | |
| | objects_unique, objects_count = np.unique(objects, return_counts=True) |
| |
|
| | num_total_matched_boxes = 0 |
| | overall_phrases_words_bboxes = [] |
| | for ind, object_name in enumerate(objects_unique): |
| | bboxes = [box for name, box in gen_boxes if name == object_name] |
| | |
| | if objects_count[ind] > 1: |
| | phrase = p.plural_noun(object_name.replace("an ", "").replace("a ", "")) |
| | if include_counts: |
| | phrase = p.number_to_words(objects_count[ind]) + " " + phrase |
| | else: |
| | phrase = object_name |
| | |
| | word = phrase.split(' ')[-1] |
| | |
| | num_total_matched_boxes += len(bboxes) |
| | overall_phrases_words_bboxes.append((phrase, word, bboxes)) |
| | |
| | assert num_total_matched_boxes == len(gen_boxes), f"{num_total_matched_boxes} != {len(gen_boxes)}" |
| |
|
| | objects_str = ", ".join([phrase for phrase, _, _ in overall_phrases_words_bboxes]) |
| | if objects_str: |
| | if bg_prompt: |
| | overall_prompt = f"{bg_prompt} with {objects_str}" |
| | else: |
| | overall_prompt = objects_str |
| | else: |
| | overall_prompt = bg_prompt |
| | |
| | if verbose: |
| | print("so_prompt_phrase_word_box_list:", so_prompt_phrase_word_box_list) |
| | print("overall_prompt:", overall_prompt) |
| | print("overall_phrases_words_bboxes:", overall_phrases_words_bboxes) |
| | |
| | return so_prompt_phrase_word_box_list, overall_prompt, overall_phrases_words_bboxes |
| |
|