From 00d54e0cb330fd22fecfb111aee8af6d18bf885c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Mon, 9 Oct 2023 12:43:23 +0200 Subject: [PATCH] Deal with unknown token separately --- dan/datasets/extract/extract.py | 4 ++-- dan/datasets/extract/utils.py | 2 +- dan/ocr/decoder.py | 2 +- dan/utils.py | 1 - tests/test_extract.py | 4 ++-- 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index 0d7002b5..a7fa8d32 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -36,13 +36,13 @@ from dan.datasets.extract.utils import ( normalize_linebreaks, normalize_spaces, ) +from dan.utils import EntityType, LMTokenMapping, parse_tokens +from line_image_extractor.extractor import extract from line_image_extractor.image_utils import ( BoundingBox, Extraction, polygon_to_bbox, ) -from dan.utils import EntityType, LMTokenMapping, parse_tokens -from line_image_extractor.extractor import extract IMAGES_DIR = "images" # Subpath to the images directory. LANGUAGE_DIR = "language_model" # Subpath to the language model directory. diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 5e867a39..a2184f07 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -81,6 +81,7 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) - + (entity_type.end if entity_type else "") ) + def normalize_linebreaks(text: str) -> str: """ Remove begin/ending linebreaks @@ -106,4 +107,3 @@ def get_bbox(polygon: List[List[int]]) -> str: x, y = min(all_x), min(all_y) width, height = max(all_x) - x, max(all_y) - y return ",".join(list(map(str, [int(x), int(y), int(width), int(height)]))) - diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index bd858c16..b4da94d0 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -505,8 +505,8 @@ class CTCLanguageDecoder: tokens=tokens_path, lm_weight=self.language_model_weight, blank_token=self.mapping.ctc.encoded, - unk_word=self.mapping.unknown.encoded, sil_token=self.mapping.space.encoded, + unk_word="â‡", nbest=1, ) # No GPU support diff --git a/dan/utils.py b/dan/utils.py index c65df263..69e7d82a 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -25,7 +25,6 @@ class LMTokenMapping(NamedTuple): space: Token = Token("⎵", " ") linebreak: Token = Token("↵", "\n") ctc: Token = Token("â—Œ", "<ctc>") - unknown: Token = Token("â‡", "<unk>") @property def display(self): diff --git a/tests/test_extract.py b/tests/test_extract.py index 0a5e0955..cfd78846 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -470,8 +470,8 @@ def test_extract( â“¢ B a r e y r e ⎵ ⎵ â“• J e a n ⎵ ⎵ â“‘ 2 8 . 3 . 1 1 â“¢ R o u s s y ⎵ ⎵ â“• J e a n ⎵ ⎵ â“‘ 4 . 1 1 . 1 4 â“¢ M a r i n ⎵ ⎵ â“• M a r c e l ⎵ ⎵ â“‘ 1 0 . 8 . 0 6 -â“¢ R o q u e s ⎵ ⎵ â“• E l o i ⎵ ⎵ â“‘ 1 1 . 1 0 . 0 4 -â“¢ G i r o s ⎵ ⎵ â“• P a u l ⎵ ⎵ â“‘ 3 0 . 1 0 . 1 0""" +â“¢ A m i c a l ⎵ ⎵ â“• E l o i ⎵ ⎵ â“‘ 1 1 . 1 0 . 0 4 +â“¢ B i r o s ⎵ ⎵ â“• M a e l ⎵ ⎵ â“‘ 3 0 . 1 0 . 1 0""" # Transcriptions with worker version are in lowercase if transcription_entities_worker_version: -- GitLab