From e0cb32b0f9a049501a0a21d7b4b6b5397d01363b Mon Sep 17 00:00:00 2001
From: Manon Blanco <blanco@teklia.com>
Date: Wed, 4 Oct 2023 09:51:31 +0000
Subject: [PATCH] Charset should only include training characters

---
 dan/datasets/extract/__init__.py              |  6 ++
 dan/datasets/extract/exceptions.py            |  9 +++
 dan/datasets/extract/extract.py               | 32 +++++++++-
 docs/usage/datasets/extract.md                |  1 +
 .../elements/test-page_1-line_1.json          |  2 +-
 .../elements/test-page_1-line_3.json          |  2 +-
 .../elements/test-page_2-line_1.json          |  2 +-
 .../elements/train-page_2-line_2.json         |  2 +-
 .../elements/train-page_2-line_3.json         |  4 +-
 .../elements/val-page_1-line_1.json           |  2 +-
 tests/test_extract.py                         | 59 ++++++++++++++-----
 11 files changed, 96 insertions(+), 25 deletions(-)

diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py
index 6ea479c6..8a4def01 100644
--- a/dan/datasets/extract/__init__.py
+++ b/dan/datasets/extract/__init__.py
@@ -109,6 +109,12 @@ def add_extract_parser(subcommands) -> None:
         required=False,
         default=list(map(validate_char, ("\n", " "))),
     )
+    parser.add_argument(
+        "--unknown-token",
+        type=str,
+        help="Token to use to replace character in the validation/test sets that is not included in the training set.",
+        default="⁇",
+    )
     parser.add_argument(
         "--tokens",
         type=pathlib.Path,
diff --git a/dan/datasets/extract/exceptions.py b/dan/datasets/extract/exceptions.py
index 2155a6ca..74e1b332 100644
--- a/dan/datasets/extract/exceptions.py
+++ b/dan/datasets/extract/exceptions.py
@@ -44,6 +44,15 @@ class NoTranscriptionError(ElementProcessingError):
         return f"No transcriptions found on element ({self.element_id}) with this config. Skipping."
 
 
+class UnknownTokenInText(ElementProcessingError):
+    """
+    Raised when the unknown token is found in a transcription text
+    """
+
+    def __str__(self) -> str:
+        return f"Unknown token found in the transcription text of element ({self.element_id})"
+
+
 class NoEndTokenError(ProcessingError):
     """
     Raised when the specified label has no end token and there is potentially additional text around the labels
diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py
index 440986ec..e6847e0c 100644
--- a/dan/datasets/extract/extract.py
+++ b/dan/datasets/extract/extract.py
@@ -27,6 +27,7 @@ from dan.datasets.extract.exceptions import (
     NoEndTokenError,
     NoTranscriptionError,
     ProcessingError,
+    UnknownTokenInText,
 )
 from dan.datasets.extract.utils import (
     download_image,
@@ -44,7 +45,8 @@ from line_image_extractor.image_utils import (
 
 IMAGES_DIR = "images"  # Subpath to the images directory.
 
-SPLIT_NAMES = ["train", "val", "test"]
+TRAIN_NAME = "train"
+SPLIT_NAMES = [TRAIN_NAME, "val", "test"]
 IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg"
 # IIIF 2.0 uses `full`
 IIIF_FULL_SIZE = "full"
@@ -64,6 +66,7 @@ class ArkindexExtractor:
         parent_element_type: str = None,
         output: Path = None,
         entity_separators: List[str] = ["\n", " "],
+        unknown_token: str = "⁇",
         tokens: Path = None,
         transcription_worker_version: Optional[Union[str, bool]] = None,
         entity_worker_version: Optional[Union[str, bool]] = None,
@@ -77,6 +80,7 @@ class ArkindexExtractor:
         self.parent_element_type = parent_element_type
         self.output = output
         self.entity_separators = entity_separators
+        self.unknown_token = unknown_token
         self.tokens = parse_tokens(tokens) if tokens else None
         self.transcription_worker_version = transcription_worker_version
         self.entity_worker_version = entity_worker_version
@@ -245,9 +249,20 @@ class ArkindexExtractor:
                 split=split, path=str(destination), url=download_url, exc=e
             )
 
-    def format_text(self, text: str):
+    def format_text(self, text: str, charset: Optional[set] = None):
         if not self.keep_spaces:
             text = remove_spaces(text)
+
+        # Replace unknown characters by the unknown token
+        if charset is not None:
+            unknown_charset = set(text) - charset
+            text = text.translate(
+                {
+                    ord(unknown_char): self.unknown_token
+                    for unknown_char in unknown_charset
+                }
+            )
+
         return text.strip()
 
     def process_element(
@@ -261,6 +276,9 @@ class ArkindexExtractor:
         """
         text = self.extract_transcription(element)
 
+        if self.unknown_token in text:
+            raise UnknownTokenInText(element_id=element.id)
+
         image_path = Path(self.output, IMAGES_DIR, split, element.id).with_suffix(
             self.image_extension
         )
@@ -276,7 +294,13 @@ class ArkindexExtractor:
                 }
             )
 
