Skip to content
Snippets Groups Projects
utils.py 2.56 KiB
# -*- coding: utf-8 -*-
from itertools import islice

import torch
import torchvision.io as torchvision

# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}


class MLflowNotInstalled(Exception):
    """
    Raised when MLflow logging was requested but the module was not installed
    """


def pad_sequences_1D(data, padding_value):
    """
    Pad data with padding_value to get same length
    """
    x_lengths = [len(x) for x in data]
    longest_x = max(x_lengths)
    padded_data = torch.ones((len(data), longest_x), dtype=torch.int32) * padding_value
    for i, x_len in enumerate(x_lengths):
        padded_data[i, :x_len] = torch.tensor(data[i][:x_len])
    return padded_data


def pad_images(data):
    """
    Pad the images so that they are in the middle of the large padded image (tb-lr mode).
    :param data: List of numpy arrays.
    :return padded_data: A tensor containing all the padded images.
    """
    longest_x = max([x.shape[1] for x in data])
    longest_y = max([x.shape[2] for x in data])
    padded_data = torch.zeros((len(data), data[0].shape[0], longest_x, longest_y))
    for index, image in enumerate(data):
        delta_x = longest_x - image.shape[1]
        delta_y = longest_y - image.shape[2]
        top, bottom = delta_x // 2, delta_x - (delta_x // 2)
        left, right = delta_y // 2, delta_y - (delta_y // 2)
        padded_data[
            index,
            :,
            top : padded_data.shape[2] - bottom,
            left : padded_data.shape[3] - right,
        ] = image
    return padded_data


def read_image(path):
    """
    Read image with torch
    :param path: Path of the image to load.
    """
    img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB)
    return img.to(dtype=torch.get_default_dtype()).div(255)


# Charset / labels conversion
def token_to_ind(labels, str):
    return [labels.index(c) for c in str]


def ind_to_token(labels, ind, oov_symbol=None):
    if oov_symbol is not None:
        res = []
        for i in ind:
            if i < len(labels):
                res.append(labels[i])
            else:
                res.append(oov_symbol)
    else:
        res = [labels[i] for i in ind]
    return "".join(res)


def list_to_batches(iterable, n):
    "Batch data into tuples of length n. The last batch may be shorter."
    # list_to_batches('ABCDEFG', 3) --> ABC DEF G
    if n < 1:
        raise ValueError("n must be at least one")
    it = iter(iterable)
    while batch := tuple(islice(it, n)):
        yield batch