From 88720b7434da9c0d26e9a1781d47f1789ad06a2f Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Thu, 20 Jul 2023 11:07:33 +0200 Subject: [PATCH] Support joined entities and entity separators --- dan/datasets/extract/__init__.py | 12 ++- dan/datasets/extract/extract.py | 65 +++++++----- tests/test_extract.py | 177 +++++++++++++++++++++++++++---- 3 files changed, 203 insertions(+), 51 deletions(-) diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 93c2fe2f..10e3b45c 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -87,16 +87,18 @@ def add_extract_parser(subcommands) -> None: action="store_true", help="Extract text with their entities.", ) - parser.add_argument( - "--only-entities", - action="store_true", - help="Remove all text that does not belong to the tokens.", - ) parser.add_argument( "--allow-unknown-entities", action="store_true", help="Ignore entities that do not appear in the list of tokens.", ) + parser.add_argument( + "--entity-separators", + type=str, + nargs="+", + help="Removes all text that does not appear in an entity or in the list of given characters. Do not give any arguments for keeping the whole text.", + required=False, + ) parser.add_argument( "--tokens", type=pathlib.Path, diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index 71160fda..55b120dc 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -2,6 +2,7 @@ import random from collections import defaultdict +from itertools import pairwise from pathlib import Path from typing import List, Optional, Union from uuid import UUID @@ -36,8 +37,6 @@ IMAGES_DIR = "images" # Subpath to the images directory. LABELS_DIR = "labels" # Subpath to the labels directory. SPLIT_NAMES = ["train", "val", "test"] -EMPTY_CHARS = [" ", "\n", "\t", "\r"] - class ArkindexExtractor: """ @@ -52,7 +51,7 @@ class ArkindexExtractor: output: Path = None, load_entities: bool = False, allow_unknown_entities: bool = False, - only_entities: bool = False, + entity_separators: list = [], tokens: Path = None, use_existing_split: bool = None, transcription_worker_version: Optional[Union[str, bool]] = None, @@ -67,7 +66,7 @@ class ArkindexExtractor: self.output = output self.load_entities = load_entities self.allow_unknown_entities = allow_unknown_entities - self.only_entities = only_entities + self.entity_separators = entity_separators self.tokens = parse_tokens(tokens) if self.load_entities else None self.use_existing_split = use_existing_split self.transcription_worker_version = transcription_worker_version @@ -110,20 +109,29 @@ class ArkindexExtractor: """ Insert tokens delimiting the start/end of each entity on the transcription. """ + + def keep_char(char): + return keep_all_text or char in self.entity_separators + text, text_offset = "", 0 + # Keep all text by default if no separator was given + keep_all_text = not self.entity_separators for entity in entities: # Text before entity - if not self.only_entities: - text += full_text[text_offset : entity.offset] + text += "".join(filter(keep_char, full_text[text_offset : entity.offset])) entity_type: EntityType = self.tokens.get(entity.type) + # Unknown entities are not allowed if not entity_type and not self.allow_unknown_entities: raise UnknownLabelError(entity.type) - if entity_type and not entity_type.end and not self.only_entities: + # We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends. + if entity_type and not entity_type.end and keep_all_text: raise NoEndTokenError(entity.type) - # Entity text (optionally with tokens) - if entity_type or not self.only_entities: + # Entity text: + # - with tokens if there is an entity_type, + # - without tokens if there is no entity_type but we want to keep the whole text. + if entity_type or keep_all_text: text += insert_token( full_text, entity_type, @@ -132,20 +140,29 @@ class ArkindexExtractor: ) text_offset = entity.offset + entity.length - # Keep the exact separator after entity (it can be a space, a line break etc.) - if self.only_entities: - separator = next( - (char for char in full_text[text_offset:] if char in EMPTY_CHARS), - "", - ) - text += separator + # Remaining text after the last entity + text += "".join(filter(keep_char, full_text[text_offset:])) + + if keep_all_text: + return text + + # Add some clean up to avoid several separators between entities. + text, full_text = "", text + for char, next_char in pairwise(full_text): + # If several separators follow each other, keep only the last one. + if ( + char not in self.entity_separators + or next_char not in self.entity_separators + ): + text += char - # Remaining text - if not self.only_entities: - text += full_text[text_offset:] + # Remaining char + remaining_char = full_text[-1] + if remaining_char not in self.entity_separators: + text += remaining_char - # Remove extra spaces - return text.strip("".join(EMPTY_CHARS)) + # Remove separators at the beginning and end of text + return text.strip("".join(self.entity_separators)) def extract_transcription(self, element: Element): """ @@ -160,7 +177,7 @@ class ArkindexExtractor: transcription = random.choice(transcriptions) - if self.load_entities or self.only_entities: + if self.load_entities: entities = get_transcription_entities( transcription.id, self.entity_worker_version ) @@ -249,7 +266,7 @@ def run( output: Path, load_entities: bool, allow_unknown_entities: bool, - only_entities: bool, + entity_separators: list, tokens: Path, use_existing_split: bool, train_folder: UUID, @@ -302,7 +319,7 @@ def run( output=output, load_entities=load_entities, allow_unknown_entities=allow_unknown_entities, - only_entities=only_entities, + entity_separators=entity_separators, tokens=tokens, use_existing_split=use_existing_split, transcription_worker_version=transcription_worker_version, diff --git a/tests/test_extract.py b/tests/test_extract.py index 68b08a1d..e6f999ee 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -4,12 +4,24 @@ from typing import NamedTuple import pytest +from dan.datasets.extract.exceptions import NoEndTokenError, UnknownLabelError from dan.datasets.extract.extract import ArkindexExtractor from dan.datasets.extract.utils import EntityType, insert_token # NamedTuple to mock actual database result Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str) +TOKENS = { + "P": EntityType(start="â“Ÿ", end="â“…"), + "D": EntityType(start="â““", end="â’¹"), + "N": EntityType(start="â“", end="Ⓝ"), + "I": EntityType(start="ⓘ", end="â’¾"), +} + + +def filter_tokens(keys): + return {key: value for key, value in TOKENS.items() if key in keys} + @pytest.mark.parametrize( "text,offset,length,expected", @@ -24,44 +36,165 @@ def test_insert_token(text, offset, length, expected): ) +def test_reconstruct_text_unknown_label_error(): + arkindex_extractor = ArkindexExtractor() + arkindex_extractor.tokens = TOKENS + with pytest.raises( + UnknownLabelError, match="Label `X` is missing in the NER configuration." + ): + arkindex_extractor.reconstruct_text( + "n°1 x 16 janvier 1611", + [ + Entity( + offset=0, + length=3, + type="X", + value="n°1", + ), + ], + ) + + +def test_reconstruct_text_no_end_token_error(): + arkindex_extractor = ArkindexExtractor() + arkindex_extractor.tokens = { + "X": EntityType(start="ⓧ"), + } + with pytest.raises(NoEndTokenError, match="Label `X` has no end token."): + arkindex_extractor.reconstruct_text( + "n°1 x 16 janvier 1611", + [ + Entity( + offset=0, + length=3, + type="X", + value="n°1", + ), + ], + ) + + @pytest.mark.parametrize( - "only_entities,expected", + "entity_separators,tokens,expected", ( - (False, "â“Ÿn°1â“… x â““16 janvier 1611â’¹ x Michou"), - (True, "â“Ÿn°1â“… â““16 janvier 1611â’¹"), + # Whole text... + # ... + All tokens + ([], TOKENS, "â“Ÿn°1â“… x â““16 janvier 1611â’¹\nâ“MichouⓃ"), + # ... + 1rst and 2nd tokens + ([], filter_tokens(["P", "D"]), "â“Ÿn°1â“… x â““16 janvier 1611â’¹\nMichou"), + # ... + 1rst and 3rd tokens + ([], filter_tokens(["P", "N"]), "â“Ÿn°1â“… x 16 janvier 1611\nâ“MichouⓃ"), + # ... + 2nd and 3rd tokens + ([], filter_tokens(["D", "N"]), "n°1 x â““16 janvier 1611â’¹\nâ“MichouⓃ"), + # Only entities... + # ... + All tokens + ([" ", "\n"], TOKENS, "â“Ÿn°1â“… â““16 janvier 1611â’¹\nâ“MichouⓃ"), + # ... + 1rst and 2nd tokens + ([" ", "\n"], filter_tokens(["P", "D"]), "â“Ÿn°1â“… â““16 janvier 1611â’¹"), + # ... + 1rst and 3rd tokens + ([" ", "\n"], filter_tokens(["P", "N"]), "â“Ÿn°1â“…\nâ“MichouⓃ"), + # ... + 2nd and 3rd tokens + ([" ", "\n"], filter_tokens(["D", "N"]), "â““16 janvier 1611â’¹\nâ“MichouⓃ"), ), ) -def test_reconstruct_text(only_entities, expected): +@pytest.mark.parametrize("text_before", ("", "text before ")) +@pytest.mark.parametrize("text_after", ("", " text after")) +def test_reconstruct_text(entity_separators, tokens, expected, text_before, text_after): arkindex_extractor = ArkindexExtractor( - allow_unknown_entities=True, only_entities=only_entities + allow_unknown_entities=True, entity_separators=entity_separators + ) + arkindex_extractor.tokens = tokens + assert arkindex_extractor.reconstruct_text( + text_before + "n°1 x 16 janvier 1611\nMichou" + text_after, + [ + Entity( + offset=0 + len(text_before), + length=3, + type="P", + value="n°1", + ), + Entity( + offset=6 + len(text_before), + length=15, + type="D", + value="16 janvier 1611", + ), + Entity( + offset=22 + len(text_before), + length=6, + type="N", + value="Michou", + ), + ], + ) == ( + (text_before if not entity_separators else "") + + expected + + (text_after if not entity_separators else "") ) + + +@pytest.mark.parametrize( + "entity_separators", + ( + # Whole text + [], + # Only entities + [" ", "\n"], + ), +) +@pytest.mark.parametrize("text_before", ("", "text before ")) +@pytest.mark.parametrize("text_after", ("", " text after")) +def test_reconstruct_text_joined_entities(entity_separators, text_before, text_after): + arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators) + arkindex_extractor.tokens = TOKENS + assert arkindex_extractor.reconstruct_text( + text_before + "LouisXIV" + text_after, + [ + Entity( + offset=0 + len(text_before), + length=5, + type="P", + value="Louis", + ), + Entity( + offset=5 + len(text_before), + length=3, + type="I", + value="XIV", + ), + ], + ) == ( + (text_before if not entity_separators else "") + + "â“ŸLouisⓅⓘXIVâ’¾" + + (text_after if not entity_separators else "") + ) + + +@pytest.mark.parametrize("text_before", ("", "text before ")) +@pytest.mark.parametrize("text_after", ("", " text after")) +def test_reconstruct_text_only_start_token(text_before, text_after): + arkindex_extractor = ArkindexExtractor(entity_separators=[" ", "\n"]) arkindex_extractor.tokens = { - "P": EntityType(start="â“Ÿ", end="â“…"), - "D": EntityType(start="â““", end="â’¹"), + "P": EntityType(start="â“Ÿ"), + "I": EntityType(start="ⓘ"), } assert ( arkindex_extractor.reconstruct_text( - "n°1 x 16 janvier 1611 x Michou", + text_before + "LouisXIV" + text_after, [ Entity( - offset=0, - length=3, + offset=0 + len(text_before), + length=5, type="P", - value="n°1", + value="Louis", ), Entity( - offset=6, - length=15, - type="D", - value="16 janvier 1611", - ), - Entity( - offset=24, - length=6, - type="N", - value="Michou", + offset=5 + len(text_before), + length=3, + type="I", + value="XIV", ), ], ) - == expected + == "â“ŸLouisⓘXIV" ) -- GitLab