Skip to content
Snippets Groups Projects
Commit f57bc667 authored by Mélodie Boillet's avatar Mélodie Boillet Committed by Yoann Schneider
Browse files

Use a single padding method

parent 86163b14
No related branches found
No related tags found
1 merge request!167Use a single padding method
......@@ -6,10 +6,9 @@ import pickle
import cv2
import numpy as np
import torch
from torch import randint
from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing
from dan.utils import pad_image, pad_images, pad_sequences_1D, token_to_ind
from dan.utils import pad_images, pad_sequences_1D, token_to_ind
class OCRDatasetManager(DatasetManager):
......@@ -47,20 +46,6 @@ class OCRDatasetManager(DatasetManager):
dataset.charset = self.charset
dataset.tokens = self.tokens
dataset.convert_labels()
if (
"padding" in dataset.params["config"]
and dataset.params["config"]["padding"]["min_height"] == "max"
):
dataset.params["config"]["padding"]["min_height"] = max(
[s["img"].shape[0] for s in self.train_dataset.samples]
)
if (
"padding" in dataset.params["config"]
and dataset.params["config"]["padding"]["min_width"] == "max"
):
dataset.params["config"]["padding"]["min_width"] = max(
[s["img"].shape[1] for s in self.train_dataset.samples]
)
class OCRDataset(GenericDataset):
......@@ -116,34 +101,6 @@ class OCRDataset(GenericDataset):
[0, sample["img"].shape[0]],
[0, sample["img"].shape[1]],
]
# Padding constraints to handle model needs
if "padding" in self.params["config"] and self.params["config"]["padding"]:
if (
self.set_name == "train"
or not self.params["config"]["padding"]["train_only"]
):
min_pad = self.params["config"]["padding"]["min_pad"]
max_pad = self.params["config"]["padding"]["max_pad"]
pad_width = (
randint(min_pad, max_pad, (1,))
if min_pad is not None and max_pad is not None
else None
)
pad_height = (
randint(min_pad, max_pad, (1,))
if min_pad is not None and max_pad is not None
else None
)
sample["img"], sample["img_position"] = pad_image(
sample["img"],
new_width=self.params["config"]["padding"]["min_width"],
new_height=self.params["config"]["padding"]["min_height"],
pad_width=pad_width,
pad_height=pad_height,
padding_mode=self.params["config"]["padding"]["mode"],
return_position=True,
)
return sample
def convert_labels(self):
......@@ -178,11 +135,8 @@ class OCRCollateFunction:
labels = pad_sequences_1D(labels, padding_value=self.label_padding_value)
labels = torch.tensor(labels).long()
padding_mode = (
self.config["padding_mode"] if "padding_mode" in self.config else "br"
)
imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
imgs = pad_images(imgs, padding_mode=padding_mode)
imgs = pad_images(imgs)
imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
formatted_batch_data = {
......
# -*- coding: utf-8 -*-
import cv2
import numpy as np
from torch import randint
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
......@@ -25,89 +24,29 @@ def pad_sequences_1D(data, padding_value):
return padded_data
def pad_images(data, padding_mode="br"):
def pad_images(data):
"""
data: list of numpy array
mode: "br"/"tl"/"random" (bottom-right, top-left, random)
Pad the images so that they are in the middle of the large padded image (tb-lr mode).
:param data: List of numpy arrays.
:return padded_data: A tensor containing all the padded images.
"""
x_lengths = [x.shape[0] for x in data]
y_lengths = [x.shape[1] for x in data]
longest_x = max(x_lengths)
longest_y = max(y_lengths)
longest_x = max([x.shape[0] for x in data])
longest_y = max([x.shape[1] for x in data])
padded_data = np.zeros((len(data), longest_x, longest_y, data[0].shape[2]))
for i, xy_len in enumerate(zip(x_lengths, y_lengths)):
x_len, y_len = xy_len
if padding_mode == "br":
padded_data[i, :x_len, :y_len, ...] = data[i]
elif padding_mode == "tl":
padded_data[i, -x_len:, -y_len:, ...] = data[i]
elif padding_mode == "random":
xmax = longest_x - x_len
ymax = longest_y - y_len
xi = randint(0, xmax, (1,)) if xmax >= 1 else 0
yi = randint(0, ymax, (1,)) if ymax >= 1 else 0
padded_data[i, xi : xi + x_len, yi : yi + y_len, ...] = data[i]
else:
raise NotImplementedError("Undefined padding mode: {}".format(padding_mode))
for index, image in enumerate(data):
delta_x = longest_x - image.shape[0]
delta_y = longest_y - image.shape[1]
top, bottom = delta_x // 2, delta_x - (delta_x // 2)
left, right = delta_y // 2, delta_y - (delta_y // 2)
padded_data[
index,
top : padded_data.shape[1] - bottom,
left : padded_data.shape[2] - right,
:,
] = image
return padded_data
def pad_image(
image,
new_height=None,
new_width=None,
pad_width=None,
pad_height=None,
padding_mode="br",
return_position=False,
):
"""
data: list of numpy array
mode: "br"/"tl"/"random" (bottom-right, top-left, random)
"""
if pad_width is not None and new_width is not None:
raise NotImplementedError("pad_with and new_width are not compatible")
if pad_height is not None and new_height is not None:
raise NotImplementedError("pad_height and new_height are not compatible")
h, w, c = image.shape
pad_width = (
pad_width
if pad_width is not None
else max(0, new_width - w)
if new_width is not None
else 0
)
pad_height = (
pad_height
if pad_height is not None
else max(0, new_height - h)
if new_height is not None
else 0
)
if not (pad_width == 0 and pad_height == 0):
padded_image = np.zeros((h + pad_height, w + pad_width, c))
if padding_mode == "br":
hi, wi = 0, 0
elif padding_mode == "tl":
hi, wi = pad_height, pad_width
elif padding_mode == "random":
hi = randint(0, pad_height, (1,)) if pad_height >= 1 else 0
wi = randint(0, pad_width, (1,)) if pad_width >= 1 else 0
else:
raise NotImplementedError("Undefined padding mode: {}".format(padding_mode))
padded_image[hi : hi + h, wi : wi + w, ...] = image
output = padded_image
else:
hi, wi = 0, 0
output = image
if return_position:
return output, [[hi, hi + h], [wi, wi + w]]
return output
def read_image(filename, scale=1.0):
"""
Read image and rescale it
......@@ -122,13 +61,6 @@ def read_image(filename, scale=1.0):
return image
def round_floats(float_list, decimals=2):
"""
Round list of floats with fixed decimals
"""
return [np.around(num, decimals) for num in float_list]
# Charset / labels conversion
def token_to_ind(labels, str):
return [labels.index(c) for c in str]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment