From 2c63c01cabfb8df3c2da0dcb4713256f5e09b2d4 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Tue, 8 Aug 2023 13:49:58 +0000 Subject: [PATCH] Remove post processing as it's not used --- dan/manager/metrics.py | 26 +++++----- dan/manager/training.py | 13 +++-- dan/ocr/document/train.py | 1 + dan/post_processing.py | 93 ---------------------------------- dan/utils.py | 3 -- docs/ref/post_processing.md | 3 -- docs/usage/train/parameters.md | 21 ++++---- mkdocs.yml | 1 - tests/conftest.py | 1 + 9 files changed, 35 insertions(+), 127 deletions(-) delete mode 100644 dan/post_processing.py delete mode 100644 docs/ref/post_processing.md diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py index 0b2472d9..015c0234 100644 --- a/dan/manager/metrics.py +++ b/dan/manager/metrics.py @@ -1,28 +1,26 @@ # -*- coding: utf-8 -*- import re +from operator import attrgetter +from pathlib import Path +from typing import Optional import editdistance import numpy as np -from dan.post_processing import PostProcessingModuleSIMARA -from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS +from dan.datasets.extract.utils import parse_tokens class MetricManager: - def __init__(self, metric_names, dataset_name): + def __init__(self, metric_names, dataset_name, tokens: Optional[Path]): self.dataset_name = dataset_name - if "simara" in dataset_name and "page" in dataset_name: - self.post_processing_module = PostProcessingModuleSIMARA - self.matching_tokens = SIMARA_MATCHING_TOKENS - else: - self.matching_tokens = dict() - - self.layout_tokens = "".join( - list(self.matching_tokens.keys()) + list(self.matching_tokens.values()) - ) - if len(self.layout_tokens) == 0: - self.layout_tokens = None + self.layout_tokens = None + if tokens: + tokens = parse_tokens(tokens) + self.layout_tokens = "".join( + list(map(attrgetter("start"), tokens.values())) + + list(map(attrgetter("end"), tokens.values())) + ) self.metric_names = metric_names self.epoch_metrics = None diff --git a/dan/manager/training.py b/dan/manager/training.py index 8f4de234..735b7100 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -60,6 +60,7 @@ class GenericTrainingManager: if self.params["training_params"]["use_ddp"] else 1 ) + self.tokens = self.params["dataset_params"].get("tokens") def init_paths(self): """ @@ -617,7 +618,9 @@ class GenericTrainingManager: ] = self.latest_epoch # init epoch metrics values self.metric_manager["train"] = MetricManager( - metric_names=metric_names, dataset_name=self.dataset_name + metric_names=metric_names, + dataset_name=self.dataset_name, + tokens=self.tokens, ) with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar: pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs)) @@ -738,7 +741,9 @@ class GenericTrainingManager: # initialize epoch metrics self.metric_manager[set_name] = MetricManager( - metric_names, dataset_name=self.dataset_name + metric_names=metric_names, + dataset_name=self.dataset_name, + tokens=self.tokens, ) with tqdm(total=len(loader.dataset)) as pbar: pbar.set_description("Evaluation E{}".format(self.latest_epoch)) @@ -787,7 +792,9 @@ class GenericTrainingManager: # initialize epoch metrics self.metric_manager[custom_name] = MetricManager( - metric_names, self.dataset_name + metric_names=metric_names, + dataset_name=self.dataset_name, + tokens=self.tokens, ) with tqdm(total=len(loader.dataset)) as pbar: diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index b6146a94..712c8a34 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -114,6 +114,7 @@ def get_config(): ], "augmentation": True, }, + "tokens": None, }, "model_params": { "models": { diff --git a/dan/post_processing.py b/dan/post_processing.py deleted file mode 100644 index 180d021f..00000000 --- a/dan/post_processing.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np - -from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS - - -class PostProcessingModule: - """ - Forward pass post processing - Add/remove layout tokens only to: - - respect token hierarchy - - complete/remove unpaired tokens - """ - - def __init__(self): - self.prediction = None - self.confidence = None - - def post_processing(self): - raise NotImplementedError - - def post_process(self, prediction, confidence_score=None): - """ - Apply dataset-specific post-processing - """ - self.prediction = list(prediction) - self.confidence = ( - list(confidence_score) if confidence_score is not None else None - ) - if self.confidence is not None: - assert len(self.prediction) == len(self.confidence) - return self.post_processing() - - def insert_label(self, index, label): - """ - Insert token at specific index. The associated confidence score is set to 0. - """ - self.prediction.insert(index, label) - if self.confidence is not None: - self.confidence.insert(index, 0) - - def del_label(self, index): - """ - Remove the token at a specific index. - """ - del self.prediction[index] - if self.confidence is not None: - del self.confidence[index] - - -class PostProcessingModuleSIMARA(PostProcessingModule): - """ - Specific post-processing for the SIMARA dataset at page level - """ - - def __init__(self): - super(PostProcessingModuleSIMARA, self).__init__() - self.matching_tokens = SIMARA_MATCHING_TOKENS - self.reverse_matching_tokens = dict() - for key in self.matching_tokens: - self.reverse_matching_tokens[self.matching_tokens[key]] = key - - def post_processing(self): - ind = 0 - begin_token = None - while ind != len(self.prediction): - char = self.prediction[ind] - # a tag must be closed before starting a new one - if char in self.matching_tokens.keys(): - if begin_token is None: - ind += 1 - else: - self.insert_label(ind, self.matching_tokens[begin_token]) - ind += 2 - begin_token = char - continue - # an end token without prior corresponding begin token is removed - elif char in self.matching_tokens.values(): - if begin_token == self.reverse_matching_tokens[char]: - ind += 1 - begin_token = None - else: - self.del_label(ind) - continue - else: - ind += 1 - # a tag must be closed - if begin_token is not None: - self.insert_label(ind + 1, self.matching_tokens[begin_token]) - res = "".join(self.prediction) - if self.confidence is not None: - return res, np.array(self.confidence) - return res diff --git a/dan/utils.py b/dan/utils.py index 84e73824..d58db8fa 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -4,9 +4,6 @@ from itertools import islice import torch import torchvision.io as torchvision -# Layout begin-token to end-token -SEM_MATCHING_TOKENS = {"ⓘ": "â’¾", "â““": "â’¹", "â“¢": "Ⓢ", "â“’": "â’¸", "ⓟ": "â“…", "â“": "â’¶"} - class MLflowNotInstalled(Exception): """ diff --git a/docs/ref/post_processing.md b/docs/ref/post_processing.md deleted file mode 100644 index 2c92ad28..00000000 --- a/docs/ref/post_processing.md +++ /dev/null @@ -1,3 +0,0 @@ -# Post processing - -::: dan.post_processing diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md index 731e970d..b031dab9 100644 --- a/docs/usage/train/parameters.md +++ b/docs/usage/train/parameters.md @@ -4,16 +4,17 @@ All hyperparameters are specified and editable in the training scripts `dan/ocr/ ## Dataset parameters -| Parameter | Description | Type | Default | -| -------------------------------------- | -------------------------------------------------------------------------------------- | ------ | ---------------------------------------------------- | -| `dataset_name` | Name of the dataset. | `str` | | -| `dataset_level` | Level of the dataset. Should be named after the element type. | `str` | | -| `dataset_variant` | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets. | `str` | | -| `dataset_path` | Path to the dataset. | `str` | | -| `dataset_params.config.load_in_memory` | Load all images in CPU memory. | `bool` | `True` | -| `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` | -| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) | -| `dataset_params.config.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) | +| Parameter | Description | Type | Default | +| -------------------------------------- | --------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------------------- | +| `dataset_name` | Name of the dataset. | `str` | | +| `dataset_level` | Level of the dataset. Should be named after the element type. | `str` | | +| `dataset_variant` | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets. | `str` | | +| `dataset_path` | Path to the dataset. | `str` | | +| `dataset_params.config.load_in_memory` | Load all images in CPU memory. | `bool` | `True` | +| `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` | +| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) | +| `dataset_params.config.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) | +| `dataset_params.tokens` | Path to a NER tokens configuration file similar to [the one used for extraction](../datasets/extract.md#description). | `pathlib.Path` | None | !!! warning The variables `dataset_name`, `dataset_level`, `dataset_variant` and `dataset_path` must have values such that the data is located in `{dataset_path}/{dataset_name}_{dataset_level}{dataset_variant}`. diff --git a/mkdocs.yml b/mkdocs.yml index 79cc177c..253be65f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -100,7 +100,6 @@ nav: - Decoders: ref/decoder.md - Models: ref/encoder.md - MLflow: ref/mlflow.md - - Post Processing: ref/post_processing.md - Schedulers: ref/schedulers.md - Transformations: ref/transforms.md - Utils: ref/utils.md diff --git a/tests/conftest.py b/tests/conftest.py index 2eda7829..c0bcfb8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,6 +60,7 @@ def training_config(): ], "augmentation": True, }, + "tokens": None, }, "model_params": { "models": { -- GitLab