From 2c63c01cabfb8df3c2da0dcb4713256f5e09b2d4 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Tue, 8 Aug 2023 13:49:58 +0000
Subject: [PATCH] Remove post processing as it's not used

---
 dan/manager/metrics.py         | 26 +++++-----
 dan/manager/training.py        | 13 +++--
 dan/ocr/document/train.py      |  1 +
 dan/post_processing.py         | 93 ----------------------------------
 dan/utils.py                   |  3 --
 docs/ref/post_processing.md    |  3 --
 docs/usage/train/parameters.md | 21 ++++----
 mkdocs.yml                     |  1 -
 tests/conftest.py              |  1 +
 9 files changed, 35 insertions(+), 127 deletions(-)
 delete mode 100644 dan/post_processing.py
 delete mode 100644 docs/ref/post_processing.md

diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py
index 0b2472d9..015c0234 100644
--- a/dan/manager/metrics.py
+++ b/dan/manager/metrics.py
@@ -1,28 +1,26 @@
 # -*- coding: utf-8 -*-
 import re
+from operator import attrgetter
+from pathlib import Path
+from typing import Optional
 
 import editdistance
 import numpy as np
 
-from dan.post_processing import PostProcessingModuleSIMARA
-from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS
+from dan.datasets.extract.utils import parse_tokens
 
 
 class MetricManager:
-    def __init__(self, metric_names, dataset_name):
+    def __init__(self, metric_names, dataset_name, tokens: Optional[Path]):
         self.dataset_name = dataset_name
 
-        if "simara" in dataset_name and "page" in dataset_name:
-            self.post_processing_module = PostProcessingModuleSIMARA
-            self.matching_tokens = SIMARA_MATCHING_TOKENS
-        else:
-            self.matching_tokens = dict()
-
-        self.layout_tokens = "".join(
-            list(self.matching_tokens.keys()) + list(self.matching_tokens.values())
-        )
-        if len(self.layout_tokens) == 0:
-            self.layout_tokens = None
+        self.layout_tokens = None
+        if tokens:
+            tokens = parse_tokens(tokens)
+            self.layout_tokens = "".join(
+                list(map(attrgetter("start"), tokens.values()))
+                + list(map(attrgetter("end"), tokens.values()))
+            )
         self.metric_names = metric_names
         self.epoch_metrics = None
 
diff --git a/dan/manager/training.py b/dan/manager/training.py
index 8f4de234..735b7100 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -60,6 +60,7 @@ class GenericTrainingManager:
             if self.params["training_params"]["use_ddp"]
             else 1
         )
+        self.tokens = self.params["dataset_params"].get("tokens")
 
     def init_paths(self):
         """
@@ -617,7 +618,9 @@ class GenericTrainingManager:
                 ] = self.latest_epoch
             # init epoch metrics values
             self.metric_manager["train"] = MetricManager(
-                metric_names=metric_names, dataset_name=self.dataset_name
+                metric_names=metric_names,
+                dataset_name=self.dataset_name,
+                tokens=self.tokens,
             )
             with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar:
                 pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs))
@@ -738,7 +741,9 @@ class GenericTrainingManager:
 
         # initialize epoch metrics
         self.metric_manager[set_name] = MetricManager(
-            metric_names, dataset_name=self.dataset_name
+            metric_names=metric_names,
+            dataset_name=self.dataset_name,
+            tokens=self.tokens,
         )
         with tqdm(total=len(loader.dataset)) as pbar:
             pbar.set_description("Evaluation E{}".format(self.latest_epoch))
@@ -787,7 +792,9 @@ class GenericTrainingManager:
 
         # initialize epoch metrics
         self.metric_manager[custom_name] = MetricManager(
-            metric_names, self.dataset_name
+            metric_names=metric_names,
+            dataset_name=self.dataset_name,
+            tokens=self.tokens,
         )
 
         with tqdm(total=len(loader.dataset)) as pbar:
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index b6146a94..712c8a34 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -114,6 +114,7 @@ def get_config():
                 ],
                 "augmentation": True,
             },
+            "tokens": None,
         },
         "model_params": {
             "models": {
diff --git a/dan/post_processing.py b/dan/post_processing.py
deleted file mode 100644
index 180d021f..00000000
--- a/dan/post_processing.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-
-from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS
-
-
-class PostProcessingModule:
-    """
-    Forward pass post processing
-    Add/remove layout tokens only to:
-     - respect token hierarchy
-     - complete/remove unpaired tokens
-    """
-
-    def __init__(self):
-        self.prediction = None
-        self.confidence = None
-
-    def post_processing(self):
-        raise NotImplementedError
-
-    def post_process(self, prediction, confidence_score=None):
-        """
-        Apply dataset-specific post-processing
-        """
-        self.prediction = list(prediction)
-        self.confidence = (
-            list(confidence_score) if confidence_score is not None else None
-        )
-        if self.confidence is not None:
-            assert len(self.prediction) == len(self.confidence)
-        return self.post_processing()
-
-    def insert_label(self, index, label):
-        """
-        Insert token at specific index. The associated confidence score is set to 0.
-        """
-        self.prediction.insert(index, label)
-        if self.confidence is not None:
-            self.confidence.insert(index, 0)
-
-    def del_label(self, index):
-        """
-        Remove the token at a specific index.
-        """
-        del self.prediction[index]
-        if self.confidence is not None:
-            del self.confidence[index]
-
-
-class PostProcessingModuleSIMARA(PostProcessingModule):
-    """
-    Specific post-processing for the SIMARA dataset at page level
-    """
-
-    def __init__(self):
-        super(PostProcessingModuleSIMARA, self).__init__()
-        self.matching_tokens = SIMARA_MATCHING_TOKENS
-        self.reverse_matching_tokens = dict()
-        for key in self.matching_tokens:
-            self.reverse_matching_tokens[self.matching_tokens[key]] = key
-
-    def post_processing(self):
-        ind = 0
-        begin_token = None
-        while ind != len(self.prediction):
-            char = self.prediction[ind]
-            # a tag must be closed before starting a new one
-            if char in self.matching_tokens.keys():
-                if begin_token is None:
-                    ind += 1
-                else:
-                    self.insert_label(ind, self.matching_tokens[begin_token])
-                    ind += 2
-                begin_token = char
-                continue
-            # an end token without prior corresponding begin token is removed
-            elif char in self.matching_tokens.values():
-                if begin_token == self.reverse_matching_tokens[char]:
-                    ind += 1
-                    begin_token = None
-                else:
-                    self.del_label(ind)
-                continue
-            else:
-                ind += 1
-        # a tag must be closed
-        if begin_token is not None:
-            self.insert_label(ind + 1, self.matching_tokens[begin_token])
-        res = "".join(self.prediction)
-        if self.confidence is not None:
-            return res, np.array(self.confidence)
-        return res
diff --git a/dan/utils.py b/dan/utils.py
index 84e73824..d58db8fa 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -4,9 +4,6 @@ from itertools import islice
 import torch
 import torchvision.io as torchvision
 
-# Layout begin-token to end-token
-SEM_MATCHING_TOKENS = {"ⓘ": "Ⓘ", "ⓓ": "Ⓓ", "ⓢ": "Ⓢ", "ⓒ": "Ⓒ", "ⓟ": "Ⓟ", "ⓐ": "Ⓐ"}
-
 
 class MLflowNotInstalled(Exception):
     """
