Source code for concepts.benchmark.vision_language.shapes.shapes_detection_utils

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : shapes_detection_utils.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 10/18/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import cv2
import numpy as np
import torch

__all__ = ['detect_shapes', 'tensor_to_image', 'image_to_tensor']


[docs] def detect_shapes(image): output = list() patches = _split_image_into_objects(image) for x in patches: if _detect_object(x): output.append((_detect_size(x), _detect_color(x), _detect_shape(x))) else: output.append(None) return patches, output
[docs] def tensor_to_image(x: torch.Tensor) -> np.ndarray: x = x.detach().cpu().numpy() x = x.transpose(1, 2, 0) x = (x * 0.5 + 0.5) * 255 x = x.clip(0, 255).astype(np.uint8) return x
[docs] def image_to_tensor(x: np.ndarray) -> torch.Tensor: x = x.transpose(2, 0, 1) x = x / 255.0 * 2 - 1 x = torch.from_numpy(x).float() return x
def _split_image_into_objects(x): x = x.reshape((3, 10, 3, 10, 3)) x = x.transpose((0, 2, 1, 3, 4)) return [x[i, j] for i in range(3) for j in range(3)] def _detect_object(x): if (x > 5).sum(): return True return False def _detect_color(x): x = x.reshape(-1, 3) x = x[x.max(-1) > 5] x = x.mean(axis=0) c = x.argmax() return ['red', 'green', 'blue'][c] def _detect_size(x): x_shape = x.shape x = x.reshape(-1, 3) x = x.max(-1) > 5 x = x.reshape(x_shape[:2]) x_range = x.any(axis=1).nonzero() y_range = x.any(axis=0).nonzero() boundary = (x_range[0].max() - x_range[0].min() + 1, y_range[0].max() - y_range[0].min() + 1) if max(boundary) > 7: return 'big' return 'small' def _detect_shape(x): gray = cv2.cvtColor(x, cv2.COLOR_RGB2GRAY) gray = cv2.resize(gray, (256, 256), cv2.INTER_NEAREST) _, threshold = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(threshold, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) contour = contours[0] approx = cv2.approxPolyDP(contour, 0.1 * cv2.arcLength(contour, True), True) if len(approx) == 3: return 'triangle' approx = cv2.approxPolyDP(contour, 0.02 * cv2.arcLength(contour, True), True) if len(approx) == 4: return 'square' else: return 'circle'