Skip to content
Snippets Groups Projects
Commit d82ed3d6 authored by Manon Blanco's avatar Manon Blanco Committed by Mélodie Boillet
Browse files

Remove padding value and padding token parameters from training configuration

parent 435fbdd5
No related branches found
No related tags found
1 merge request!164Remove padding value and padding token parameters from training configuration
...@@ -18,7 +18,6 @@ class DatasetManager: ...@@ -18,7 +18,6 @@ class DatasetManager:
def __init__(self, params, device: str): def __init__(self, params, device: str):
self.params = params self.params = params
self.dataset_class = None self.dataset_class = None
self.img_padding_value = params["config"]["padding_value"]
self.my_collate_function = None self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
...@@ -224,13 +223,6 @@ class GenericDataset(Dataset): ...@@ -224,13 +223,6 @@ class GenericDataset(Dataset):
if self.load_in_memory: if self.load_in_memory:
self.apply_preprocessing(params["config"]["preprocessings"]) self.apply_preprocessing(params["config"]["preprocessings"])
self.padding_value = params["config"]["padding_value"]
if self.padding_value == "mean":
if self.mean is None:
_, _ = self.compute_std_mean()
self.padding_value = self.mean
self.params["config"]["padding_value"] = self.padding_value
self.curriculum_config = None self.curriculum_config = None
def __len__(self): def __len__(self):
......
...@@ -25,14 +25,9 @@ class OCRDatasetManager(DatasetManager): ...@@ -25,14 +25,9 @@ class OCRDatasetManager(DatasetManager):
params["charset"] if "charset" in params else self.get_merged_charsets() params["charset"] if "charset" in params else self.get_merged_charsets()
) )
self.tokens = { self.tokens = {"pad": len(self.charset) + 2}
"pad": params["config"]["padding_token"],
}
self.tokens["end"] = len(self.charset) self.tokens["end"] = len(self.charset)
self.tokens["start"] = len(self.charset) + 1 self.tokens["start"] = len(self.charset) + 1
self.tokens["pad"] = (
self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 2
)
self.params["config"]["padding_token"] = self.tokens["pad"] self.params["config"]["padding_token"] = self.tokens["pad"]
def get_merged_charsets(self): def get_merged_charsets(self):
...@@ -142,7 +137,6 @@ class OCRDataset(GenericDataset): ...@@ -142,7 +137,6 @@ class OCRDataset(GenericDataset):
sample["img"], sample["img_position"] = pad_image( sample["img"], sample["img_position"] = pad_image(
sample["img"], sample["img"],
padding_value=self.padding_value,
new_width=self.params["config"]["padding"]["min_width"], new_width=self.params["config"]["padding"]["min_width"],
new_height=self.params["config"]["padding"]["min_height"], new_height=self.params["config"]["padding"]["min_height"],
pad_width=pad_width, pad_width=pad_width,
...@@ -176,7 +170,6 @@ class OCRCollateFunction: ...@@ -176,7 +170,6 @@ class OCRCollateFunction:
""" """
def __init__(self, config): def __init__(self, config):
self.img_padding_value = float(config["padding_value"])
self.label_padding_value = config["padding_token"] self.label_padding_value = config["padding_token"]
self.config = config self.config = config
...@@ -189,9 +182,7 @@ class OCRCollateFunction: ...@@ -189,9 +182,7 @@ class OCRCollateFunction:
self.config["padding_mode"] if "padding_mode" in self.config else "br" self.config["padding_mode"] if "padding_mode" in self.config else "br"
) )
imgs = [batch_data[i]["img"] for i in range(len(batch_data))] imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
imgs = pad_images( imgs = pad_images(imgs, padding_mode=padding_mode)
imgs, padding_value=self.img_padding_value, padding_mode=padding_mode
)
imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2) imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
formatted_batch_data = { formatted_batch_data = {
......
...@@ -105,8 +105,6 @@ def get_config(): ...@@ -105,8 +105,6 @@ def get_config():
"config": { "config": {
"load_in_memory": True, # Load all images in CPU memory "load_in_memory": True, # Load all images in CPU memory
"worker_per_gpu": 4, # Num of parallel processes per gpu for data loading "worker_per_gpu": 4, # Num of parallel processes per gpu for data loading
"padding_value": 0, # Image padding value
"padding_token": None, # Label padding value
"preprocessings": [ "preprocessings": [
{ {
"type": "to_RGB", "type": "to_RGB",
......
...@@ -25,7 +25,7 @@ def pad_sequences_1D(data, padding_value): ...@@ -25,7 +25,7 @@ def pad_sequences_1D(data, padding_value):
return padded_data return padded_data
def pad_images(data, padding_value, padding_mode="br"): def pad_images(data, padding_mode="br"):
""" """
data: list of numpy array data: list of numpy array
mode: "br"/"tl"/"random" (bottom-right, top-left, random) mode: "br"/"tl"/"random" (bottom-right, top-left, random)
...@@ -34,9 +34,7 @@ def pad_images(data, padding_value, padding_mode="br"): ...@@ -34,9 +34,7 @@ def pad_images(data, padding_value, padding_mode="br"):
y_lengths = [x.shape[1] for x in data] y_lengths = [x.shape[1] for x in data]
longest_x = max(x_lengths) longest_x = max(x_lengths)
longest_y = max(y_lengths) longest_y = max(y_lengths)
padded_data = ( padded_data = np.zeros((len(data), longest_x, longest_y, data[0].shape[2]))
np.ones((len(data), longest_x, longest_y, data[0].shape[2])) * padding_value
)
for i, xy_len in enumerate(zip(x_lengths, y_lengths)): for i, xy_len in enumerate(zip(x_lengths, y_lengths)):
x_len, y_len = xy_len x_len, y_len = xy_len
if padding_mode == "br": if padding_mode == "br":
...@@ -56,7 +54,6 @@ def pad_images(data, padding_value, padding_mode="br"): ...@@ -56,7 +54,6 @@ def pad_images(data, padding_value, padding_mode="br"):
def pad_image( def pad_image(
image, image,
padding_value,
new_height=None, new_height=None,
new_width=None, new_width=None,
pad_width=None, pad_width=None,
...@@ -90,7 +87,7 @@ def pad_image( ...@@ -90,7 +87,7 @@ def pad_image(
) )
if not (pad_width == 0 and pad_height == 0): if not (pad_width == 0 and pad_height == 0):
padded_image = np.ones((h + pad_height, w + pad_width, c)) * padding_value padded_image = np.zeros((h + pad_height, w + pad_width, c))
if padding_mode == "br": if padding_mode == "br":
hi, wi = 0, 0 hi, wi = 0, 0
elif padding_mode == "tl": elif padding_mode == "tl":
......
...@@ -14,8 +14,6 @@ All hyperparameters are specified and editable in the training scripts (meaning ...@@ -14,8 +14,6 @@ All hyperparameters are specified and editable in the training scripts (meaning
| `dataset_params.config.datasets` | Dataset dictionary with the dataset name as key and dataset path as value. | `dict` | | | `dataset_params.config.datasets` | Dataset dictionary with the dataset name as key and dataset path as value. | `dict` | |
| `dataset_params.config.load_in_memory` | Load all images in CPU memory. | `str` | `True` | | `dataset_params.config.load_in_memory` | Load all images in CPU memory. | `str` | `True` |
| `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` | | `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` |
| `dataset_params.config.padding_value` | Image padding value. | `int` | `0` |
| `dataset_params.config.padding_token` | Transcription padding value. | `int` | `None` |
| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) | | `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) |
| `dataset_params.config.augmentation` | Configuration for data augmentation. | `dict` | (see [dedicated section](#data-augmentation)) | | `dataset_params.config.augmentation` | Configuration for data augmentation. | `dict` | (see [dedicated section](#data-augmentation)) |
......
...@@ -68,8 +68,6 @@ def training_config(): ...@@ -68,8 +68,6 @@ def training_config():
}, },
"config": { "config": {
"load_in_memory": True, # Load all images in CPU memory "load_in_memory": True, # Load all images in CPU memory
"padding_value": 0, # Image padding value
"padding_token": None, # Label padding value
"preprocessings": [ "preprocessings": [
{ {
"type": "to_RGB", "type": "to_RGB",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment