Skip to content
Snippets Groups Projects
Commit 66fc27c0 authored by Mélodie's avatar Mélodie
Browse files

Appy f57bc667

parent 454b992e
No related branches found
No related tags found
No related merge requests found
......@@ -6,10 +6,9 @@ import pickle
import cv2
import numpy as np
import torch
from torch import randint
from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing
from dan.utils import pad_image, pad_images, pad_sequences_1D, token_to_ind
from dan.utils import pad_images, pad_sequences_1D, token_to_ind
class OCRDatasetManager(DatasetManager):
......@@ -47,20 +46,6 @@ class OCRDatasetManager(DatasetManager):
dataset.charset = self.charset
dataset.tokens = self.tokens
dataset.convert_labels()
if (
"padding" in dataset.params["config"]
and dataset.params["config"]["padding"]["min_height"] == "max"
):
dataset.params["config"]["padding"]["min_height"] = max(
[s["img"].shape[0] for s in self.train_dataset.samples]
)
if (
"padding" in dataset.params["config"]
and dataset.params["config"]["padding"]["min_width"] == "max"
):
dataset.params["config"]["padding"]["min_width"] = max(
[s["img"].shape[1] for s in self.train_dataset.samples]
)
class OCRDataset(GenericDataset):
......@@ -116,34 +101,6 @@ class OCRDataset(GenericDataset):
[0, sample["img"].shape[0]],
[0, sample["img"].shape[1]],
]
# Padding constraints to handle model needs
if "padding" in self.params["config"] and self.params["config"]["padding"]:
if (
self.set_name == "train"
or not self.params["config"]["padding"]["train_only"]
):
min_pad = self.params["config"]["padding"]["min_pad"]
max_pad = self.params["config"]["padding"]["max_pad"]
pad_width = (
randint(min_pad, max_pad, (1,))
if min_pad is not None and max_pad is not None
else None
)
pad_height = (
randint(min_pad, max_pad, (1,))
if min_pad is not None and max_pad is not None
else None
)
sample["img"], sample["img_position"] = pad_image(
sample["img"],
new_width=self.params["config"]["padding"]["min_width"],
new_height=self.params["config"]["padding"]["min_height"],
pad_width=pad_width,
pad_height=pad_height,
padding_mode=self.params["config"]["padding"]["mode"],
return_position=True,
)
return sample
def convert_labels(self):
......@@ -178,11 +135,8 @@ class OCRCollateFunction:
labels = pad_sequences_1D(labels, padding_value=self.label_padding_value)
labels = torch.tensor(labels).long()
padding_mode = (
self.config["padding_mode"] if "padding_mode" in self.config else "br"
)
imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
imgs = pad_images(imgs, padding_mode=padding_mode)
imgs = pad_images(imgs)
imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
formatted_batch_data = {
......
# -*- coding: utf-8 -*-
import cv2
import numpy as np
from torch import randint
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
......@@ -25,89 +24,29 @@ def pad_sequences_1D(data, padding_value):
return padded_data
def pad_images(data, padding_mode="br"):
def pad_images(data):
"""
data: list of numpy array
mode: "br"/"tl"/"random" (bottom-right, top-left, random)
Pad the images so that they are in the middle of the large padded image (tb-lr mode).
:param data: List of numpy arrays.
:return padded_data: A tensor containing all the padded images.
"""
x_lengths = [x.shape[0] for x in data]
y_lengths = [x.shape[1] for x in data]
longest_x = max(x_lengths)
longest_y = max(y_lengths)
longest_x = max([x.shape[0] for x in data])
longest_y = max([x.shape[1] for x in data])
padded_data = np.zeros((len(data), longest_x, longest_y, data[0].shape[2]))
for i, xy_len in enumerate(zip(x_lengths, y_lengths)):
x_len, y_len = xy_len
if padding_mode == "br":
padded_data[i, :x_len, :y_len, ...] = data[i]
elif padding_mode == "tl":
padded_data[i, -x_len:, -y_len:, ...] = data[i]
elif padding_mode == "random":
xmax = longest_x - x_len
ymax = longest_y - y_len
xi = randint(0, xmax, (1,)) if xmax >= 1 else 0
yi = randint(0, ymax, (1,)) if ymax >= 1 else 0
padded_data[i, xi : xi + x_len, yi : yi + y_len, ...] = data[i]
else:
raise NotImplementedError("Undefined padding mode: {}".format(padding_mode))
for index, image in enumerate(data):
delta_x = longest_x - image.shape[0]
delta_y = longest_y - image.shape[1]
top, bottom = delta_x // 2, delta_x - (delta_x // 2)
left, right = delta_y // 2, delta_y - (delta_y // 2)
padded_data[
index,
top : padded_data.shape[1] - bottom,
left : padded_data.shape[2] - right,
:,
] = image
return padded_data
def pad_image(
image,
new_height=None,
new_width=None,
pad_width=None,
pad_height=None,
padding_mode="br",
return_position=False,
):
"""
data: list of numpy array
mode: "br"/"tl"/"random" (bottom-right, top-left, random)
"""
if pad_width is not None and new_width is not None:
raise NotImplementedError("pad_with and new_width are not compatible")
if pad_height is not None and new_height is not None:
raise NotImplementedError("pad_height and new_height are not compatible")
h, w, c = image.shape
pad_width = (
pad_width
if pad_width is not None
else max(0, new_width - w)
if new_width is not None
else 0
)
pad_height = (
pad_height
if pad_height is not None
else max(0, new_height - h)
if new_height is not None
else 0
)
if not (pad_width == 0 and pad_height == 0):
padded_image = np.zeros((h + pad_height, w + pad_width, c))
if padding_mode == "br":
hi, wi = 0, 0
elif padding_mode == "tl":
hi, wi = pad_height, pad_width
elif padding_mode == "random":
hi = randint(0, pad_height, (1,)) if pad_height >= 1 else 0
wi = randint(0, pad_width, (1,)) if pad_width >= 1 else 0
else:
raise NotImplementedError("Undefined padding mode: {}".format(padding_mode))
padded_image[hi : hi + h, wi : wi + w, ...] = image
output = padded_image
else:
hi, wi = 0, 0
output = image
if return_position:
return output, [[hi, hi + h], [wi, wi + w]]
return output
def read_image(filename, scale=1.0):
"""
Read image and rescale it
......@@ -122,13 +61,6 @@ def read_image(filename, scale=1.0):
return image
def round_floats(float_list, decimals=2):
"""
Round list of floats with fixed decimals
"""
return [np.around(num, decimals) for num in float_list]
# Charset / labels conversion
def token_to_ind(labels, str):
return [labels.index(c) for c in str]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment