diff --git a/configs/eval.json b/configs/eval.json index ad3b43b5bc24162ed33eda2e95e0ba17c913a5bd..01072cef4b851fbf43449e2bf4c10ed1c6ad58d9 100644 --- a/configs/eval.json +++ b/configs/eval.json @@ -68,13 +68,17 @@ "train": [ "loss_ce", "cer", + "cer_no_token", "wer", - "wer_no_punct" + "wer_no_punct", + "wer_no_token" ], "eval": [ "cer", + "cer_no_token", "wer", - "wer_no_punct" + "wer_no_punct", + "wer_no_token" ] }, "validation": { diff --git a/configs/quickstart.json b/configs/quickstart.json index 6c8a00fa69e7ff07522a6aab51d4815e32fe0cb3..b9c12ea88ea69a46ae50cc320f2db4b2e2eeb260 100644 --- a/configs/quickstart.json +++ b/configs/quickstart.json @@ -77,13 +77,17 @@ "train": [ "loss_ce", "cer", + "cer_no_token", "wer", - "wer_no_punct" + "wer_no_punct", + "wer_no_token" ], "eval": [ "cer", + "cer_no_token", "wer", - "wer_no_punct" + "wer_no_punct", + "wer_no_token" ] }, "validation": { diff --git a/configs/tests.json b/configs/tests.json index 2267525da6db80cabd958005b29a445871fd9d14..f1c64402aeb257c6c02e9bdc25366aff16342bfe 100644 --- a/configs/tests.json +++ b/configs/tests.json @@ -68,13 +68,17 @@ "train": [ "loss_ce", "cer", + "cer_no_token", "wer", - "wer_no_punct" + "wer_no_punct", + "wer_no_token" ], "eval": [ "cer", + "cer_no_token", "wer", - "wer_no_punct" + "wer_no_punct", + "wer_no_token" ] }, "validation": { diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index 6cc95a03bd7894a9fbb09a320c0f92e4156a48a1..d998ff1585c023de9540270871dc3ee64dcd1b9a 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -50,7 +50,7 @@ def eval(rank, config, mlflow_logging): model = Manager(config) model.load_model() - metrics = ["cer", "wer", "wer_no_punct", "time"] + metrics = ["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token", "time"] if config["dataset"]["tokens"] is not None: metrics.append("ner") diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index f6b6ec74d50aca424411869a6fc2597aff4b6eb6..e9a3174e3c4147ed50fbb7ee4a1f277b4609b45a 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -19,6 +19,9 @@ REGEX_CONSECUTIVE_SPACES = re.compile(r" +") # Keep only one space character REGEX_ONLY_ONE_SPACE = re.compile(r"\s+") +# Mapping between computation tasks (CER, WER, NER) and their metric keyword +METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"} + class MetricManager: def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None): @@ -37,46 +40,46 @@ class MetricManager: self.metric_names: List[str] = metric_names self.epoch_metrics = defaultdict(list) - def edit_cer_from_string(self, gt: str, pred: str): + def format_string_for_cer(self, text: str, remove_token: bool = False): """ - Format and compute edit distance between two strings at character level + Format string for CER computation: remove layout tokens and extra spaces """ - gt = self.format_string_for_cer(gt) - pred = self.format_string_for_cer(pred) - return editdistance.eval(gt, pred) + if remove_token and self.remove_tokens is not None: + text = self.remove_tokens.sub("", text) - def nb_chars_cer_from_string(self, gt: str) -> int: - """ - Compute length after formatting of ground truth string - """ - return len(self.format_string_for_cer(gt)) + text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text) + return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip() - def format_string_for_wer(self, text: str, remove_punct: bool = False): + def format_string_for_wer( + self, text: str, remove_punct: bool = False, remove_token: bool = False + ): """ Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space """ if remove_punct: text = REGEX_PUNCTUATION.sub("", text) - if self.remove_tokens is not None: + if remove_token and self.remove_tokens is not None: text = self.remove_tokens.sub("", text) return REGEX_ONLY_ONE_SPACE.sub(" ", text).strip().split(" ") - def format_string_for_cer(self, text: str): - """ - Format string for CER computation: remove layout tokens and extra spaces - """ - if self.remove_tokens is not None: - text = self.remove_tokens.sub("", text) - - text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text) - return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip() - def format_string_for_ner(self, text: str): """ Format string for NER computation: only keep layout tokens """ return self.keep_tokens.sub("", text) + def _format_string(self, task: str, *args, **kwargs): + """ + Call the proper `format_string_for_*` method for the given task + """ + match task: + case "cer": + return self.format_string_for_cer(*args, **kwargs) + case "wer": + return self.format_string_for_wer(*args, **kwargs) + case "ner": + return self.format_string_for_ner(*args, **kwargs) + def update_metrics(self, batch_metrics): """ Add batch metrics to the metrics @@ -103,13 +106,13 @@ class MetricManager: display_values["sample_time"] = float(round(sample_time, 4)) display_values[metric_name] = value continue - case "cer": - num_name, denom_name = "edit_chars", "nb_chars" - case "wer" | "wer_no_punct": + case "cer" | "cer_no_token" | "wer" | "wer_no_punct" | "wer_no_token" | "ner": + keyword = METRICS_KEYWORD[metric_name[:3]] suffix = metric_name[3:] - num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix - case "ner": - num_name, denom_name = "edit_tokens", "nb_tokens" + num_name, denom_name = ( + "edit_" + keyword + suffix, + "nb_" + keyword + suffix, + ) case "loss" | "loss_ce": display_values[metric_name] = round( float( @@ -147,28 +150,30 @@ class MetricManager: gt, prediction = values["str_y"], values["str_x"] for metric_name in metric_names: match metric_name: - case "cer": - metrics["edit_chars"] = list( - map(self.edit_cer_from_string, gt, prediction) - ) - metrics["nb_chars"] = list(map(self.nb_chars_cer_from_string, gt)) - case "wer" | "wer_no_punct": + case "cer" | "cer_no_token" | "wer" | "wer_no_punct" | "wer_no_token" | "ner": + task = metric_name[:3] + keyword = METRICS_KEYWORD[task] suffix = metric_name[3:] - split_gt = list(map(self.format_string_for_wer, gt, [bool(suffix)])) + + # Add extra parameters for the format functions + extras = [] + if suffix == "_no_punct": + extras.append([{"remove_punct": True}]) + elif suffix == "_no_token": + extras.append([{"remove_token": True}]) + + # Run the format function for the desired computation (CER, WER or NER) + split_gt = list(map(self._format_string, [task], gt, *extras)) split_pred = list( - map(self.format_string_for_wer, prediction, [bool(suffix)]) + map(self._format_string, [task], prediction, *extras) ) - metrics["edit_words" + suffix] = list( - map(editdistance.eval, split_gt, split_pred) - ) - metrics["nb_words" + suffix] = list(map(len, split_gt)) - case "ner": - split_gt = list(map(self.format_string_for_ner, gt)) - split_pred = list(map(self.format_string_for_ner, prediction)) - metrics["edit_tokens"] = list( + + # Compute and store edit distance/length for the desired level + # (chars, words or tokens) as metrics + metrics["edit_" + keyword + suffix] = list( map(editdistance.eval, split_gt, split_pred) ) - metrics["nb_tokens"] = list(map(len, split_gt)) + metrics["nb_" + keyword + suffix] = list(map(len, split_gt)) case "loss" | "loss_ce": metrics[metric_name] = [ values[metric_name], diff --git a/docs/usage/train/config.md b/docs/usage/train/config.md index a435fdc6340f2e8a9ed3987021c0dfc00eb24724..ac7e1d60102bfcfc5bce97b6e24124590bdce402 100644 --- a/docs/usage/train/config.md +++ b/docs/usage/train/config.md @@ -37,34 +37,34 @@ To determine the value to use for `dataset.max_char_prediction`, you can use the ## Training parameters -| Name | Description | Type | Default | -| ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------- | ------------ | ----------------------------------------------------------------- | -| `training.data.batch_size` | Mini-batch size for the training loop. | `int` | `2` | -| `training.data.load_in_memory` | Load all images in CPU memory. | `bool` | `True` | -| `training.data.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` | -| `training.data.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) | -| `training.data.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) | -| `training.output_folder` | Directory for checkpoint and results. | `str` | | -| `training.max_nb_epochs` | Maximum number of epochs before stopping training. | `int` | `800` | -| `training.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` | -| `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` | -| `training.device.ddp_port` | DDP port. | `int` | `20027` | -| `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` | -| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | | -| `training.device.force` | Use a specific device if available. Use `cpu` to train on CPU (for debugging) or `cuda`/`cuda:$gpu_device` to train on GPU. | `str` | | -| `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` | -| `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` | -| `training.lr_schedulers` | Learning rate schedulers. | custom class | | -| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` | -| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` | -| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | | -| `training.metrics.train` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "wer", "wer_no_punct"]` | -| `training.metrics.eval` | List of metrics to compute during validation. | `list` | `["cer", "wer", "wer_no_punct"]` | -| `training.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` | -| `training.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` | -| `training.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` | -| `training.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]` | -| `training.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, False]` | +| Name | Description | Type | Default | +| ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------- | ------------ | --------------------------------------------------------------------------- | +| `training.data.batch_size` | Mini-batch size for the training loop. | `int` | `2` | +| `training.data.load_in_memory` | Load all images in CPU memory. | `bool` | `True` | +| `training.data.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` | +| `training.data.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) | +| `training.data.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) | +| `training.output_folder` | Directory for checkpoint and results. | `str` | | +| `training.max_nb_epochs` | Maximum number of epochs before stopping training. | `int` | `800` | +| `training.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` | +| `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` | +| `training.device.ddp_port` | DDP port. | `int` | `20027` | +| `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` | +| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | | +| `training.device.force` | Use a specific device if available. Use `cpu` to train on CPU (for debugging) or `cuda`/`cuda:$gpu_device` to train on GPU. | `str` | | +| `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` | +| `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` | +| `training.lr_schedulers` | Learning rate schedulers. | custom class | | +| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` | +| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` | +| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | | +| `training.metrics.train` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` | +| `training.metrics.eval` | List of metrics to compute during validation. | `list` | `["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` | +| `training.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` | +| `training.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` | +| `training.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` | +| `training.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]` | +| `training.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, False]` | - To train on several GPUs, simply set the `training.use_ddp` parameter to `True`. By default, the model will use all available GPUs. To restrict access to fewer GPUs, one can modify the `training.nb_gpu` parameter. - During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations. diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index c0776167a003930ea1c1e6fc30b4250ea9f0075f..1f73b048b0b57134ec261f90537d4cb507331f5c 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -16,28 +16,40 @@ from tests import FIXTURES { "nb_chars": 43, "cer": 1.3023, + "nb_chars_no_token": 43, + "cer_no_token": 1.3023, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, + "nb_words_no_token": 9, + "wer_no_token": 1.0, "nb_samples": 2, }, { "nb_chars": 41, "cer": 1.2683, + "nb_chars_no_token": 41, + "cer_no_token": 1.2683, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, + "nb_words_no_token": 9, + "wer_no_token": 1.0, "nb_samples": 2, }, { "nb_chars": 49, "cer": 1.1224, + "nb_chars_no_token": 49, + "cer_no_token": 1.1224, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, + "nb_words_no_token": 9, + "wer_no_token": 1.0, "nb_samples": 2, }, ),