Skip to content
Snippets Groups Projects
Commit 543d1e23 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'evaluate-pure-htr' into 'main'

Evaluate pure HTR metric

Closes #236

See merge request !326
parents 2afec4ec 6cf7df04
No related branches found
No related tags found
1 merge request!326Evaluate pure HTR metric
...@@ -68,13 +68,17 @@ ...@@ -68,13 +68,17 @@
"train": [ "train": [
"loss_ce", "loss_ce",
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
], ],
"eval": [ "eval": [
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
] ]
}, },
"validation": { "validation": {
......
...@@ -77,13 +77,17 @@ ...@@ -77,13 +77,17 @@
"train": [ "train": [
"loss_ce", "loss_ce",
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
], ],
"eval": [ "eval": [
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
] ]
}, },
"validation": { "validation": {
......
...@@ -68,13 +68,17 @@ ...@@ -68,13 +68,17 @@
"train": [ "train": [
"loss_ce", "loss_ce",
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
], ],
"eval": [ "eval": [
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
] ]
}, },
"validation": { "validation": {
......
...@@ -50,7 +50,7 @@ def eval(rank, config, mlflow_logging): ...@@ -50,7 +50,7 @@ def eval(rank, config, mlflow_logging):
model = Manager(config) model = Manager(config)
model.load_model() model.load_model()
metrics = ["cer", "wer", "wer_no_punct", "time"] metrics = ["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token", "time"]
if config["dataset"]["tokens"] is not None: if config["dataset"]["tokens"] is not None:
metrics.append("ner") metrics.append("ner")
......
...@@ -19,6 +19,9 @@ REGEX_CONSECUTIVE_SPACES = re.compile(r" +") ...@@ -19,6 +19,9 @@ REGEX_CONSECUTIVE_SPACES = re.compile(r" +")
# Keep only one space character # Keep only one space character
REGEX_ONLY_ONE_SPACE = re.compile(r"\s+") REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
# Mapping between computation tasks (CER, WER, NER) and their metric keyword
METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
class MetricManager: class MetricManager:
def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None): def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None):
...@@ -37,46 +40,46 @@ class MetricManager: ...@@ -37,46 +40,46 @@ class MetricManager:
self.metric_names: List[str] = metric_names self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list) self.epoch_metrics = defaultdict(list)
def edit_cer_from_string(self, gt: str, pred: str): def format_string_for_cer(self, text: str, remove_token: bool = False):
""" """
Format and compute edit distance between two strings at character level Format string for CER computation: remove layout tokens and extra spaces
""" """
gt = self.format_string_for_cer(gt) if remove_token and self.remove_tokens is not None:
pred = self.format_string_for_cer(pred) text = self.remove_tokens.sub("", text)
return editdistance.eval(gt, pred)
def nb_chars_cer_from_string(self, gt: str) -> int: text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
""" return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
Compute length after formatting of ground truth string
"""
return len(self.format_string_for_cer(gt))
def format_string_for_wer(self, text: str, remove_punct: bool = False): def format_string_for_wer(
self, text: str, remove_punct: bool = False, remove_token: bool = False
):
""" """
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
""" """
if remove_punct: if remove_punct:
text = REGEX_PUNCTUATION.sub("", text) text = REGEX_PUNCTUATION.sub("", text)
if self.remove_tokens is not None: if remove_token and self.remove_tokens is not None:
text = self.remove_tokens.sub("", text) text = self.remove_tokens.sub("", text)
return REGEX_ONLY_ONE_SPACE.sub(" ", text).strip().split(" ") return REGEX_ONLY_ONE_SPACE.sub(" ", text).strip().split(" ")
def format_string_for_cer(self, text: str):
"""
Format string for CER computation: remove layout tokens and extra spaces
"""
if self.remove_tokens is not None:
text = self.remove_tokens.sub("", text)
text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
def format_string_for_ner(self, text: str): def format_string_for_ner(self, text: str):
""" """
Format string for NER computation: only keep layout tokens Format string for NER computation: only keep layout tokens
""" """
return self.keep_tokens.sub("", text) return self.keep_tokens.sub("", text)
def _format_string(self, task: str, *args, **kwargs):
"""
Call the proper `format_string_for_*` method for the given task
"""
match task:
case "cer":
return self.format_string_for_cer(*args, **kwargs)
case "wer":
return self.format_string_for_wer(*args, **kwargs)
case "ner":
return self.format_string_for_ner(*args, **kwargs)
def update_metrics(self, batch_metrics): def update_metrics(self, batch_metrics):
""" """
Add batch metrics to the metrics Add batch metrics to the metrics
...@@ -103,13 +106,13 @@ class MetricManager: ...@@ -103,13 +106,13 @@ class MetricManager:
display_values["sample_time"] = float(round(sample_time, 4)) display_values["sample_time"] = float(round(sample_time, 4))
display_values[metric_name] = value display_values[metric_name] = value
continue continue
case "cer": case "cer" | "cer_no_token" | "wer" | "wer_no_punct" | "wer_no_token" | "ner":
num_name, denom_name = "edit_chars", "nb_chars" keyword = METRICS_KEYWORD[metric_name[:3]]
case "wer" | "wer_no_punct":
suffix = metric_name[3:] suffix = metric_name[3:]
num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix num_name, denom_name = (
case "ner": "edit_" + keyword + suffix,
num_name, denom_name = "edit_tokens", "nb_tokens" "nb_" + keyword + suffix,
)
case "loss" | "loss_ce": case "loss" | "loss_ce":
display_values[metric_name] = round( display_values[metric_name] = round(
float( float(
...@@ -147,28 +150,30 @@ class MetricManager: ...@@ -147,28 +150,30 @@ class MetricManager:
gt, prediction = values["str_y"], values["str_x"] gt, prediction = values["str_y"], values["str_x"]
for metric_name in metric_names: for metric_name in metric_names:
match metric_name: match metric_name:
case "cer": case "cer" | "cer_no_token" | "wer" | "wer_no_punct" | "wer_no_token" | "ner":
metrics["edit_chars"] = list( task = metric_name[:3]
map(self.edit_cer_from_string, gt, prediction) keyword = METRICS_KEYWORD[task]
)
metrics["nb_chars"] = list(map(self.nb_chars_cer_from_string, gt))
case "wer" | "wer_no_punct":
suffix = metric_name[3:] suffix = metric_name[3:]
split_gt = list(map(self.format_string_for_wer, gt, [bool(suffix)]))
# Add extra parameters for the format functions
extras = []
if suffix == "_no_punct":
extras.append([{"remove_punct": True}])
elif suffix == "_no_token":
extras.append([{"remove_token": True}])
# Run the format function for the desired computation (CER, WER or NER)
split_gt = list(map(self._format_string, [task], gt, *extras))
split_pred = list( split_pred = list(
map(self.format_string_for_wer, prediction, [bool(suffix)]) map(self._format_string, [task], prediction, *extras)
) )
metrics["edit_words" + suffix] = list(
map(editdistance.eval, split_gt, split_pred) # Compute and store edit distance/length for the desired level
) # (chars, words or tokens) as metrics
metrics["nb_words" + suffix] = list(map(len, split_gt)) metrics["edit_" + keyword + suffix] = list(
case "ner":
split_gt = list(map(self.format_string_for_ner, gt))
split_pred = list(map(self.format_string_for_ner, prediction))
metrics["edit_tokens"] = list(
map(editdistance.eval, split_gt, split_pred) map(editdistance.eval, split_gt, split_pred)
) )
metrics["nb_tokens"] = list(map(len, split_gt)) metrics["nb_" + keyword + suffix] = list(map(len, split_gt))
case "loss" | "loss_ce": case "loss" | "loss_ce":
metrics[metric_name] = [ metrics[metric_name] = [
values[metric_name], values[metric_name],
......
...@@ -37,34 +37,34 @@ To determine the value to use for `dataset.max_char_prediction`, you can use the ...@@ -37,34 +37,34 @@ To determine the value to use for `dataset.max_char_prediction`, you can use the
## Training parameters ## Training parameters
| Name | Description | Type | Default | | Name | Description | Type | Default |
| ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------- | ------------ | ----------------------------------------------------------------- | | ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------- | ------------ | --------------------------------------------------------------------------- |
| `training.data.batch_size` | Mini-batch size for the training loop. | `int` | `2` | | `training.data.batch_size` | Mini-batch size for the training loop. | `int` | `2` |
| `training.data.load_in_memory` | Load all images in CPU memory. | `bool` | `True` | | `training.data.load_in_memory` | Load all images in CPU memory. | `bool` | `True` |
| `training.data.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` | | `training.data.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` |
| `training.data.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) | | `training.data.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) |
| `training.data.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) | | `training.data.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) |
| `training.output_folder` | Directory for checkpoint and results. | `str` | | | `training.output_folder` | Directory for checkpoint and results. | `str` | |
| `training.max_nb_epochs` | Maximum number of epochs before stopping training. | `int` | `800` | | `training.max_nb_epochs` | Maximum number of epochs before stopping training. | `int` | `800` |
| `training.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` | | `training.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` |
| `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` | | `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` |
| `training.device.ddp_port` | DDP port. | `int` | `20027` | | `training.device.ddp_port` | DDP port. | `int` | `20027` |
| `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` | | `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` |
| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | | | `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | |
| `training.device.force` | Use a specific device if available. Use `cpu` to train on CPU (for debugging) or `cuda`/`cuda:$gpu_device` to train on GPU. | `str` | | | `training.device.force` | Use a specific device if available. Use `cpu` to train on CPU (for debugging) or `cuda`/`cuda:$gpu_device` to train on GPU. | `str` | |
| `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` | | `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` |
| `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` | | `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` |
| `training.lr_schedulers` | Learning rate schedulers. | custom class | | | `training.lr_schedulers` | Learning rate schedulers. | custom class | |
| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` | | `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` |
| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` | | `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` |
| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | | | `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | |
| `training.metrics.train` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "wer", "wer_no_punct"]` | | `training.metrics.train` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` |
| `training.metrics.eval` | List of metrics to compute during validation. | `list` | `["cer", "wer", "wer_no_punct"]` | | `training.metrics.eval` | List of metrics to compute during validation. | `list` | `["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` |
| `training.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` | | `training.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` |
| `training.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` | | `training.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` |
| `training.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` | | `training.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` |
| `training.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]` | | `training.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]` |
| `training.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, False]` | | `training.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, False]` |
- To train on several GPUs, simply set the `training.use_ddp` parameter to `True`. By default, the model will use all available GPUs. To restrict access to fewer GPUs, one can modify the `training.nb_gpu` parameter. - To train on several GPUs, simply set the `training.use_ddp` parameter to `True`. By default, the model will use all available GPUs. To restrict access to fewer GPUs, one can modify the `training.nb_gpu` parameter.
- During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations. - During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations.
......
...@@ -16,28 +16,40 @@ from tests import FIXTURES ...@@ -16,28 +16,40 @@ from tests import FIXTURES
{ {
"nb_chars": 43, "nb_chars": 43,
"cer": 1.3023, "cer": 1.3023,
"nb_chars_no_token": 43,
"cer_no_token": 1.3023,
"nb_words": 9, "nb_words": 9,
"wer": 1.0, "wer": 1.0,
"nb_words_no_punct": 9, "nb_words_no_punct": 9,
"wer_no_punct": 1.0, "wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2, "nb_samples": 2,
}, },
{ {
"nb_chars": 41, "nb_chars": 41,
"cer": 1.2683, "cer": 1.2683,
"nb_chars_no_token": 41,
"cer_no_token": 1.2683,
"nb_words": 9, "nb_words": 9,
"wer": 1.0, "wer": 1.0,
"nb_words_no_punct": 9, "nb_words_no_punct": 9,
"wer_no_punct": 1.0, "wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2, "nb_samples": 2,
}, },
{ {
"nb_chars": 49, "nb_chars": 49,
"cer": 1.1224, "cer": 1.1224,
"nb_chars_no_token": 49,
"cer_no_token": 1.1224,
"nb_words": 9, "nb_words": 9,
"wer": 1.0, "wer": 1.0,
"nb_words_no_punct": 9, "nb_words_no_punct": 9,
"wer_no_punct": 1.0, "wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2, "nb_samples": 2,
}, },
), ),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment