Something went wrong on our end
utils.py 2.56 KiB
# -*- coding: utf-8 -*-
from itertools import islice
import torch
import torchvision.io as torchvision
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"ⓘ": "Ⓘ", "ⓓ": "Ⓓ", "ⓢ": "Ⓢ", "ⓒ": "Ⓒ", "ⓟ": "Ⓟ", "ⓐ": "Ⓐ"}
class MLflowNotInstalled(Exception):
"""
Raised when MLflow logging was requested but the module was not installed
"""
def pad_sequences_1D(data, padding_value):
"""
Pad data with padding_value to get same length
"""
x_lengths = [len(x) for x in data]
longest_x = max(x_lengths)
padded_data = torch.ones((len(data), longest_x), dtype=torch.int32) * padding_value
for i, x_len in enumerate(x_lengths):
padded_data[i, :x_len] = torch.tensor(data[i][:x_len])
return padded_data
def pad_images(data):
"""
Pad the images so that they are in the middle of the large padded image (tb-lr mode).
:param data: List of numpy arrays.
:return padded_data: A tensor containing all the padded images.
"""
longest_x = max([x.shape[1] for x in data])
longest_y = max([x.shape[2] for x in data])
padded_data = torch.zeros((len(data), data[0].shape[0], longest_x, longest_y))
for index, image in enumerate(data):
delta_x = longest_x - image.shape[1]
delta_y = longest_y - image.shape[2]
top, bottom = delta_x // 2, delta_x - (delta_x // 2)
left, right = delta_y // 2, delta_y - (delta_y // 2)
padded_data[
index,
:,
top : padded_data.shape[2] - bottom,
left : padded_data.shape[3] - right,
] = image
return padded_data
def read_image(path):
"""
Read image with torch
:param path: Path of the image to load.
"""
img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB)
return img.to(dtype=torch.get_default_dtype()).div(255)
# Charset / labels conversion
def token_to_ind(labels, str):
return [labels.index(c) for c in str]
def ind_to_token(labels, ind, oov_symbol=None):
if oov_symbol is not None:
res = []
for i in ind:
if i < len(labels):
res.append(labels[i])
else:
res.append(oov_symbol)
else:
res = [labels[i] for i in ind]
return "".join(res)
def list_to_batches(iterable, n):
"Batch data into tuples of length n. The last batch may be shorter."
# list_to_batches('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch