From b9351d5485b20caf2fa51384229fce5a58276467 Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Tue, 18 Jul 2023 17:09:09 +0200 Subject: [PATCH] Filter entities by name when extracting data from Arkindex --- dan/datasets/extract/__init__.py | 5 ++++ dan/datasets/extract/extract.py | 31 ++++++++++++------- dan/datasets/extract/utils.py | 12 ++------ docs/usage/datasets/extract.md | 1 + tests/test_extract.py | 51 ++++++++++++++++++++++---------- 5 files changed, 65 insertions(+), 35 deletions(-) diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 54cca912..bb9fb9dd 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -85,6 +85,11 @@ def add_extract_parser(subcommands) -> None: parser.add_argument( "--load-entities", action="store_true", help="Extract text with their entities." ) + parser.add_argument( + "--allow-unknown-entities", + action="store_true", + help="Ignore entities that do not appear in the list of tokens.", + ) parser.add_argument( "--tokens", type=pathlib.Path, diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index 258500ba..bbededdc 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -22,7 +22,6 @@ from dan.datasets.extract.exceptions import ( UnknownLabelError, ) from dan.datasets.extract.utils import ( - EntityType, Subset, download_image, insert_token, @@ -47,7 +46,8 @@ class ArkindexExtractor: element_type: list = [], parent_element_type: str = None, output: Path = None, - load_entities: bool = None, + load_entities: bool = False, + allow_unknown_entities: bool = False, tokens: Path = None, use_existing_split: bool = None, transcription_worker_version: Optional[Union[str, bool]] = None, @@ -61,6 +61,7 @@ class ArkindexExtractor: self.parent_element_type = parent_element_type self.output = output self.load_entities = load_entities + self.allow_unknown_entities = allow_unknown_entities self.tokens = parse_tokens(tokens) if self.load_entities else None self.use_existing_split = use_existing_split self.transcription_worker_version = transcription_worker_version @@ -103,21 +104,27 @@ class ArkindexExtractor: """ Insert tokens delimiting the start/end of each entity on the transcription. """ - count = 0 - for entity in entities: - if entity.type not in self.tokens: + + # Filter entities + for entity in entities.copy(): + # Tokens known for this entity + if entity.type in self.tokens: + continue + # Tokens unknown for this entity + if not self.allow_unknown_entities: raise UnknownLabelError(entity.type) + entities.remove(entity) - entity_type: EntityType = self.tokens[entity.type] - text = insert_token( + # Only keep text of the filtered entities + return " ".join( + insert_token( text, - count, - entity_type, + entity_type=self.tokens[entity.type], offset=entity.offset, length=entity.length, ) - count += entity_type.offset - return text + for entity in entities + ) def extract_transcription(self, element: Element): """ @@ -220,6 +227,7 @@ def run( parent_element_type: str, output: Path, load_entities: bool, + allow_unknown_entities: bool, tokens: Path, use_existing_split: bool, train_folder: UUID, @@ -271,6 +279,7 @@ def run( parent_element_type=parent_element_type, output=output, load_entities=load_entities, + allow_unknown_entities=allow_unknown_entities, tokens=tokens, use_existing_split=use_existing_split, transcription_worker_version=transcription_worker_version, diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 90bebb9d..3a125337 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -74,23 +74,17 @@ def save_json(path: Path, data: dict): json.dump(data, outfile, indent=4) -def insert_token( - text: str, count: int, entity_type: EntityType, offset: int, length: int -) -> str: +def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str: """ Insert the given tokens at the right position in the text """ return ( - # Text before entity - text[: count + offset] # Starting token - + entity_type.start + entity_type.start # Entity - + text[count + offset : count + offset + length] + + text[offset : offset + length] # End token + entity_type.end - # Text after entity - + text[count + offset + length :] ) diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md index 65cf1a95..c8aa3416 100644 --- a/docs/usage/datasets/extract.md +++ b/docs/usage/datasets/extract.md @@ -12,6 +12,7 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind | `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` | | `--output` | Folder where the data will be generated. | `Path` | | | `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` | +| `--allow-unknown-entities` | Ignore entities that do not appear in the list of tokens. | `bool` | `False` | | `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | | | `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | | | `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | | diff --git a/tests/test_extract.py b/tests/test_extract.py index 4ddc6fa2..7f1b363e 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -12,24 +12,27 @@ Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str) @pytest.mark.parametrize( - "text,count,offset,length,expected", + "text,offset,length,expected", ( - ("n°1 16 janvier 1611", 0, 0, 3, "ⓘn°1Ⓘ 16 janvier 1611"), - ("ⓘn°1Ⓘ 16 janvier 1611", 2, 4, 15, "ⓘn°1Ⓘ ⓘ16 janvier 1611Ⓘ"), + ("n°1 16 janvier 1611", 0, 3, "ⓘn°1Ⓘ"), + ("ⓘn°1Ⓘ 16 janvier 1611", 6, 15, "ⓘ16 janvier 1611Ⓘ"), ), ) -def test_insert_token(text, count, offset, length, expected): +def test_insert_token(text, offset, length, expected): assert ( - insert_token(text, count, EntityType(start="ⓘ", end="Ⓘ"), offset, length) - == expected + insert_token(text, EntityType(start="ⓘ", end="Ⓘ"), offset, length) == expected ) @pytest.mark.parametrize( - "text,entities,expected", + "tokens,text,entities,expected", ( ( - "n°1 16 janvier 1611", + { + "P": EntityType(start="ⓟ", end="Ⓟ"), + "D": EntityType(start="ⓓ", end="Ⓓ"), + }, + "n°1 x 16 janvier 1611", [ Entity( offset=0, @@ -38,7 +41,7 @@ def test_insert_token(text, count, offset, length, expected): value="n°1", ), Entity( - offset=4, + offset=6, length=15, type="D", value="16 janvier 1611", @@ -46,12 +49,30 @@ def test_insert_token(text, count, offset, length, expected): ], "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ", ), + ( + { + "P": EntityType(start="ⓟ", end="Ⓟ"), + }, + "n°1 x 16 janvier 1611", + [ + Entity( + offset=0, + length=3, + type="P", + value="n°1", + ), + Entity( + offset=6, + length=15, + type="D", + value="16 janvier 1611", + ), + ], + "ⓟn°1Ⓟ", + ), ), ) -def test_reconstruct_text(text, entities, expected): - arkindex_extractor = ArkindexExtractor() - arkindex_extractor.tokens = { - "P": EntityType(start="ⓟ", end="Ⓟ"), - "D": EntityType(start="ⓓ", end="Ⓓ"), - } +def test_reconstruct_text(tokens, text, entities, expected): + arkindex_extractor = ArkindexExtractor(allow_unknown_entities=True) + arkindex_extractor.tokens = tokens assert arkindex_extractor.reconstruct_text(text, entities) == expected -- GitLab