Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
......@@ -18,7 +18,6 @@ class DatasetManager:
def __init__(self, params, device: str):
self.params = params
self.dataset_class = None
self.img_padding_value = params["config"]["padding_value"]
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
......@@ -224,13 +223,6 @@ class GenericDataset(Dataset):
if self.load_in_memory:
self.apply_preprocessing(params["config"]["preprocessings"])
self.padding_value = params["config"]["padding_value"]
if self.padding_value == "mean":
if self.mean is None:
_, _ = self.compute_std_mean()
self.padding_value = self.mean
self.params["config"]["padding_value"] = self.padding_value
self.curriculum_config = None
def __len__(self):
......
......@@ -25,14 +25,9 @@ class OCRDatasetManager(DatasetManager):
params["charset"] if "charset" in params else self.get_merged_charsets()
)
self.tokens = {
"pad": params["config"]["padding_token"],
}
self.tokens = {"pad": len(self.charset) + 2}
self.tokens["end"] = len(self.charset)
self.tokens["start"] = len(self.charset) + 1
self.tokens["pad"] = (
self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 2
)
self.params["config"]["padding_token"] = self.tokens["pad"]
def get_merged_charsets(self):
......@@ -77,9 +72,8 @@ class OCRDataset(GenericDataset):
super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets)
self.charset = None
self.tokens = None
self.reduce_dims_factor = np.array(
[params["config"]["height_divisor"], params["config"]["width_divisor"], 1]
)
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.reduce_dims_factor = np.array([32, 8, 1])
self.collate_function = OCRCollateFunction
def __getitem__(self, idx):
......@@ -143,7 +137,6 @@ class OCRDataset(GenericDataset):
sample["img"], sample["img_position"] = pad_image(
sample["img"],
padding_value=self.padding_value,
new_width=self.params["config"]["padding"]["min_width"],
new_height=self.params["config"]["padding"]["min_height"],
pad_width=pad_width,
......@@ -177,7 +170,6 @@ class OCRCollateFunction:
"""
def __init__(self, config):
self.img_padding_value = float(config["padding_value"])
self.label_padding_value = config["padding_token"]
self.config = config
......@@ -190,9 +182,7 @@ class OCRCollateFunction:
self.config["padding_mode"] if "padding_mode" in self.config else "br"
)
imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
imgs = pad_images(
imgs, padding_value=self.img_padding_value, padding_mode=padding_mode
)
imgs = pad_images(imgs, padding_mode=padding_mode)
imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
formatted_batch_data = {
......
......@@ -105,10 +105,6 @@ def get_config():
"config": {
"load_in_memory": True, # Load all images in CPU memory
"worker_per_gpu": 4, # Num of parallel processes per gpu for data loading
"width_divisor": 8, # Image width will be divided by 8
"height_divisor": 32, # Image height will be divided by 32
"padding_value": 0, # Image padding value
"padding_token": None, # Label padding value
"preprocessings": [
{
"type": "to_RGB",
......
......@@ -25,7 +25,7 @@ def pad_sequences_1D(data, padding_value):
return padded_data
def pad_images(data, padding_value, padding_mode="br"):
def pad_images(data, padding_mode="br"):
"""
data: list of numpy array
mode: "br"/"tl"/"random" (bottom-right, top-left, random)
......@@ -34,9 +34,7 @@ def pad_images(data, padding_value, padding_mode="br"):
y_lengths = [x.shape[1] for x in data]
longest_x = max(x_lengths)
longest_y = max(y_lengths)
padded_data = (
np.ones((len(data), longest_x, longest_y, data[0].shape[2])) * padding_value
)
padded_data = np.zeros((len(data), longest_x, longest_y, data[0].shape[2]))
for i, xy_len in enumerate(zip(x_lengths, y_lengths)):
x_len, y_len = xy_len
if padding_mode == "br":
......@@ -56,7 +54,6 @@ def pad_images(data, padding_value, padding_mode="br"):
def pad_image(
image,
padding_value,
new_height=None,
new_width=None,
pad_width=None,
......@@ -90,7 +87,7 @@ def pad_image(
)
if not (pad_width == 0 and pad_height == 0):
padded_image = np.ones((h + pad_height, w + pad_width, c)) * padding_value
padded_image = np.zeros((h + pad_height, w + pad_width, c))
if padding_mode == "br":
hi, wi = 0, 0
elif padding_mode == "tl":
......
......@@ -14,10 +14,6 @@ All hyperparameters are specified and editable in the training scripts (meaning
| `dataset_params.config.datasets` | Dataset dictionary with the dataset name as key and dataset path as value. | `dict` | |
| `dataset_params.config.load_in_memory` | Load all images in CPU memory. | `str` | `True` |
| `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` |
| `dataset_params.config.height_divisor` | Factor to reduce the width of the feature vector before feeding the decoder. | `int` | `8` |
| `dataset_params.config.width_divisor` | Factor to reduce the height of the feature vector before feeding the decoder. | `int` | `32` |
| `dataset_params.config.padding_value` | Image padding value. | `int` | `0` |
| `dataset_params.config.padding_token` | Transcription padding value. | `int` | `None` |
| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) |
| `dataset_params.config.augmentation` | Configuration for data augmentation. | `dict` | (see [dedicated section](#data-augmentation)) |
......
......@@ -68,10 +68,6 @@ def training_config():
},
"config": {
"load_in_memory": True, # Load all images in CPU memory
"width_divisor": 8, # Image width will be divided by 8
"height_divisor": 32, # Image height will be divided by 32
"padding_value": 0, # Image padding value
"padding_token": None, # Label padding value
"preprocessings": [
{
"type": "to_RGB",
......