From 88720b7434da9c0d26e9a1781d47f1789ad06a2f Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Thu, 20 Jul 2023 11:07:33 +0200
Subject: [PATCH] Support joined entities and entity separators

---
 dan/datasets/extract/__init__.py |  12 ++-
 dan/datasets/extract/extract.py  |  65 +++++++-----
 tests/test_extract.py            | 177 +++++++++++++++++++++++++++----
 3 files changed, 203 insertions(+), 51 deletions(-)

diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py
index 93c2fe2f..10e3b45c 100644
--- a/dan/datasets/extract/__init__.py
+++ b/dan/datasets/extract/__init__.py
@@ -87,16 +87,18 @@ def add_extract_parser(subcommands) -> None:
         action="store_true",
         help="Extract text with their entities.",
     )
-    parser.add_argument(
-        "--only-entities",
-        action="store_true",
-        help="Remove all text that does not belong to the tokens.",
-    )
     parser.add_argument(
         "--allow-unknown-entities",
         action="store_true",
         help="Ignore entities that do not appear in the list of tokens.",
     )
+    parser.add_argument(
+        "--entity-separators",
+        type=str,
+        nargs="+",
+        help="Removes all text that does not appear in an entity or in the list of given characters. Do not give any arguments for keeping the whole text.",
+        required=False,
+    )
     parser.add_argument(
         "--tokens",
         type=pathlib.Path,
diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py
index 71160fda..55b120dc 100644
--- a/dan/datasets/extract/extract.py
+++ b/dan/datasets/extract/extract.py
@@ -2,6 +2,7 @@
 
 import random
 from collections import defaultdict
+from itertools import pairwise
 from pathlib import Path
 from typing import List, Optional, Union
 from uuid import UUID
@@ -36,8 +37,6 @@ IMAGES_DIR = "images"  # Subpath to the images directory.
 LABELS_DIR = "labels"  # Subpath to the labels directory.
 SPLIT_NAMES = ["train", "val", "test"]
 
-EMPTY_CHARS = [" ", "\n", "\t", "\r"]
-
 
 class ArkindexExtractor:
     """
@@ -52,7 +51,7 @@ class ArkindexExtractor:
         output: Path = None,
         load_entities: bool = False,
         allow_unknown_entities: bool = False,
-        only_entities: bool = False,
+        entity_separators: list = [],
         tokens: Path = None,
         use_existing_split: bool = None,
         transcription_worker_version: Optional[Union[str, bool]] = None,
@@ -67,7 +66,7 @@ class ArkindexExtractor:
         self.output = output
         self.load_entities = load_entities
         self.allow_unknown_entities = allow_unknown_entities
-        self.only_entities = only_entities
+        self.entity_separators = entity_separators
         self.tokens = parse_tokens(tokens) if self.load_entities else None
         self.use_existing_split = use_existing_split
         self.transcription_worker_version = transcription_worker_version
@@ -110,20 +109,29 @@ class ArkindexExtractor:
         """
         Insert tokens delimiting the start/end of each entity on the transcription.
         """
+
+        def keep_char(char):
+            return keep_all_text or char in self.entity_separators
+
         text, text_offset = "", 0
+        # Keep all text by default if no separator was given
+        keep_all_text = not self.entity_separators
         for entity in entities:
             # Text before entity
-            if not self.only_entities:
-                text += full_text[text_offset : entity.offset]
+            text += "".join(filter(keep_char, full_text[text_offset : entity.offset]))
 
             entity_type: EntityType = self.tokens.get(entity.type)
+            # Unknown entities are not allowed
             if not entity_type and not self.allow_unknown_entities:
                 raise UnknownLabelError(entity.type)
-            if entity_type and not entity_type.end and not self.only_entities:
+            # We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends.
+            if entity_type and not entity_type.end and keep_all_text:
                 raise NoEndTokenError(entity.type)
 
-            # Entity text (optionally with tokens)
-            if entity_type or not self.only_entities:
+            # Entity text:
+            # - with tokens if there is an entity_type,
+            # - without tokens if there is no entity_type but we want to keep the whole text.
+            if entity_type or keep_all_text:
                 text += insert_token(
                     full_text,
                     entity_type,
@@ -132,20 +140,29 @@ class ArkindexExtractor:
                 )
             text_offset = entity.offset + entity.length
 
-            # Keep the exact separator after entity (it can be a space, a line break etc.)
-            if self.only_entities:
-                separator = next(
-                    (char for char in full_text[text_offset:] if char in EMPTY_CHARS),
-                    "",
-                )
-                text += separator
+        # Remaining text after the last entity
+        text += "".join(filter(keep_char, full_text[text_offset:]))
+
+        if keep_all_text:
+            return text
+
+        # Add some clean up to avoid several separators between entities.
+        text, full_text = "", text
+        for char, next_char in pairwise(full_text):
+            # If several separators follow each other, keep only the last one.
+            if (
+                char not in self.entity_separators
+                or next_char not in self.entity_separators
+            ):
+                text += char
 
-        # Remaining text
-        if not self.only_entities:
-            text += full_text[text_offset:]
+        # Remaining char
+        remaining_char = full_text[-1]
+        if remaining_char not in self.entity_separators:
+            text += remaining_char
 
-        # Remove extra spaces
-        return text.strip("".join(EMPTY_CHARS))
+        # Remove separators at the beginning and end of text
+        return text.strip("".join(self.entity_separators))
 
     def extract_transcription(self, element: Element):
         """
@@ -160,7 +177,7 @@ class ArkindexExtractor:
 
         transcription = random.choice(transcriptions)
 
-        if self.load_entities or self.only_entities:
+        if self.load_entities:
             entities = get_transcription_entities(
                 transcription.id, self.entity_worker_version
             )
@@ -249,7 +266,7 @@ def run(
     output: Path,
     load_entities: bool,
     allow_unknown_entities: bool,
-    only_entities: bool,
+    entity_separators: list,
     tokens: Path,
     use_existing_split: bool,
     train_folder: UUID,
@@ -302,7 +319,7 @@ def run(
         output=output,
         load_entities=load_entities,
         allow_unknown_entities=allow_unknown_entities,
-        only_entities=only_entities,
+        entity_separators=entity_separators,
         tokens=tokens,
         use_existing_split=use_existing_split,
         transcription_worker_version=transcription_worker_version,
diff --git a/tests/test_extract.py b/tests/test_extract.py
index 68b08a1d..e6f999ee 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -4,12 +4,24 @@ from typing import NamedTuple
 
 import pytest
 
+from dan.datasets.extract.exceptions import NoEndTokenError, UnknownLabelError
 from dan.datasets.extract.extract import ArkindexExtractor
 from dan.datasets.extract.utils import EntityType, insert_token
 
 # NamedTuple to mock actual database result
 Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
 
+TOKENS = {
+    "P": EntityType(start="â“Ÿ", end="â“…"),
+    "D": EntityType(start="â““", end="â’¹"),
+    "N": EntityType(start="ⓝ", end="Ⓝ"),
+    "I": EntityType(start="ⓘ", end="Ⓘ"),
+}
+
+
+def filter_tokens(keys):
+    return {key: value for key, value in TOKENS.items() if key in keys}
+
 
 @pytest.mark.parametrize(
     "text,offset,length,expected",
@@ -24,44 +36,165 @@ def test_insert_token(text, offset, length, expected):
     )
 
 
+def test_reconstruct_text_unknown_label_error():
+    arkindex_extractor = ArkindexExtractor()
+    arkindex_extractor.tokens = TOKENS
+    with pytest.raises(
+        UnknownLabelError, match="Label `X` is missing in the NER configuration."
+    ):
+        arkindex_extractor.reconstruct_text(
+            "n°1 x 16 janvier 1611",
+            [
+                Entity(
+                    offset=0,
+                    length=3,
+                    type="X",
+                    value="n°1",
+                ),
+            ],
+        )
+
+
+def test_reconstruct_text_no_end_token_error():
+    arkindex_extractor = ArkindexExtractor()
+    arkindex_extractor.tokens = {
+        "X": EntityType(start="ⓧ"),
+    }
+    with pytest.raises(NoEndTokenError, match="Label `X` has no end token."):
+        arkindex_extractor.reconstruct_text(
+            "n°1 x 16 janvier 1611",
+            [
+                Entity(
+                    offset=0,
+                    length=3,
+                    type="X",
+                    value="n°1",
+                ),
+            ],
+        )
+
+
 @pytest.mark.parametrize(
-    "only_entities,expected",
+    "entity_separators,tokens,expected",
     (
-        (False, "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ x Michou"),
-        (True, "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ"),
+        # Whole text...
+        # ... + All tokens
+        ([], TOKENS, "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
+        # ... + 1rst and 2nd tokens
+        ([], filter_tokens(["P", "D"]), "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ\nMichou"),
+        # ... + 1rst and 3rd tokens
+        ([], filter_tokens(["P", "N"]), "ⓟn°1Ⓟ x 16 janvier 1611\nⓝMichouⓃ"),
+        # ... + 2nd and 3rd tokens
+        ([], filter_tokens(["D", "N"]), "n°1 x ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
+        # Only entities...
+        # ... + All tokens
+        ([" ", "\n"], TOKENS, "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
+        # ... + 1rst and 2nd tokens
+        ([" ", "\n"], filter_tokens(["P", "D"]), "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ"),
+        # ... + 1rst and 3rd tokens
+        ([" ", "\n"], filter_tokens(["P", "N"]), "ⓟn°1Ⓟ\nⓝMichouⓃ"),
+        # ... + 2nd and 3rd tokens
+        ([" ", "\n"], filter_tokens(["D", "N"]), "ⓓ16 janvier 1611Ⓓ\nⓝMichouⓃ"),
     ),
 )
-def test_reconstruct_text(only_entities, expected):
+@pytest.mark.parametrize("text_before", ("", "text before "))
+@pytest.mark.parametrize("text_after", ("", " text after"))
+def test_reconstruct_text(entity_separators, tokens, expected, text_before, text_after):
     arkindex_extractor = ArkindexExtractor(
-        allow_unknown_entities=True, only_entities=only_entities
+        allow_unknown_entities=True, entity_separators=entity_separators
+    )
+    arkindex_extractor.tokens = tokens
+    assert arkindex_extractor.reconstruct_text(
+        text_before + "n°1 x 16 janvier 1611\nMichou" + text_after,
+        [
+            Entity(
+                offset=0 + len(text_before),
+                length=3,
+                type="P",
+                value="n°1",
+            ),
+            Entity(
+                offset=6 + len(text_before),
+                length=15,
+                type="D",
+                value="16 janvier 1611",
+            ),
+            Entity(
+                offset=22 + len(text_before),
+                length=6,
+                type="N",
+                value="Michou",
+            ),
+        ],
+    ) == (
+        (text_before if not entity_separators else "")
+        + expected
+        + (text_after if not entity_separators else "")
     )
+
+
+@pytest.mark.parametrize(
+    "entity_separators",
+    (
+        # Whole text
+        [],
+        # Only entities
+        [" ", "\n"],
+    ),
+)
+@pytest.mark.parametrize("text_before", ("", "text before "))
+@pytest.mark.parametrize("text_after", ("", " text after"))
+def test_reconstruct_text_joined_entities(entity_separators, text_before, text_after):
+    arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators)
+    arkindex_extractor.tokens = TOKENS
+    assert arkindex_extractor.reconstruct_text(
+        text_before + "LouisXIV" + text_after,
+        [
+            Entity(
+                offset=0 + len(text_before),
+                length=5,
+                type="P",
+                value="Louis",
+            ),
+            Entity(
+                offset=5 + len(text_before),
+                length=3,
+                type="I",
+                value="XIV",
+            ),
+        ],
+    ) == (
+        (text_before if not entity_separators else "")
+        + "ⓟLouisⓅⓘXIVⒾ"
+        + (text_after if not entity_separators else "")
+    )
+
+
+@pytest.mark.parametrize("text_before", ("", "text before "))
+@pytest.mark.parametrize("text_after", ("", " text after"))
+def test_reconstruct_text_only_start_token(text_before, text_after):
+    arkindex_extractor = ArkindexExtractor(entity_separators=[" ", "\n"])
     arkindex_extractor.tokens = {
-        "P": EntityType(start="â“Ÿ", end="â“…"),
-        "D": EntityType(start="â““", end="â’¹"),
+        "P": EntityType(start="â“Ÿ"),
+        "I": EntityType(start="ⓘ"),
     }
     assert (
         arkindex_extractor.reconstruct_text(
-            "n°1 x 16 janvier 1611 x Michou",
+            text_before + "LouisXIV" + text_after,
             [
                 Entity(
-                    offset=0,
-                    length=3,
+                    offset=0 + len(text_before),
+                    length=5,
                     type="P",
-                    value="n°1",
+                    value="Louis",
                 ),
                 Entity(
-                    offset=6,
-                    length=15,
-                    type="D",
-                    value="16 janvier 1611",
-                ),
-                Entity(
-                    offset=24,
-                    length=6,
-                    type="N",
-                    value="Michou",
+                    offset=5 + len(text_before),
+                    length=3,
+                    type="I",
+                    value="XIV",
                 ),
             ],
         )
-        == expected
+        == "ⓟLouisⓘXIV"
     )
-- 
GitLab