From f0ec449d0bb7a10abfb57c6ffc632593fde9541e Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Wed, 19 Jul 2023 17:04:42 +0200 Subject: [PATCH] Allow to keep text around entities --- dan/datasets/extract/__init__.py | 9 ++++- dan/datasets/extract/exceptions.py | 15 +++++++ dan/datasets/extract/extract.py | 65 ++++++++++++++++++++---------- dan/datasets/extract/utils.py | 4 +- tests/test_extract.py | 57 +++++++++++--------------- 5 files changed, 92 insertions(+), 58 deletions(-) diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 33be3c0e..232f4458 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -83,7 +83,14 @@ def add_extract_parser(subcommands) -> None: # Optional arguments. parser.add_argument( - "--load-entities", action="store_true", help="Extract text with their entities." + "--load-entities", + action="store_true", + help="Extract text with their entities.", + ) + parser.add_argument( + "--only-entities", + action="store_true", + help="Extract text with their entities and remove all text that does not belong to the tokens.", ) parser.add_argument( "--allow-unknown-entities", diff --git a/dan/datasets/extract/exceptions.py b/dan/datasets/extract/exceptions.py index 22c47a6c..da8fba65 100644 --- a/dan/datasets/extract/exceptions.py +++ b/dan/datasets/extract/exceptions.py @@ -62,3 +62,18 @@ class UnknownLabelError(ProcessingError): def __str__(self) -> str: return f"Label `{self.label}` is missing in the NER configuration." + + +class NoEndTokenError(ProcessingError): + """ + Raised when the specified label has no end token and there is potentially additional text around the labels + """ + + label: str + + def __init__(self, label: str, *args: object) -> None: + super().__init__(*args) + self.label = label + + def __str__(self) -> str: + return f"Label `{self.label}` has no end token." diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index a0cf5574..ea8043d5 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -18,11 +18,13 @@ from dan.datasets.extract.db import ( get_transcriptions, ) from dan.datasets.extract.exceptions import ( + NoEndTokenError, NoTranscriptionError, ProcessingError, UnknownLabelError, ) from dan.datasets.extract.utils import ( + EntityType, Subset, download_image, insert_token, @@ -35,6 +37,8 @@ IMAGES_DIR = "images" # Subpath to the images directory. LABELS_DIR = "labels" # Subpath to the labels directory. SPLIT_NAMES = ["train", "val", "test"] +EMPTY_CHARS = [" ", "\n", "\t", "\r"] + class ArkindexExtractor: """ @@ -49,6 +53,7 @@ class ArkindexExtractor: output: Path = None, load_entities: bool = False, allow_unknown_entities: bool = False, + only_entities: bool = False, tokens: Path = None, use_existing_split: bool = None, transcription_worker_version: str = None, @@ -63,6 +68,7 @@ class ArkindexExtractor: self.output = output self.load_entities = load_entities self.allow_unknown_entities = allow_unknown_entities + self.only_entities = only_entities self.tokens = parse_tokens(tokens) if self.load_entities else None self.use_existing_split = use_existing_split self.transcription_worker_version = transcription_worker_version @@ -101,31 +107,46 @@ class ArkindexExtractor: def get_random_split(self): return next(self._assign_random_split()) - def reconstruct_text(self, text: str, entities: List[Entity]): + def reconstruct_text(self, full_text: str, entities: List[Entity]): """ Insert tokens delimiting the start/end of each entity on the transcription. """ + text, text_offset = "", 0 + for entity in entities: + # Text before entity + if not self.only_entities: + text += full_text[text_offset : entity.offset] - # Filter entities - for entity in entities.copy(): - # Tokens known for this entity - if entity.type in self.tokens: - continue - # Tokens unknown for this entity - if not self.allow_unknown_entities: + entity_type: EntityType = self.tokens.get(entity.type) + if not entity_type and not self.allow_unknown_entities: raise UnknownLabelError(entity.type) - entities.remove(entity) - - # Only keep text of the filtered entities - return " ".join( - insert_token( - text, - entity_type=self.tokens[entity.type], - offset=entity.offset, - length=entity.length, - ) - for entity in entities - ) + if entity_type and not entity_type.end and not self.only_entities: + raise NoEndTokenError(entity.type) + + # Entity text (optionally with tokens) + if entity_type or not self.only_entities: + text += insert_token( + full_text, + entity_type, + offset=entity.offset, + length=entity.length, + ) + text_offset = entity.offset + entity.length + + # Keep the exact separator after entity (it can be a space, a line break etc.) + if self.only_entities: + separator = next( + (char for char in full_text[text_offset:] if char in EMPTY_CHARS), + "", + ) + text += separator + + # Remaining text + if not self.only_entities: + text += full_text[text_offset:] + + # Remove extra spaces + return text.strip("".join(EMPTY_CHARS)) def extract_transcription(self, element: Element): """ @@ -140,7 +161,7 @@ class ArkindexExtractor: transcription = random.choice(transcriptions) - if self.load_entities: + if self.load_entities or self.only_entities: entities = get_transcription_entities( transcription.id, self.entity_worker_version ) @@ -229,6 +250,7 @@ def run( output: Path, load_entities: bool, allow_unknown_entities: bool, + only_entities: bool, tokens: Path, use_existing_split: bool, train_folder: UUID, @@ -281,6 +303,7 @@ def run( output=output, load_entities=load_entities, allow_unknown_entities=allow_unknown_entities, + only_entities=only_entities, tokens=tokens, use_existing_split=use_existing_split, transcription_worker_version=transcription_worker_version, diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 3a125337..3194cccd 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -80,11 +80,11 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) - """ return ( # Starting token - entity_type.start + (entity_type.start if entity_type else "") # Entity + text[offset : offset + length] # End token - + entity_type.end + + (entity_type.end if entity_type else "") ) diff --git a/tests/test_extract.py b/tests/test_extract.py index 9ceed8dc..d299ea0b 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -20,14 +20,23 @@ def test_insert_token(text, offset, length, expected): @pytest.mark.parametrize( - "tokens,text,entities,expected", + "only_entities,expected", ( - ( - { - "P": EntityType(start="ⓟ", end="Ⓟ"), - "D": EntityType(start="ⓓ", end="Ⓓ"), - }, - "n°1 x 16 janvier 1611", + (False, "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ x Michou"), + (True, "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ"), + ), +) +def test_reconstruct_text(only_entities, expected): + arkindex_extractor = ArkindexExtractor( + allow_unknown_entities=True, only_entities=only_entities + ) + arkindex_extractor.tokens = { + "P": EntityType(start="ⓟ", end="Ⓟ"), + "D": EntityType(start="ⓓ", end="Ⓓ"), + } + assert ( + arkindex_extractor.reconstruct_text( + "n°1 x 16 janvier 1611 x Michou", [ Entity( offset=0, @@ -41,33 +50,13 @@ def test_insert_token(text, offset, length, expected): type="D", value="16 janvier 1611", ), - ], - "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ", - ), - ( - { - "P": EntityType(start="ⓟ", end="Ⓟ"), - }, - "n°1 x 16 janvier 1611", - [ Entity( - offset=0, - length=3, - type="P", - value="n°1", - ), - Entity( - offset=6, - length=15, - type="D", - value="16 janvier 1611", + offset=24, + length=6, + type="N", + value="Michou", ), ], - "ⓟn°1Ⓟ", - ), - ), -) -def test_reconstruct_text(tokens, text, entities, expected): - arkindex_extractor = ArkindexExtractor(allow_unknown_entities=True) - arkindex_extractor.tokens = tokens - assert arkindex_extractor.reconstruct_text(text, entities) == expected + ) + == expected + ) -- GitLab