diff --git a/docs/ref/post_processing.md b/docs/ref/post_processing.md
deleted file mode 100644
index 2c92ad28..00000000
--- a/docs/ref/post_processing.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Post processing
-
-::: dan.post_processing
diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md
index 731e970d..b031dab9 100644
--- a/docs/usage/train/parameters.md
+++ b/docs/usage/train/parameters.md
@@ -4,16 +4,17 @@ All hyperparameters are specified and editable in the training scripts `dan/ocr/
 
 ## Dataset parameters
 
-| Parameter                              | Description                                                                            | Type   | Default                                              |
-| -------------------------------------- | -------------------------------------------------------------------------------------- | ------ | ---------------------------------------------------- |
-| `dataset_name`                         | Name of the dataset.                                                                   | `str`  |                                                      |
-| `dataset_level`                        | Level of the dataset. Should be named after the element type.                          | `str`  |                                                      |
-| `dataset_variant`                      | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets. | `str`  |                                                      |
-| `dataset_path`                         | Path to the dataset.                                                                   | `str`  |                                                      |
-| `dataset_params.config.load_in_memory` | Load all images in CPU memory.                                                         | `bool` | `True`                                               |
-| `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading.                                 | `int`  | `4`                                                  |
-| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images.                             | `list` | (see [dedicated section](#data-preprocessing))       |
-| `dataset_params.config.augmentation`   | Whether to use data augmentation on the training set.                                  | `bool` | `True` (see [dedicated section](#data-augmentation)) |
+| Parameter                              | Description                                                                                                           | Type           | Default                                              |
+| -------------------------------------- | --------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------------------- |
+| `dataset_name`                         | Name of the dataset.                                                                                                  | `str`          |                                                      |
+| `dataset_level`                        | Level of the dataset. Should be named after the element type.                                                         | `str`          |                                                      |
+| `dataset_variant`                      | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets.                                | `str`          |                                                      |
+| `dataset_path`                         | Path to the dataset.                                                                                                  | `str`          |                                                      |
+| `dataset_params.config.load_in_memory` | Load all images in CPU memory.                                                                                        | `bool`         | `True`                                               |
+| `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading.                                                                | `int`          | `4`                                                  |
+| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images.                                                            | `list`         | (see [dedicated section](#data-preprocessing))       |
+| `dataset_params.config.augmentation`   | Whether to use data augmentation on the training set.                                                                 | `bool`         | `True` (see [dedicated section](#data-augmentation)) |
+| `dataset_params.tokens`                | Path to a NER tokens configuration file similar to [the one used for extraction](../datasets/extract.md#description). | `pathlib.Path` | None                                                 |
 
 !!! warning
     The variables `dataset_name`, `dataset_level`, `dataset_variant` and `dataset_path` must have values such that the data is located in `{dataset_path}/{dataset_name}_{dataset_level}{dataset_variant}`.
diff --git a/mkdocs.yml b/mkdocs.yml
index 79cc177c..253be65f 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -100,7 +100,6 @@ nav:
     - Decoders: ref/decoder.md
     - Models: ref/encoder.md
     - MLflow: ref/mlflow.md
-    - Post Processing: ref/post_processing.md
     - Schedulers: ref/schedulers.md
     - Transformations: ref/transforms.md
     - Utils: ref/utils.md
diff --git a/tests/conftest.py b/tests/conftest.py
index 2eda7829..c0bcfb8c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -60,6 +60,7 @@ def training_config():
                 ],
                 "augmentation": True,
             },
+            "tokens": None,
         },
         "model_params": {
             "models": {
-- 
GitLab