Skip to content
Snippets Groups Projects
Verified Commit 35d3b1fc authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Apply d2760687

parent 9d99164b
No related branches found
No related tags found
No related merge requests found
...@@ -46,8 +46,6 @@ class OCRDatasetManager: ...@@ -46,8 +46,6 @@ class OCRDatasetManager:
self.generator = torch.Generator() self.generator = torch.Generator()
self.generator.manual_seed(0) self.generator.manual_seed(0)
self.batch_size = self.get_batch_size_by_set()
self.load_in_memory = ( self.load_in_memory = (
self.params["config"]["load_in_memory"] self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"] if "load_in_memory" in self.params["config"]
...@@ -126,7 +124,7 @@ class OCRDatasetManager: ...@@ -126,7 +124,7 @@ class OCRDatasetManager:
""" """
self.train_loader = DataLoader( self.train_loader = DataLoader(
self.train_dataset, self.train_dataset,
batch_size=self.batch_size["train"], batch_size=self.params["batch_size"],
shuffle=True if self.train_sampler is None else False, shuffle=True if self.train_sampler is None else False,
drop_last=False, drop_last=False,
batch_sampler=self.train_sampler, batch_sampler=self.train_sampler,
...@@ -141,7 +139,7 @@ class OCRDatasetManager: ...@@ -141,7 +139,7 @@ class OCRDatasetManager:
for key in self.valid_datasets.keys(): for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader( self.valid_loaders[key] = DataLoader(
self.valid_datasets[key], self.valid_datasets[key],
batch_size=self.batch_size["val"], batch_size=1,
sampler=self.valid_samplers[key], sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key], batch_sampler=self.valid_samplers[key],
shuffle=False, shuffle=False,
...@@ -197,7 +195,7 @@ class OCRDatasetManager: ...@@ -197,7 +195,7 @@ class OCRDatasetManager:
self.test_loaders[custom_name] = DataLoader( self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name], self.test_datasets[custom_name],
batch_size=self.batch_size["test"], batch_size=1,
sampler=self.test_samplers[custom_name], sampler=self.test_samplers[custom_name],
shuffle=False, shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
...@@ -243,20 +241,6 @@ class OCRDatasetManager: ...@@ -243,20 +241,6 @@ class OCRDatasetManager:
"pad": len(self.charset) + 2, "pad": len(self.charset) + 2,
} }
def get_batch_size_by_set(self):
"""
Return batch size for each set
"""
return {
"train": self.params["batch_size"],
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
class OCRCollateFunction: class OCRCollateFunction:
""" """
......
...@@ -84,14 +84,6 @@ class GenericTrainingManager: ...@@ -84,14 +84,6 @@ class GenericTrainingManager:
self.params["dataset_params"]["batch_size"] = self.params["training_params"][ self.params["dataset_params"]["batch_size"] = self.params["training_params"][
"batch_size" "batch_size"
] ]
if "valid_batch_size" in self.params["training_params"]:
self.params["dataset_params"]["valid_batch_size"] = self.params[
"training_params"
]["valid_batch_size"]
if "test_batch_size" in self.params["training_params"]:
self.params["dataset_params"]["test_batch_size"] = self.params[
"training_params"
]["test_batch_size"]
self.params["dataset_params"]["num_gpu"] = self.params["training_params"][ self.params["dataset_params"]["num_gpu"] = self.params["training_params"][
"nb_gpu" "nb_gpu"
] ]
......
...@@ -170,7 +170,6 @@ def get_config(): ...@@ -170,7 +170,6 @@ def get_config():
"load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate "load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate
"interval_save_weights": None, # None: keep best and last only "interval_save_weights": None, # None: keep best and last only
"batch_size": 2, # mini-batch size for training "batch_size": 2, # mini-batch size for training
"valid_batch_size": 4, # mini-batch size for valdiation
"use_ddp": False, # Use DistributedDataParallel "use_ddp": False, # Use DistributedDataParallel
"ddp_port": "20027", "ddp_port": "20027",
"use_amp": True, # Enable automatic mix-precision "use_amp": True, # Enable automatic mix-precision
......
...@@ -162,7 +162,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa ...@@ -162,7 +162,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
| `training_params.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` | | `training_params.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` |
| `training_params.interval_save_weights` | Step to save weights. Set to `None` to keep only best and last epochs. | `int` | `None` | | `training_params.interval_save_weights` | Step to save weights. Set to `None` to keep only best and last epochs. | `int` | `None` |
| `training_params.batch_size` | Mini-batch size for the training loop. | `int` | `2` | | `training_params.batch_size` | Mini-batch size for the training loop. | `int` | `2` |
| `training_params.valid_batch_size` | Mini-batch size for the valdiation loop. | `int` | `4` |
| `training_params.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` | | `training_params.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` |
| `training_params.ddp_port` | DDP port. | `int` | `20027` | | `training_params.ddp_port` | DDP port. | `int` | `20027` |
| `training_params.use_amp` | Whether to enable automatic mix-precision. | `int` | `torch.cuda.device_count()` | | `training_params.use_amp` | Whether to enable automatic mix-precision. | `int` | `torch.cuda.device_count()` |
...@@ -184,6 +183,7 @@ For a detailed description of all augmentation transforms, see the [dedicated pa ...@@ -184,6 +183,7 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
| `training_params.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` | | `training_params.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` |
| `training_params.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` | | `training_params.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` |
During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations.
## MLFlow logging ## MLFlow logging
......
...@@ -116,7 +116,6 @@ def training_config(): ...@@ -116,7 +116,6 @@ def training_config():
"load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate "load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate
"interval_save_weights": None, # None: keep best and last only "interval_save_weights": None, # None: keep best and last only
"batch_size": 2, # mini-batch size for training "batch_size": 2, # mini-batch size for training
"valid_batch_size": 2, # mini-batch size for valdiation
"use_ddp": False, # Use DistributedDataParallel "use_ddp": False, # Use DistributedDataParallel
"nb_gpu": 0, "nb_gpu": 0,
"optimizers": { "optimizers": {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment