From 30613c99a032a71400988e153b743a4381303dc8 Mon Sep 17 00:00:00 2001 From: Manon blanco <blanco@teklia.com> Date: Tue, 25 Jul 2023 12:35:06 +0000 Subject: [PATCH] Filter entities by name when extracting data from Arkindex --- dan/datasets/extract/__init__.py | 24 +++- dan/datasets/extract/exceptions.py | 6 +- dan/datasets/extract/extract.py | 79 +++++++++--- dan/datasets/extract/utils.py | 14 +- docs/usage/datasets/extract.md | 61 +++++---- docs/usage/train/augmentation.md | 3 +- docs/usage/train/jeanzay.md | 3 +- tests/test_extract.py | 201 +++++++++++++++++++++++++---- 8 files changed, 312 insertions(+), 79 deletions(-) diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 54cca912..82be4807 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -40,6 +40,15 @@ def validate_probability(proba): return proba +def validate_char(char): + if len(char) != 1: + raise argparse.ArgumentTypeError( + f"`{char}` (of length {len(char)}) is not a valid character. Must be a string of length 1." + ) + + return char + + def add_extract_parser(subcommands) -> None: parser = subcommands.add_parser( "extract", @@ -83,7 +92,20 @@ def add_extract_parser(subcommands) -> None: # Optional arguments. parser.add_argument( - "--load-entities", action="store_true", help="Extract text with their entities." + "--load-entities", + action="store_true", + help="Extract text with their entities.", + ) + parser.add_argument( + "--entity-separators", + type=validate_char, + nargs="+", + help=""" + Removes all text that does not appear in an entity or in the list of given ordered characters. + If several separators follow each other, keep only the first to appear in the list. + Do not give any arguments to keep the whole text. + """, + required=False, ) parser.add_argument( "--tokens", diff --git a/dan/datasets/extract/exceptions.py b/dan/datasets/extract/exceptions.py index 22c47a6c..93c8d1ae 100644 --- a/dan/datasets/extract/exceptions.py +++ b/dan/datasets/extract/exceptions.py @@ -49,9 +49,9 @@ class NoTranscriptionError(ElementProcessingError): return f"No transcriptions found on element ({self.element_id}) with this config. Skipping." -class UnknownLabelError(ProcessingError): +class NoEndTokenError(ProcessingError): """ - Raised when the specified label is not known + Raised when the specified label has no end token and there is potentially additional text around the labels """ label: str @@ -61,4 +61,4 @@ class UnknownLabelError(ProcessingError): self.label = label def __str__(self) -> str: - return f"Label `{self.label}` is missing in the NER configuration." + return f"Label `{self.label}` has no end token." diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index 258500ba..ab2539c0 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -17,9 +17,9 @@ from dan.datasets.extract.db import ( get_transcriptions, ) from dan.datasets.extract.exceptions import ( + NoEndTokenError, NoTranscriptionError, ProcessingError, - UnknownLabelError, ) from dan.datasets.extract.utils import ( EntityType, @@ -47,7 +47,8 @@ class ArkindexExtractor: element_type: list = [], parent_element_type: str = None, output: Path = None, - load_entities: bool = None, + load_entities: bool = False, + entity_separators: list = [], tokens: Path = None, use_existing_split: bool = None, transcription_worker_version: Optional[Union[str, bool]] = None, @@ -61,6 +62,7 @@ class ArkindexExtractor: self.parent_element_type = parent_element_type self.output = output self.load_entities = load_entities + self.entity_separators = entity_separators self.tokens = parse_tokens(tokens) if self.load_entities else None self.use_existing_split = use_existing_split self.transcription_worker_version = transcription_worker_version @@ -99,25 +101,68 @@ class ArkindexExtractor: def get_random_split(self): return next(self._assign_random_split()) - def reconstruct_text(self, text: str, entities): + def _keep_char(self, char: str) -> bool: + # Keep all text by default if no separator was given + return not self.entity_separators or char in self.entity_separators + + def reconstruct_text(self, full_text: str, entities) -> str: """ Insert tokens delimiting the start/end of each entity on the transcription. """ - count = 0 + text, text_offset = "", 0 + # Keep all text by default if no separator was given for entity in entities: - if entity.type not in self.tokens: - raise UnknownLabelError(entity.type) - - entity_type: EntityType = self.tokens[entity.type] - text = insert_token( - text, - count, - entity_type, - offset=entity.offset, - length=entity.length, + # Text before entity + text += "".join( + filter(self._keep_char, full_text[text_offset : entity.offset]) ) - count += entity_type.offset - return text + + entity_type: EntityType = self.tokens.get(entity.type) + if not entity_type: + logger.warning( + f"Label `{entity.type}` is missing in the NER configuration." + ) + # We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends + elif not entity_type.end and not self.entity_separators: + raise NoEndTokenError(entity.type) + + # Entity text: + # - with tokens if there is an entity_type + # - without tokens if there is no entity_type but we want to keep the whole text + if entity_type or not self.entity_separators: + text += insert_token( + full_text, + entity_type, + offset=entity.offset, + length=entity.length, + ) + text_offset = entity.offset + entity.length + + # Remaining text after the last entity + text += "".join(filter(self._keep_char, full_text[text_offset:])) + + if not self.entity_separators: + return text + + # Add some clean up to avoid several separators between entities + text, full_text = "", text + for char in full_text: + last_char = text[-1] if len(text) else "" + + # Keep the current character if there are no two consecutive separators + if ( + char not in self.entity_separators + or last_char not in self.entity_separators + ): + text += char + # If several separators follow each other, keep only one according to the given order + elif self.entity_separators.index(char) < self.entity_separators.index( + last_char + ): + text = text[:-1] + char + + # Remove separators at the beginning and end of text + return text.strip("".join(self.entity_separators)) def extract_transcription(self, element: Element): """ @@ -220,6 +265,7 @@ def run( parent_element_type: str, output: Path, load_entities: bool, + entity_separators: list, tokens: Path, use_existing_split: bool, train_folder: UUID, @@ -271,6 +317,7 @@ def run( parent_element_type=parent_element_type, output=output, load_entities=load_entities, + entity_separators=entity_separators, tokens=tokens, use_existing_split=use_existing_split, transcription_worker_version=transcription_worker_version, diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 90bebb9d..3194cccd 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -74,23 +74,17 @@ def save_json(path: Path, data: dict): json.dump(data, outfile, indent=4) -def insert_token( - text: str, count: int, entity_type: EntityType, offset: int, length: int -) -> str: +def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str: """ Insert the given tokens at the right position in the text """ return ( - # Text before entity - text[: count + offset] # Starting token - + entity_type.start + (entity_type.start if entity_type else "") # Entity - + text[count + offset : count + offset + length] + + text[offset : offset + length] # End token - + entity_type.end - # Text after entity - + text[count + offset + length :] + + (entity_type.end if entity_type else "") ) diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md index 65cf1a95..4a0ba209 100644 --- a/docs/usage/datasets/extract.md +++ b/docs/usage/datasets/extract.md @@ -4,25 +4,26 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkindex export database (SQLite format). This will generate the images and the labels needed to train a DAN model. -| Parameter | Description | Type | Default | -| -------------------------------- | ----------------------------------------------------------------------------------- | --------------- | ------- | -| `database` | Path to an Arkindex export database in SQLite format. | `Path` | | -| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | `str` or `uuid` | | -| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | -| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` | -| `--output` | Folder where the data will be generated. | `Path` | | -| `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` | -| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | | -| `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | | -| `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | | -| `--val-folder` | ID of the validation folder to import from Arkindex. | `uuid` | | -| `--test-folder` | ID of the training folder to import from Arkindex. | `uuid` | | -| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | -| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | -| `--train-prob` | Training set split size | `float` | `0.7` | -| `--val-prob` | Validation set split size | `float` | `0.15` | -| `--max-width` | Images larger than this width will be resized to this width. | `int` | | -| `--max-height` | Images larger than this height will be resized to this height. | `int` | | +| Parameter | Description | Type | Default | +| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------- | ------------------------------------ | +| `database` | Path to an Arkindex export database in SQLite format. | `Path` | | +| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | `str` or `uuid` | | +| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | +| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` | +| `--output` | Folder where the data will be generated. | `Path` | | +| `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` | +| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text. | `str` | (see [dedicated section](#examples)) | +| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | | +| `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | | +| `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | | +| `--val-folder` | ID of the validation folder to import from Arkindex. | `uuid` | | +| `--test-folder` | ID of the training folder to import from Arkindex. | `uuid` | | +| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | +| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | +| `--train-prob` | Training set split size | `float` | `0.7` | +| `--val-prob` | Validation set split size | `float` | `0.15` | +| `--max-width` | Images larger than this width will be resized to this width. | `int` | | +| `--max-height` | Images larger than this height will be resized to this height. | `int` | | The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. @@ -54,7 +55,7 @@ CLASSEMENT: ### HTR and NER data from one source -To extract HTR+NER data from **pages** from a folder, use the following command: +To extract HTR+NER data from **pages** from a folder, you have to define an end token for each entity and use the following command: ```shell teklia-dan dataset extract \ @@ -70,7 +71,7 @@ with `tokens.yml` compliant with the format described before. ### HTR and NER data from multiple source -To do the same but only use the data from three folders, the commands becomes: +To do the same but only use the data from three folders, you have to define an end token for each entity and the commands becomes: ```shell teklia-dan dataset extract \ @@ -84,7 +85,7 @@ teklia-dan dataset extract \ ### HTR and NER data with an existing split -To use the data from three folders as **training**, **validation** and **testing** dataset respectively, the commands becomes: +To use the data from three folders as **training**, **validation** and **testing** dataset respectively, you have to define a end token for each entity and the commands becomes: ```shell teklia-dan dataset extract \ @@ -101,7 +102,7 @@ teklia-dan dataset extract \ ### HTR from multiple element types with some parent filtering -To extract HTR data from **annotations** and **text_zones** from a folder, but only keep those that are children of **single_pages**, use the following command: +To extract HTR data from **annotations** and **text_zones** from a folder, but only keep those that are children of **single_pages**, you have to define an end token for each entity and use the following command: ```shell teklia-dan dataset extract \ @@ -111,3 +112,17 @@ teklia-dan dataset extract \ --parent-element-type single_page \ --output data ``` + +### NER data + +To extract NER data and keep breaklines and spaces between entities, use the following command: + +```shell +teklia-dan dataset extract \ + [...] + --load-entities \ + --entity-separators $'\n' " " \ + --tokens tokens.yml +``` + +If several separators follow each other, it will keep only one, ideally a breakline if there is one, otherwise a space. If you change the order of the `--entity-separators` parameters, then it will keep a space if there is one, otherwise a breakline. diff --git a/docs/usage/train/augmentation.md b/docs/usage/train/augmentation.md index 3af48ea8..ab9f3870 100644 --- a/docs/usage/train/augmentation.md +++ b/docs/usage/train/augmentation.md @@ -16,7 +16,8 @@ This page lists data augmentation transforms used in DAN. ### PieceWise Affine -:warning: This transform is temporarily removed from the pipeline until [this issue](https://github.com/albumentations-team/albumentations/issues/1442) is fixed. +!!! warning + This transform is temporarily removed from the pipeline until [this issue](https://github.com/albumentations-team/albumentations/issues/1442) is fixed. | | PieceWise Affine | | ---------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | diff --git a/docs/usage/train/jeanzay.md b/docs/usage/train/jeanzay.md index f4a6d8ce..f486b76d 100644 --- a/docs/usage/train/jeanzay.md +++ b/docs/usage/train/jeanzay.md @@ -4,7 +4,8 @@ See the [wiki](https://redmine.teklia.com/projects/research/wiki/Jean_Zay) for m ## Run a training job -Warning: there is no HTTP connection during a job. +!!! warning + There is no HTTP connection during a job. You can debug using an interactive job. The following command will get you a new terminal with 1 gpu for 1 hour: `srun --ntasks=1 --cpus-per-task=40 --gres=gpu:1 --time=01:00:00 --qos=qos_gpu-dev --pty bash -i`. diff --git a/tests/test_extract.py b/tests/test_extract.py index 4ddc6fa2..a9f324dd 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -4,54 +4,207 @@ from typing import NamedTuple import pytest +from dan.datasets.extract.exceptions import NoEndTokenError from dan.datasets.extract.extract import ArkindexExtractor from dan.datasets.extract.utils import EntityType, insert_token # NamedTuple to mock actual database result Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str) +TOKENS = { + "P": EntityType(start="â“Ÿ", end="â“…"), + "D": EntityType(start="â““", end="â’¹"), + "N": EntityType(start="â“", end="Ⓝ"), + "I": EntityType(start="ⓘ", end="â’¾"), +} + + +def filter_tokens(keys): + return {key: value for key, value in TOKENS.items() if key in keys} + @pytest.mark.parametrize( - "text,count,offset,length,expected", + "text,offset,length,expected", ( - ("n°1 16 janvier 1611", 0, 0, 3, "ⓘn°1â’¾ 16 janvier 1611"), - ("ⓘn°1â’¾ 16 janvier 1611", 2, 4, 15, "ⓘn°1â’¾ ⓘ16 janvier 1611â’¾"), + ("n°1 16 janvier 1611", 0, 3, "ⓘn°1â’¾"), + ("ⓘn°1â’¾ 16 janvier 1611", 6, 15, "ⓘ16 janvier 1611â’¾"), ), ) -def test_insert_token(text, count, offset, length, expected): +def test_insert_token(text, offset, length, expected): assert ( - insert_token(text, count, EntityType(start="ⓘ", end="â’¾"), offset, length) - == expected + insert_token(text, EntityType(start="ⓘ", end="â’¾"), offset, length) == expected ) -@pytest.mark.parametrize( - "text,entities,expected", - ( - ( - "n°1 16 janvier 1611", +def test_reconstruct_text_no_end_token_error(): + arkindex_extractor = ArkindexExtractor() + arkindex_extractor.tokens = { + "X": EntityType(start="ⓧ"), + } + with pytest.raises(NoEndTokenError, match="Label `X` has no end token."): + arkindex_extractor.reconstruct_text( + "n°1 x 16 janvier 1611", [ Entity( offset=0, length=3, - type="P", + type="X", value="n°1", ), + ], + ) + + +@pytest.mark.parametrize( + "entity_separators,tokens,expected", + ( + # Whole text... + # ... + All tokens + ([], TOKENS, "â“Ÿn°1â“… x â““16 janvier 1611â’¹\nâ“MichelⓃ"), + # ... + 1rst and 2nd tokens + ([], filter_tokens(["P", "D"]), "â“Ÿn°1â“… x â““16 janvier 1611â’¹\nMichel"), + # ... + 1rst and 3rd tokens + ([], filter_tokens(["P", "N"]), "â“Ÿn°1â“… x 16 janvier 1611\nâ“MichelⓃ"), + # ... + 2nd and 3rd tokens + ([], filter_tokens(["D", "N"]), "n°1 x â““16 janvier 1611â’¹\nâ“MichelⓃ"), + # Only entities... + # ... + All tokens + (["\n", " "], TOKENS, "â“Ÿn°1â“… â““16 janvier 1611â’¹\nâ“MichelⓃ"), + # ... + 1rst and 2nd tokens + (["\n", " "], filter_tokens(["P", "D"]), "â“Ÿn°1â“… â““16 janvier 1611â’¹"), + # ... + 1rst and 3rd tokens + (["\n", " "], filter_tokens(["P", "N"]), "â“Ÿn°1â“…\nâ“MichelⓃ"), + # ... + 2nd and 3rd tokens + (["\n", " "], filter_tokens(["D", "N"]), "â““16 janvier 1611â’¹\nâ“MichelⓃ"), + ), +) +@pytest.mark.parametrize("text_before", ("", "text before ")) +@pytest.mark.parametrize("text_after", ("", " text after")) +def test_reconstruct_text(entity_separators, tokens, expected, text_before, text_after): + arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators) + arkindex_extractor.tokens = tokens + assert arkindex_extractor.reconstruct_text( + text_before + "n°1 x 16 janvier 1611\nMichel" + text_after, + [ + Entity( + offset=0 + len(text_before), + length=3, + type="P", + value="n°1", + ), + Entity( + offset=6 + len(text_before), + length=15, + type="D", + value="16 janvier 1611", + ), + Entity( + offset=22 + len(text_before), + length=6, + type="N", + value="Michel", + ), + ], + ) == ( + (text_before if not entity_separators else "") + + expected + + (text_after if not entity_separators else "") + ) + + +@pytest.mark.parametrize( + "entity_separators", + ( + # Whole text + [], + # Only entities + [" ", "\n"], + ), +) +@pytest.mark.parametrize("text_before", ("", "text before ")) +@pytest.mark.parametrize("text_after", ("", " text after")) +def test_reconstruct_text_joined_entities(entity_separators, text_before, text_after): + arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators) + arkindex_extractor.tokens = TOKENS + assert arkindex_extractor.reconstruct_text( + text_before + "LouisXIV" + text_after, + [ + Entity( + offset=0 + len(text_before), + length=5, + type="P", + value="Louis", + ), + Entity( + offset=5 + len(text_before), + length=3, + type="I", + value="XIV", + ), + ], + ) == ( + (text_before if not entity_separators else "") + + "â“ŸLouisⓅⓘXIVâ’¾" + + (text_after if not entity_separators else "") + ) + + +@pytest.mark.parametrize("text_before", ("", "text before ")) +@pytest.mark.parametrize("text_after", ("", " text after")) +def test_reconstruct_text_several_separators(text_before, text_after): + arkindex_extractor = ArkindexExtractor(entity_separators=["\n", " "]) + arkindex_extractor.tokens = TOKENS + # Keep "\n" instead of " " + assert ( + arkindex_extractor.reconstruct_text( + text_before + "King\nLouis XIV" + text_after, + [ Entity( - offset=4, - length=15, + offset=0 + len(text_before), + length=4, type="D", - value="16 janvier 1611", + value="King", + ), + Entity( + offset=11 + len(text_before), + length=3, + type="I", + value="XIV", ), ], - "â“Ÿn°1â“… â““16 janvier 1611â’¹", - ), - ), -) -def test_reconstruct_text(text, entities, expected): - arkindex_extractor = ArkindexExtractor() + ) + == "â““Kingâ’¹\nⓘXIVâ’¾" + ) + + +@pytest.mark.parametrize("joined", (True, False)) +@pytest.mark.parametrize("text_before", ("", "text before ")) +@pytest.mark.parametrize("text_after", ("", " text after")) +def test_reconstruct_text_only_start_token(joined, text_before, text_after): + separator = " " if not joined else "" + + arkindex_extractor = ArkindexExtractor(entity_separators=[" ", "\n"]) arkindex_extractor.tokens = { - "P": EntityType(start="â“Ÿ", end="â“…"), - "D": EntityType(start="â““", end="â’¹"), + "P": EntityType(start="â“Ÿ"), + "I": EntityType(start="ⓘ"), } - assert arkindex_extractor.reconstruct_text(text, entities) == expected + assert ( + arkindex_extractor.reconstruct_text( + text_before + "Louis" + separator + "XIV" + text_after, + [ + Entity( + offset=0 + len(text_before), + length=5, + type="P", + value="Louis", + ), + Entity( + offset=5 + len(separator) + len(text_before), + length=3, + type="I", + value="XIV", + ), + ], + ) + == "â“ŸLouis" + separator + "ⓘXIV" + ) -- GitLab