From f57bc667bfaa3ccea27f561c6395c0e331fff18f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Fri, 16 Jun 2023 07:29:10 +0000
Subject: [PATCH] Use a single padding method

---
 dan/manager/ocr.py |  50 +---------------------
 dan/utils.py       | 102 ++++++++-------------------------------------
 2 files changed, 19 insertions(+), 133 deletions(-)

diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 8422f435..c60ce88c 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -6,10 +6,9 @@ import pickle
 import cv2
 import numpy as np
 import torch
-from torch import randint
 
 from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing
-from dan.utils import pad_image, pad_images, pad_sequences_1D, token_to_ind
+from dan.utils import pad_images, pad_sequences_1D, token_to_ind
 
 
 class OCRDatasetManager(DatasetManager):
@@ -47,20 +46,6 @@ class OCRDatasetManager(DatasetManager):
         dataset.charset = self.charset
         dataset.tokens = self.tokens
         dataset.convert_labels()
-        if (
-            "padding" in dataset.params["config"]
-            and dataset.params["config"]["padding"]["min_height"] == "max"
-        ):
-            dataset.params["config"]["padding"]["min_height"] = max(
-                [s["img"].shape[0] for s in self.train_dataset.samples]
-            )
-        if (
-            "padding" in dataset.params["config"]
-            and dataset.params["config"]["padding"]["min_width"] == "max"
-        ):
-            dataset.params["config"]["padding"]["min_width"] = max(
-                [s["img"].shape[1] for s in self.train_dataset.samples]
-            )
 
 
 class OCRDataset(GenericDataset):
@@ -116,34 +101,6 @@ class OCRDataset(GenericDataset):
             [0, sample["img"].shape[0]],
             [0, sample["img"].shape[1]],
         ]
-        # Padding constraints to handle model needs
-        if "padding" in self.params["config"] and self.params["config"]["padding"]:
-            if (
-                self.set_name == "train"
-                or not self.params["config"]["padding"]["train_only"]
-            ):
-                min_pad = self.params["config"]["padding"]["min_pad"]
-                max_pad = self.params["config"]["padding"]["max_pad"]
-                pad_width = (
-                    randint(min_pad, max_pad, (1,))
-                    if min_pad is not None and max_pad is not None
-                    else None
-                )
-                pad_height = (
-                    randint(min_pad, max_pad, (1,))
-                    if min_pad is not None and max_pad is not None
-                    else None
-                )
-
-                sample["img"], sample["img_position"] = pad_image(
-                    sample["img"],
-                    new_width=self.params["config"]["padding"]["min_width"],
-                    new_height=self.params["config"]["padding"]["min_height"],
-                    pad_width=pad_width,
-                    pad_height=pad_height,
-                    padding_mode=self.params["config"]["padding"]["mode"],
-                    return_position=True,
-                )
         return sample
 
     def convert_labels(self):
@@ -178,11 +135,8 @@ class OCRCollateFunction:
         labels = pad_sequences_1D(labels, padding_value=self.label_padding_value)
         labels = torch.tensor(labels).long()
 
-        padding_mode = (
-            self.config["padding_mode"] if "padding_mode" in self.config else "br"
-        )
         imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
