Skip to content
Snippets Groups Projects
utils.py 1.57 KiB
# -*- coding: utf-8 -*-
import ast
import logging
import time
from pathlib import Path
from urllib.parse import urljoin

import cv2
import imageio.v2 as iio
from arkindex_export.models import Element
from worker_generic_training_dataset.exceptions import ImageDownloadError

logger = logging.getLogger(__name__)
MAX_RETRIES = 5


def bounding_box(polygon: list):
    """
    Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points)
    """
    all_x, all_y = zip(*polygon)
    x, y = min(all_x), min(all_y)
    width, height = max(all_x) - x, max(all_y) - y
    return int(x), int(y), int(width), int(height)


def build_image_url(element: Element):
    x, y, width, height = bounding_box(ast.literal_eval(element.polygon))
    return urljoin(
        element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
    )


def download_image(element: Element, folder: Path):
    """
    Download the image to `folder / {element.id}.jpg`
    """
    tries = 1
    # retry loop
    while True:
        if tries > MAX_RETRIES:
            raise ImageDownloadError(element.id, Exception("Maximum retries reached."))
        try:
            image = iio.imread(build_image_url(element))
            cv2.imwrite(
                str(folder / f"{element.id}.jpg"),
                cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
            )
            break
        except TimeoutError:
            logger.warning("Timeout, retry in 1 second.")
            time.sleep(1)
            tries += 1
        except Exception as e:
            raise ImageDownloadError(element.id, e)