-        self.data[split][str(image_path)] = self.format_text(text)
+        text = self.format_text(
+            text,
+            # Do not replace unknown characters in train split
+            charset=self.charset if split != TRAIN_NAME else None,
+        )
+
+        self.data[split][str(image_path)] = text
         self.charset = self.charset.union(set(text))
 
     def process_parent(
@@ -390,6 +414,7 @@ def run(
     parent_element_type: str,
     output: Path,
     entity_separators: List[str],
+    unknown_token: str,
     tokens: Path,
     train_folder: UUID,
     val_folder: UUID,
@@ -416,6 +441,7 @@ def run(
         parent_element_type=parent_element_type,
         output=output,
         entity_separators=entity_separators,
+        unknown_token=unknown_token,
         tokens=tokens,
         transcription_worker_version=transcription_worker_version,
         entity_worker_version=entity_worker_version,
diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md
index ae0c9059..a7715d59 100644
--- a/docs/usage/datasets/extract.md
+++ b/docs/usage/datasets/extract.md
@@ -17,6 +17,7 @@ If an image download fails for whatever reason, it won't appear in the transcrip
 | `--parent-element-type`          | Type of the parent element containing the data.                                                                                                                                                                                      | `str`           | `page`                                             |
 | `--output`                       | Folder where the data will be generated.                                                                                                                                                                                             | `Path`          |                                                    |
 | `--entity-separators`            | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text. | `str`           | `["\n", " "]` (see [dedicated section](#examples)) |
+| `--unknown-token`                | Token to use to replace character in the validation/test sets that is not included in the training set.                                                                                                                              | `str`           | `⁇`                                                |
 | `--tokens`                       | Mapping between starting tokens and end tokens to extract text with their entities.                                                                                                                                                  | `Path`          |                                                    |
 | `--train-folder`                 | ID of the training folder to import from Arkindex.                                                                                                                                                                                   | `uuid`          |                                                    |
 | `--val-folder`                   | ID of the validation folder to import from Arkindex.                                                                                                                                                                                 | `uuid`          |                                                    |
diff --git a/tests/data/extraction/elements/test-page_1-line_1.json b/tests/data/extraction/elements/test-page_1-line_1.json
index ef63f940..716bf50a 100644
--- a/tests/data/extraction/elements/test-page_1-line_1.json
+++ b/tests/data/extraction/elements/test-page_1-line_1.json
@@ -14,7 +14,7 @@
             "offset": 0
         },
         {
-            "name": "Louis",
+            "name": "Bouis",
             "type": "firstname",
             "offset": 7
         },
diff --git a/tests/data/extraction/elements/test-page_1-line_3.json b/tests/data/extraction/elements/test-page_1-line_3.json
index f196d3cf..78fbe786 100644
--- a/tests/data/extraction/elements/test-page_1-line_3.json
+++ b/tests/data/extraction/elements/test-page_1-line_3.json
@@ -14,7 +14,7 @@
             "offset": 0
         },
         {
-            "name": "François",
+            "name": "Français",
             "type": "firstname",
             "offset": 7
         },
diff --git a/tests/data/extraction/elements/test-page_2-line_1.json b/tests/data/extraction/elements/test-page_2-line_1.json
index a9b2498f..9d5c131b 100644
--- a/tests/data/extraction/elements/test-page_2-line_1.json
+++ b/tests/data/extraction/elements/test-page_2-line_1.json
@@ -14,7 +14,7 @@
             "offset": 0
         },
         {
-            "name": "Louis",
+            "name": "Bouis",
             "type": "firstname",
             "offset": 8
         },
diff --git a/tests/data/extraction/elements/train-page_2-line_2.json b/tests/data/extraction/elements/train-page_2-line_2.json
index c54e23b6..bc6829da 100644
--- a/tests/data/extraction/elements/train-page_2-line_2.json
+++ b/tests/data/extraction/elements/train-page_2-line_2.json
@@ -9,7 +9,7 @@
     ],
     "transcription_entities": [
         {
-            "name": "Roques",
+            "name": "Amical",
             "type": "surname",
             "offset": 0
         },
diff --git a/tests/data/extraction/elements/train-page_2-line_3.json b/tests/data/extraction/elements/train-page_2-line_3.json
index 200e3b3d..90432163 100644
--- a/tests/data/extraction/elements/train-page_2-line_3.json
+++ b/tests/data/extraction/elements/train-page_2-line_3.json
@@ -9,12 +9,12 @@
     ],
     "transcription_entities": [
         {
-            "name": "Giros",
+            "name": "Biros",
             "type": "surname",
             "offset": 0
         },
         {
-            "name": "Paul",
+            "name": "Mael",
             "type": "firstname",
             "offset": 6
         },
diff --git a/tests/data/extraction/elements/val-page_1-line_1.json b/tests/data/extraction/elements/val-page_1-line_1.json
index 371f730d..d2556444 100644
--- a/tests/data/extraction/elements/val-page_1-line_1.json
+++ b/tests/data/extraction/elements/val-page_1-line_1.json
@@ -14,7 +14,7 @@
             "offset": 0
         },
         {
-            "name": "Louis",
+            "name": "Bouis",
             "type": "firstname",
             "offset": 7
         },
diff --git a/tests/test_extract.py b/tests/test_extract.py
index 24ec2905..52981000 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -11,7 +11,8 @@ from unittest.mock import patch
 import pytest
 from PIL import Image, ImageChops
 
-from dan.datasets.extract.exceptions import NoEndTokenError
+from arkindex_export import Element, Transcription
+from dan.datasets.extract.exceptions import NoEndTokenError, UnknownTokenInText
 from dan.datasets.extract.extract import IIIF_FULL_SIZE, ArkindexExtractor
 from dan.datasets.extract.utils import EntityType, insert_token, remove_spaces
 from dan.utils import parse_tokens
@@ -259,6 +260,34 @@ def test_reconstruct_text_only_start_token(joined, text_before, text_after):
     )
 
 