-        imgs = pad_images(imgs, padding_mode=padding_mode)
+        imgs = pad_images(imgs)
         imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
 
         formatted_batch_data = {
diff --git a/dan/utils.py b/dan/utils.py
index a10991ad..6ceede28 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -1,7 +1,6 @@
 # -*- coding: utf-8 -*-
 import cv2
 import numpy as np
-from torch import randint
 
 # Layout begin-token to end-token
 SEM_MATCHING_TOKENS = {"ⓘ": "Ⓘ", "ⓓ": "Ⓓ", "ⓢ": "Ⓢ", "ⓒ": "Ⓒ", "ⓟ": "Ⓟ", "ⓐ": "Ⓐ"}
@@ -25,89 +24,29 @@ def pad_sequences_1D(data, padding_value):
     return padded_data
 
 
-def pad_images(data, padding_mode="br"):
+def pad_images(data):
     """
-    data: list of numpy array
-    mode: "br"/"tl"/"random" (bottom-right, top-left, random)
+    Pad the images so that they are in the middle of the large padded image (tb-lr mode).
+    :param data: List of numpy arrays.
+    :return padded_data: A tensor containing all the padded images.
     """
-    x_lengths = [x.shape[0] for x in data]
-    y_lengths = [x.shape[1] for x in data]
-    longest_x = max(x_lengths)
-    longest_y = max(y_lengths)
+    longest_x = max([x.shape[0] for x in data])
+    longest_y = max([x.shape[1] for x in data])
     padded_data = np.zeros((len(data), longest_x, longest_y, data[0].shape[2]))
-    for i, xy_len in enumerate(zip(x_lengths, y_lengths)):
-        x_len, y_len = xy_len
-        if padding_mode == "br":
-            padded_data[i, :x_len, :y_len, ...] = data[i]
-        elif padding_mode == "tl":
-            padded_data[i, -x_len:, -y_len:, ...] = data[i]
-        elif padding_mode == "random":
-            xmax = longest_x - x_len
-            ymax = longest_y - y_len
-            xi = randint(0, xmax, (1,)) if xmax >= 1 else 0
-            yi = randint(0, ymax, (1,)) if ymax >= 1 else 0
-            padded_data[i, xi : xi + x_len, yi : yi + y_len, ...] = data[i]
-        else:
-            raise NotImplementedError("Undefined padding mode: {}".format(padding_mode))
+    for index, image in enumerate(data):
+        delta_x = longest_x - image.shape[0]
+        delta_y = longest_y - image.shape[1]
+        top, bottom = delta_x // 2, delta_x - (delta_x // 2)
+        left, right = delta_y // 2, delta_y - (delta_y // 2)
+        padded_data[
+            index,
+            top : padded_data.shape[1] - bottom,
+            left : padded_data.shape[2] - right,
+            :,
+        ] = image
     return padded_data
 
 
-def pad_image(
-    image,
-    new_height=None,
-    new_width=None,
-    pad_width=None,
-    pad_height=None,
-    padding_mode="br",
-    return_position=False,
-):
-    """
-    data: list of numpy array
-    mode: "br"/"tl"/"random" (bottom-right, top-left, random)
-    """
-    if pad_width is not None and new_width is not None:
-        raise NotImplementedError("pad_with and new_width are not compatible")
-    if pad_height is not None and new_height is not None:
-        raise NotImplementedError("pad_height and new_height are not compatible")
-
-    h, w, c = image.shape
-    pad_width = (
-        pad_width
-        if pad_width is not None
-        else max(0, new_width - w)
-        if new_width is not None
-        else 0
-    )
-    pad_height = (
-        pad_height
-        if pad_height is not None
-        else max(0, new_height - h)
-        if new_height is not None
-        else 0
-    )
-
-    if not (pad_width == 0 and pad_height == 0):
-        padded_image = np.zeros((h + pad_height, w + pad_width, c))
-        if padding_mode == "br":
-            hi, wi = 0, 0
-        elif padding_mode == "tl":
-            hi, wi = pad_height, pad_width
-        elif padding_mode == "random":
-            hi = randint(0, pad_height, (1,)) if pad_height >= 1 else 0
-            wi = randint(0, pad_width, (1,)) if pad_width >= 1 else 0
-        else:
-            raise NotImplementedError("Undefined padding mode: {}".format(padding_mode))
-        padded_image[hi : hi + h, wi : wi + w, ...] = image
-        output = padded_image
-    else:
-        hi, wi = 0, 0
-        output = image
-
-    if return_position:
-        return output, [[hi, hi + h], [wi, wi + w]]
-    return output
-
-
 def read_image(filename, scale=1.0):
     """
     Read image and rescale it
@@ -122,13 +61,6 @@ def read_image(filename, scale=1.0):
     return image
 
 
-def round_floats(float_list, decimals=2):
-    """
-    Round list of floats with fixed decimals
-    """
-    return [np.around(num, decimals) for num in float_list]
-
-
 # Charset / labels conversion
 def token_to_ind(labels, str):
     return [labels.index(c) for c in str]
-- 
GitLab