Skip to content
Snippets Groups Projects
utils.py 4.82 KiB
Newer Older
from itertools import tee

import cv2
import numpy as np
import torch
from torch.distributions.uniform import Uniform

# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}

class MLflowNotInstalled(Exception):
    """
    Raised when MLflow logging was requested but the module was not installed
    """


def randint(low, high):
    """
    call torch.randint to preserve random among dataloader workers
    """
    return int(torch.randint(low, high, (1,)))


def rand():
    """
    call torch.rand to preserve random among dataloader workers
    """
    return float(torch.rand((1,)))


def rand_uniform(low, high):
    """
    call torch uniform to preserve random among dataloader workers
    """
    return float(Uniform(low, high).sample())


def pad_sequences_1D(data, padding_value):
    """
    Pad data with padding_value to get same length
    """
    x_lengths = [len(x) for x in data]
    longest_x = max(x_lengths)
    padded_data = np.ones((len(data), longest_x)).astype(np.int32) * padding_value
    for i, x_len in enumerate(x_lengths):
        padded_data[i, :x_len] = data[i][:x_len]
    return padded_data


def pad_images(data, padding_value, padding_mode="br"):
    """
    data: list of numpy array
    mode: "br"/"tl"/"random" (bottom-right, top-left, random)
    """
    x_lengths = [x.shape[0] for x in data]
    y_lengths = [x.shape[1] for x in data]
    longest_x = max(x_lengths)
    longest_y = max(y_lengths)
    padded_data = (
        np.ones((len(data), longest_x, longest_y, data[0].shape[2])) * padding_value
    )
    for i, xy_len in enumerate(zip(x_lengths, y_lengths)):
        x_len, y_len = xy_len
        if padding_mode == "br":
            padded_data[i, :x_len, :y_len, ...] = data[i]
        elif padding_mode == "tl":
            padded_data[i, -x_len:, -y_len:, ...] = data[i]
        elif padding_mode == "random":
            xmax = longest_x - x_len
            ymax = longest_y - y_len
            xi = randint(0, xmax) if xmax >= 1 else 0
            yi = randint(0, ymax) if ymax >= 1 else 0
            padded_data[i, xi : xi + x_len, yi : yi + y_len, ...] = data[i]
        else:
            raise NotImplementedError("Undefined padding mode: {}".format(padding_mode))
    return padded_data


def pad_image(
    image,
    padding_value,
    new_height=None,
    new_width=None,
    pad_width=None,
    pad_height=None,
    padding_mode="br",
    return_position=False,
):
    """
    data: list of numpy array
    mode: "br"/"tl"/"random" (bottom-right, top-left, random)
    """
    if pad_width is not None and new_width is not None:
        raise NotImplementedError("pad_with and new_width are not compatible")
    if pad_height is not None and new_height is not None:
        raise NotImplementedError("pad_height and new_height are not compatible")

    h, w, c = image.shape
    pad_width = (
        pad_width
        if pad_width is not None
        else max(0, new_width - w)
        if new_width is not None
        else 0
    )
    pad_height = (
        pad_height
        if pad_height is not None
        else max(0, new_height - h)
        if new_height is not None
        else 0
    )

    if not (pad_width == 0 and pad_height == 0):
        padded_image = np.ones((h + pad_height, w + pad_width, c)) * padding_value
        if padding_mode == "br":
            hi, wi = 0, 0
        elif padding_mode == "tl":
            hi, wi = pad_height, pad_width
        elif padding_mode == "random":
            hi = randint(0, pad_height) if pad_height >= 1 else 0
            wi = randint(0, pad_width) if pad_width >= 1 else 0
        else:
            raise NotImplementedError("Undefined padding mode: {}".format(padding_mode))
        padded_image[hi : hi + h, wi : wi + w, ...] = image
        output = padded_image
    else:
        hi, wi = 0, 0
        output = image

    if return_position:
        return output, [[hi, hi + h], [wi, wi + w]]
    return output


def read_image(filename, scale=1.0):
    """
    Read image and rescale it
    :param filename: Image path
    :param scale: Scaling factor before prediction
    """
    image = cv2.cvtColor(cv2.imread(str(filename)), cv2.COLOR_BGR2RGB)
    if scale != 1.0:
        width = int(image.shape[1] * scale)
        height = int(image.shape[0] * scale)
        image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
    return image


def round_floats(float_list, decimals=2):
    """
    Round list of floats with fixed decimals
    """
    return [np.around(num, decimals) for num in float_list]


def pairwise(iterable):
    """
    Not necessary when using 3.10. See https://docs.python.org/3/library/itertools.html#itertools.pairwise.
    """
    # pairwise('ABCDEFG') --> AB BC CD DE EF FG
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)