+def test_process_element_unknown_token_in_text_error(mock_database, tmp_path):
+    output = tmp_path / "extraction"
+    arkindex_extractor = ArkindexExtractor(output=output)
+
+    # Create an element with an invalid transcription
+    element = Element.create(
+        id="element_id",
+        name="1",
+        type="page",
+        polygon="[]",
+        created=0.0,
+        updated=0.0,
+    )
+    Transcription.create(
+        id="transcription_id",
+        text="Is this text valid⁇",
+        element=element,
+    )
+
+    with pytest.raises(
+        UnknownTokenInText,
+        match=re.escape(
+            "Unknown token found in the transcription text of element (element_id)"
+        ),
+    ):
+        arkindex_extractor.process_element(element, "val")
+
+
 @pytest.mark.parametrize("load_entities", (True, False))
 @pytest.mark.parametrize("keep_spaces", (True, False))
 # Transcription and entities have the same worker version
@@ -343,12 +372,12 @@ def test_extract(
     # Check "labels.json"
     expected_labels = {
         "test": {
-            str(TEST_DIR / "test-page_1-line_1.jpg"): "â“¢Coupez  â“•Louis  â“‘7.12.14",
-            str(TEST_DIR / "test-page_1-line_2.jpg"): "â“¢Poutrain  â“•Adolphe  â“‘9.4.13",
-            str(TEST_DIR / "test-page_1-line_3.jpg"): "ⓢGabale  ⓕFrançois  ⓑ26.3.11",
-            str(TEST_DIR / "test-page_2-line_1.jpg"): "â“¢Durosoy  â“•Louis  â“‘22-4-18",
-            str(TEST_DIR / "test-page_2-line_2.jpg"): "â“¢Colaiani  â“•Angels  â“‘28.11.17",
-            str(TEST_DIR / "test-page_2-line_3.jpg"): "â“¢Renouard  â“•Maurice  â“‘25.7.04",
+            str(TEST_DIR / "test-page_1-line_1.jpg"): "ⓢCou⁇e⁇  ⓕBouis  ⓑ⁇.12.14",
+            str(TEST_DIR / "test-page_1-line_2.jpg"): "ⓢ⁇outrain  ⓕA⁇ol⁇⁇e  ⓑ9.4.13",
+            str(TEST_DIR / "test-page_1-line_3.jpg"): "ⓢ⁇abale  ⓕ⁇ran⁇ais  ⓑ26.3.11",
+            str(TEST_DIR / "test-page_2-line_1.jpg"): "ⓢ⁇urosoy  ⓕBouis  ⓑ22⁇4⁇18",
+            str(TEST_DIR / "test-page_2-line_2.jpg"): "ⓢColaiani  ⓕAn⁇els  ⓑ28.11.1⁇",
+            str(TEST_DIR / "test-page_2-line_3.jpg"): "ⓢRenouar⁇  ⓕMaurice  ⓑ2⁇.⁇.04",
         },
         "train": {
             str(TRAIN_DIR / "train-page_1-line_1.jpg"): "â“¢Caillet  â“•Maurice  â“‘28.9.06",
@@ -356,13 +385,13 @@ def test_extract(
             str(TRAIN_DIR / "train-page_1-line_3.jpg"): "â“¢Bareyre  â“•Jean  â“‘28.3.11",
             str(TRAIN_DIR / "train-page_1-line_4.jpg"): "â“¢Roussy  â“•Jean  â“‘4.11.14",
             str(TRAIN_DIR / "train-page_2-line_1.jpg"): "â“¢Marin  â“•Marcel  â“‘10.8.06",
-            str(TRAIN_DIR / "train-page_2-line_2.jpg"): "â“¢Roques  â“•Eloi  â“‘11.10.04",
-            str(TRAIN_DIR / "train-page_2-line_3.jpg"): "â“¢Giros  â“•Paul  â“‘30.10.10",
+            str(TRAIN_DIR / "train-page_2-line_2.jpg"): "â“¢Amical  â“•Eloi  â“‘11.10.04",
+            str(TRAIN_DIR / "train-page_2-line_3.jpg"): "â“¢Biros  â“•Mael  â“‘30.10.10",
         },
         "val": {
-            str(VAL_DIR / "val-page_1-line_1.jpg"): "â“¢Monard  â“•Louis  â“‘29-7-04",
-            str(VAL_DIR / "val-page_1-line_2.jpg"): "â“¢Astier  â“•Arthur  â“‘11-2-13",
-            str(VAL_DIR / "val-page_1-line_3.jpg"): "â“¢De Vlieger  â“•Jules  â“‘21-11-11",
+            str(VAL_DIR / "val-page_1-line_1.jpg"): "ⓢMonar⁇  ⓕBouis  ⓑ29⁇⁇⁇04",
+            str(VAL_DIR / "val-page_1-line_2.jpg"): "ⓢAstier  ⓕArt⁇ur  ⓑ11⁇2⁇13",
+            str(VAL_DIR / "val-page_1-line_3.jpg"): "ⓢ⁇e ⁇lie⁇er  ⓕJules  ⓑ21⁇11⁇11",
         },
     }
 
@@ -393,12 +422,12 @@ def test_extract(
 
     # Check "charset.pkl"
     expected_charset = set()
-    for labels in expected_labels.values():
-        for label in labels.values():
-            expected_charset.update(set(label))
+    for label in expected_labels["train"].values():
+        expected_charset.update(set(label))
 
     if load_entities:
         expected_charset.update(tokens)
+    expected_charset.add("⁇")
     assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset
 
     # Check cropped images
-- 
GitLab