Skip to content
Snippets Groups Projects
Commit 2c63c01c authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Mélodie Boillet
Browse files

Remove post processing as it's not used

parent e2235375
No related branches found
No related tags found
1 merge request!235Remove post processing as it's not used
# -*- coding: utf-8 -*-
import re
from operator import attrgetter
from pathlib import Path
from typing import Optional
import editdistance
import numpy as np
from dan.post_processing import PostProcessingModuleSIMARA
from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS
from dan.datasets.extract.utils import parse_tokens
class MetricManager:
def __init__(self, metric_names, dataset_name):
def __init__(self, metric_names, dataset_name, tokens: Optional[Path]):
self.dataset_name = dataset_name
if "simara" in dataset_name and "page" in dataset_name:
self.post_processing_module = PostProcessingModuleSIMARA
self.matching_tokens = SIMARA_MATCHING_TOKENS
else:
self.matching_tokens = dict()
self.layout_tokens = "".join(
list(self.matching_tokens.keys()) + list(self.matching_tokens.values())
)
if len(self.layout_tokens) == 0:
self.layout_tokens = None
self.layout_tokens = None
if tokens:
tokens = parse_tokens(tokens)
self.layout_tokens = "".join(
list(map(attrgetter("start"), tokens.values()))
+ list(map(attrgetter("end"), tokens.values()))
)
self.metric_names = metric_names
self.epoch_metrics = None
......
......@@ -60,6 +60,7 @@ class GenericTrainingManager:
if self.params["training_params"]["use_ddp"]
else 1
)
self.tokens = self.params["dataset_params"].get("tokens")
def init_paths(self):
"""
......@@ -617,7 +618,9 @@ class GenericTrainingManager:
] = self.latest_epoch
# init epoch metrics values
self.metric_manager["train"] = MetricManager(
metric_names=metric_names, dataset_name=self.dataset_name
metric_names=metric_names,
dataset_name=self.dataset_name,
tokens=self.tokens,
)
with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar:
pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs))
......@@ -738,7 +741,9 @@ class GenericTrainingManager:
# initialize epoch metrics
self.metric_manager[set_name] = MetricManager(
metric_names, dataset_name=self.dataset_name
metric_names=metric_names,
dataset_name=self.dataset_name,
tokens=self.tokens,
)
with tqdm(total=len(loader.dataset)) as pbar:
pbar.set_description("Evaluation E{}".format(self.latest_epoch))
......@@ -787,7 +792,9 @@ class GenericTrainingManager:
# initialize epoch metrics
self.metric_manager[custom_name] = MetricManager(
metric_names, self.dataset_name
metric_names=metric_names,
dataset_name=self.dataset_name,
tokens=self.tokens,
)
with tqdm(total=len(loader.dataset)) as pbar:
......
......@@ -114,6 +114,7 @@ def get_config():
],
"augmentation": True,
},
"tokens": None,
},
"model_params": {
"models": {
......
# -*- coding: utf-8 -*-
import numpy as np
from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS
class PostProcessingModule:
"""
Forward pass post processing
Add/remove layout tokens only to:
- respect token hierarchy
- complete/remove unpaired tokens
"""
def __init__(self):
self.prediction = None
self.confidence = None
def post_processing(self):
raise NotImplementedError
def post_process(self, prediction, confidence_score=None):
"""
Apply dataset-specific post-processing
"""
self.prediction = list(prediction)
self.confidence = (
list(confidence_score) if confidence_score is not None else None
)
if self.confidence is not None:
assert len(self.prediction) == len(self.confidence)
return self.post_processing()
def insert_label(self, index, label):
"""
Insert token at specific index. The associated confidence score is set to 0.
"""
self.prediction.insert(index, label)
if self.confidence is not None:
self.confidence.insert(index, 0)
def del_label(self, index):
"""
Remove the token at a specific index.
"""
del self.prediction[index]
if self.confidence is not None:
del self.confidence[index]
class PostProcessingModuleSIMARA(PostProcessingModule):
"""
Specific post-processing for the SIMARA dataset at page level
"""
def __init__(self):
super(PostProcessingModuleSIMARA, self).__init__()
self.matching_tokens = SIMARA_MATCHING_TOKENS
self.reverse_matching_tokens = dict()
for key in self.matching_tokens:
self.reverse_matching_tokens[self.matching_tokens[key]] = key
def post_processing(self):
ind = 0
begin_token = None
while ind != len(self.prediction):
char = self.prediction[ind]
# a tag must be closed before starting a new one
if char in self.matching_tokens.keys():
if begin_token is None:
ind += 1
else:
self.insert_label(ind, self.matching_tokens[begin_token])
ind += 2
begin_token = char
continue
# an end token without prior corresponding begin token is removed
elif char in self.matching_tokens.values():
if begin_token == self.reverse_matching_tokens[char]:
ind += 1
begin_token = None
else:
self.del_label(ind)
continue
else:
ind += 1
# a tag must be closed
if begin_token is not None:
self.insert_label(ind + 1, self.matching_tokens[begin_token])
res = "".join(self.prediction)
if self.confidence is not None:
return res, np.array(self.confidence)
return res
......@@ -4,9 +4,6 @@ from itertools import islice
import torch
import torchvision.io as torchvision
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
class MLflowNotInstalled(Exception):
"""
......
# Post processing
::: dan.post_processing
......@@ -4,16 +4,17 @@ All hyperparameters are specified and editable in the training scripts `dan/ocr/
## Dataset parameters
| Parameter | Description | Type | Default |
| -------------------------------------- | -------------------------------------------------------------------------------------- | ------ | ---------------------------------------------------- |
| `dataset_name` | Name of the dataset. | `str` | |
| `dataset_level` | Level of the dataset. Should be named after the element type. | `str` | |
| `dataset_variant` | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets. | `str` | |
| `dataset_path` | Path to the dataset. | `str` | |
| `dataset_params.config.load_in_memory` | Load all images in CPU memory. | `bool` | `True` |
| `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` |
| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) |
| `dataset_params.config.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) |
| Parameter | Description | Type | Default |
| -------------------------------------- | --------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------------------- |
| `dataset_name` | Name of the dataset. | `str` | |
| `dataset_level` | Level of the dataset. Should be named after the element type. | `str` | |
| `dataset_variant` | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets. | `str` | |
| `dataset_path` | Path to the dataset. | `str` | |
| `dataset_params.config.load_in_memory` | Load all images in CPU memory. | `bool` | `True` |
| `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` |
| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) |
| `dataset_params.config.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) |
| `dataset_params.tokens` | Path to a NER tokens configuration file similar to [the one used for extraction](../datasets/extract.md#description). | `pathlib.Path` | None |
!!! warning
The variables `dataset_name`, `dataset_level`, `dataset_variant` and `dataset_path` must have values such that the data is located in `{dataset_path}/{dataset_name}_{dataset_level}{dataset_variant}`.
......
......@@ -100,7 +100,6 @@ nav:
- Decoders: ref/decoder.md
- Models: ref/encoder.md
- MLflow: ref/mlflow.md
- Post Processing: ref/post_processing.md
- Schedulers: ref/schedulers.md
- Transformations: ref/transforms.md
- Utils: ref/utils.md
......
......@@ -60,6 +60,7 @@ def training_config():
],
"augmentation": True,
},
"tokens": None,
},
"model_params": {
"models": {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment