diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 6ea479c602b3557e90d2d256e598472d9574c3a5..8a4def011e469acbbbfb42543042499f051e2562 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -109,6 +109,12 @@ def add_extract_parser(subcommands) -> None: required=False, default=list(map(validate_char, ("\n", " "))), ) + parser.add_argument( + "--unknown-token", + type=str, + help="Token to use to replace character in the validation/test sets that is not included in the training set.", + default="â‡", + ) parser.add_argument( "--tokens", type=pathlib.Path, diff --git a/dan/datasets/extract/exceptions.py b/dan/datasets/extract/exceptions.py index 2155a6ca2383cf7ad9b499d95ef7de2eed5daaad..74e1b332c0fd2d91748cc45bf76025769292e8ea 100644 --- a/dan/datasets/extract/exceptions.py +++ b/dan/datasets/extract/exceptions.py @@ -44,6 +44,15 @@ class NoTranscriptionError(ElementProcessingError): return f"No transcriptions found on element ({self.element_id}) with this config. Skipping." +class UnknownTokenInText(ElementProcessingError): + """ + Raised when the unknown token is found in a transcription text + """ + + def __str__(self) -> str: + return f"Unknown token found in the transcription text of element ({self.element_id})" + + class NoEndTokenError(ProcessingError): """ Raised when the specified label has no end token and there is potentially additional text around the labels diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index 440986ec32bd1e1119a98be34aa02482013691ca..e6847e0c4b3a34996846a7fcd2e1bba0d94edb9f 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -27,6 +27,7 @@ from dan.datasets.extract.exceptions import ( NoEndTokenError, NoTranscriptionError, ProcessingError, + UnknownTokenInText, ) from dan.datasets.extract.utils import ( download_image, @@ -44,7 +45,8 @@ from line_image_extractor.image_utils import ( IMAGES_DIR = "images" # Subpath to the images directory. -SPLIT_NAMES = ["train", "val", "test"] +TRAIN_NAME = "train" +SPLIT_NAMES = [TRAIN_NAME, "val", "test"] IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg" # IIIF 2.0 uses `full` IIIF_FULL_SIZE = "full" @@ -64,6 +66,7 @@ class ArkindexExtractor: parent_element_type: str = None, output: Path = None, entity_separators: List[str] = ["\n", " "], + unknown_token: str = "â‡", tokens: Path = None, transcription_worker_version: Optional[Union[str, bool]] = None, entity_worker_version: Optional[Union[str, bool]] = None, @@ -77,6 +80,7 @@ class ArkindexExtractor: self.parent_element_type = parent_element_type self.output = output self.entity_separators = entity_separators + self.unknown_token = unknown_token self.tokens = parse_tokens(tokens) if tokens else None self.transcription_worker_version = transcription_worker_version self.entity_worker_version = entity_worker_version @@ -245,9 +249,20 @@ class ArkindexExtractor: split=split, path=str(destination), url=download_url, exc=e ) - def format_text(self, text: str): + def format_text(self, text: str, charset: Optional[set] = None): if not self.keep_spaces: text = remove_spaces(text) + + # Replace unknown characters by the unknown token + if charset is not None: + unknown_charset = set(text) - charset + text = text.translate( + { + ord(unknown_char): self.unknown_token + for unknown_char in unknown_charset + } + ) + return text.strip() def process_element( @@ -261,6 +276,9 @@ class ArkindexExtractor: """ text = self.extract_transcription(element) + if self.unknown_token in text: + raise UnknownTokenInText(element_id=element.id) + image_path = Path(self.output, IMAGES_DIR, split, element.id).with_suffix( self.image_extension ) @@ -276,7 +294,13 @@ class ArkindexExtractor: } ) - self.data[split][str(image_path)] = self.format_text(text) + text = self.format_text( + text, + # Do not replace unknown characters in train split + charset=self.charset if split != TRAIN_NAME else None, + ) + + self.data[split][str(image_path)] = text self.charset = self.charset.union(set(text)) def process_parent( @@ -390,6 +414,7 @@ def run( parent_element_type: str, output: Path, entity_separators: List[str], + unknown_token: str, tokens: Path, train_folder: UUID, val_folder: UUID, @@ -416,6 +441,7 @@ def run( parent_element_type=parent_element_type, output=output, entity_separators=entity_separators, + unknown_token=unknown_token, tokens=tokens, transcription_worker_version=transcription_worker_version, entity_worker_version=entity_worker_version, diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md index ae0c9059b21fc6bccc98de1e684a68bd946fdcf9..a7715d59e07d1e776b05ff72a61b8d034e88dc4f 100644 --- a/docs/usage/datasets/extract.md +++ b/docs/usage/datasets/extract.md @@ -17,6 +17,7 @@ If an image download fails for whatever reason, it won't appear in the transcrip | `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` | | `--output` | Folder where the data will be generated. | `Path` | | | `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text. | `str` | `["\n", " "]` (see [dedicated section](#examples)) | +| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `â‡` | | `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `Path` | | | `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | | | `--val-folder` | ID of the validation folder to import from Arkindex. | `uuid` | | diff --git a/tests/data/extraction/elements/test-page_1-line_1.json b/tests/data/extraction/elements/test-page_1-line_1.json index ef63f940e968e245027e2ab0d4dbbf4d4b0ecc6f..716bf50aa37442665d7ebc7c9246ac3b0bf01739 100644 --- a/tests/data/extraction/elements/test-page_1-line_1.json +++ b/tests/data/extraction/elements/test-page_1-line_1.json @@ -14,7 +14,7 @@ "offset": 0 }, { - "name": "Louis", + "name": "Bouis", "type": "firstname", "offset": 7 }, diff --git a/tests/data/extraction/elements/test-page_1-line_3.json b/tests/data/extraction/elements/test-page_1-line_3.json index f196d3cf4d91072ff57aa9f51eaa7569caf7aaef..78fbe7865e47e062dc7b2b9682edc337b3294e53 100644 --- a/tests/data/extraction/elements/test-page_1-line_3.json +++ b/tests/data/extraction/elements/test-page_1-line_3.json @@ -14,7 +14,7 @@ "offset": 0 }, { - "name": "François", + "name": "Français", "type": "firstname", "offset": 7 }, diff --git a/tests/data/extraction/elements/test-page_2-line_1.json b/tests/data/extraction/elements/test-page_2-line_1.json index a9b2498f08840eba97648647f279d3ee1b96fa69..9d5c131b90c8b1bd52cd6e365899c8f43afb070e 100644 --- a/tests/data/extraction/elements/test-page_2-line_1.json +++ b/tests/data/extraction/elements/test-page_2-line_1.json @@ -14,7 +14,7 @@ "offset": 0 }, { - "name": "Louis", + "name": "Bouis", "type": "firstname", "offset": 8 }, diff --git a/tests/data/extraction/elements/train-page_2-line_2.json b/tests/data/extraction/elements/train-page_2-line_2.json index c54e23b61690f1623306e42f79861336d00cd192..bc6829dab843ea9b68a2bbdac8df91c575ebc96d 100644 --- a/tests/data/extraction/elements/train-page_2-line_2.json +++ b/tests/data/extraction/elements/train-page_2-line_2.json @@ -9,7 +9,7 @@ ], "transcription_entities": [ { - "name": "Roques", + "name": "Amical", "type": "surname", "offset": 0 }, diff --git a/tests/data/extraction/elements/train-page_2-line_3.json b/tests/data/extraction/elements/train-page_2-line_3.json index 200e3b3db491b812ddb11bd922af7adba8319cba..90432163e7449c6c228eb25d1024fc23eec988ca 100644 --- a/tests/data/extraction/elements/train-page_2-line_3.json +++ b/tests/data/extraction/elements/train-page_2-line_3.json @@ -9,12 +9,12 @@ ], "transcription_entities": [ { - "name": "Giros", + "name": "Biros", "type": "surname", "offset": 0 }, { - "name": "Paul", + "name": "Mael", "type": "firstname", "offset": 6 }, diff --git a/tests/data/extraction/elements/val-page_1-line_1.json b/tests/data/extraction/elements/val-page_1-line_1.json index 371f730d2929a234718152eebcc4778a78e900c3..d255644499a7858d72b4e636f0471b61abd19c93 100644 --- a/tests/data/extraction/elements/val-page_1-line_1.json +++ b/tests/data/extraction/elements/val-page_1-line_1.json @@ -14,7 +14,7 @@ "offset": 0 }, { - "name": "Louis", + "name": "Bouis", "type": "firstname", "offset": 7 }, diff --git a/tests/test_extract.py b/tests/test_extract.py index 24ec2905aeb60176b6a83105abf81309a29e4e06..529810005df89fc9390fb2e84045327e99d90b90 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -11,7 +11,8 @@ from unittest.mock import patch import pytest from PIL import Image, ImageChops -from dan.datasets.extract.exceptions import NoEndTokenError +from arkindex_export import Element, Transcription +from dan.datasets.extract.exceptions import NoEndTokenError, UnknownTokenInText from dan.datasets.extract.extract import IIIF_FULL_SIZE, ArkindexExtractor from dan.datasets.extract.utils import EntityType, insert_token, remove_spaces from dan.utils import parse_tokens @@ -259,6 +260,34 @@ def test_reconstruct_text_only_start_token(joined, text_before, text_after): ) +def test_process_element_unknown_token_in_text_error(mock_database, tmp_path): + output = tmp_path / "extraction" + arkindex_extractor = ArkindexExtractor(output=output) + + # Create an element with an invalid transcription + element = Element.create( + id="element_id", + name="1", + type="page", + polygon="[]", + created=0.0, + updated=0.0, + ) + Transcription.create( + id="transcription_id", + text="Is this text validâ‡", + element=element, + ) + + with pytest.raises( + UnknownTokenInText, + match=re.escape( + "Unknown token found in the transcription text of element (element_id)" + ), + ): + arkindex_extractor.process_element(element, "val") + + @pytest.mark.parametrize("load_entities", (True, False)) @pytest.mark.parametrize("keep_spaces", (True, False)) # Transcription and entities have the same worker version @@ -343,12 +372,12 @@ def test_extract( # Check "labels.json" expected_labels = { "test": { - str(TEST_DIR / "test-page_1-line_1.jpg"): "â“¢Coupez â“•Louis â“‘7.12.14", - str(TEST_DIR / "test-page_1-line_2.jpg"): "â“¢Poutrain â“•Adolphe â“‘9.4.13", - str(TEST_DIR / "test-page_1-line_3.jpg"): "â“¢Gabale â“•François â“‘26.3.11", - str(TEST_DIR / "test-page_2-line_1.jpg"): "â“¢Durosoy â“•Louis â“‘22-4-18", - str(TEST_DIR / "test-page_2-line_2.jpg"): "â“¢Colaiani â“•Angels â“‘28.11.17", - str(TEST_DIR / "test-page_2-line_3.jpg"): "â“¢Renouard â“•Maurice â“‘25.7.04", + str(TEST_DIR / "test-page_1-line_1.jpg"): "â“¢Couâ‡e⇠ⓕBouis â“‘â‡.12.14", + str(TEST_DIR / "test-page_1-line_2.jpg"): "â“¢â‡outrain â“•Aâ‡olâ‡â‡e â“‘9.4.13", + str(TEST_DIR / "test-page_1-line_3.jpg"): "â“¢â‡abale â“•â‡ranâ‡ais â“‘26.3.11", + str(TEST_DIR / "test-page_2-line_1.jpg"): "â“¢â‡urosoy â“•Bouis â“‘22â‡4â‡18", + str(TEST_DIR / "test-page_2-line_2.jpg"): "â“¢Colaiani â“•Anâ‡els â“‘28.11.1â‡", + str(TEST_DIR / "test-page_2-line_3.jpg"): "â“¢Renouar⇠ⓕMaurice â“‘2â‡.â‡.04", }, "train": { str(TRAIN_DIR / "train-page_1-line_1.jpg"): "â“¢Caillet â“•Maurice â“‘28.9.06", @@ -356,13 +385,13 @@ def test_extract( str(TRAIN_DIR / "train-page_1-line_3.jpg"): "â“¢Bareyre â“•Jean â“‘28.3.11", str(TRAIN_DIR / "train-page_1-line_4.jpg"): "â“¢Roussy â“•Jean â“‘4.11.14", str(TRAIN_DIR / "train-page_2-line_1.jpg"): "â“¢Marin â“•Marcel â“‘10.8.06", - str(TRAIN_DIR / "train-page_2-line_2.jpg"): "â“¢Roques â“•Eloi â“‘11.10.04", - str(TRAIN_DIR / "train-page_2-line_3.jpg"): "â“¢Giros â“•Paul â“‘30.10.10", + str(TRAIN_DIR / "train-page_2-line_2.jpg"): "â“¢Amical â“•Eloi â“‘11.10.04", + str(TRAIN_DIR / "train-page_2-line_3.jpg"): "â“¢Biros â“•Mael â“‘30.10.10", }, "val": { - str(VAL_DIR / "val-page_1-line_1.jpg"): "â“¢Monard â“•Louis â“‘29-7-04", - str(VAL_DIR / "val-page_1-line_2.jpg"): "â“¢Astier â“•Arthur â“‘11-2-13", - str(VAL_DIR / "val-page_1-line_3.jpg"): "â“¢De Vlieger â“•Jules â“‘21-11-11", + str(VAL_DIR / "val-page_1-line_1.jpg"): "â“¢Monar⇠ⓕBouis â“‘29â‡â‡â‡04", + str(VAL_DIR / "val-page_1-line_2.jpg"): "â“¢Astier â“•Artâ‡ur â“‘11â‡2â‡13", + str(VAL_DIR / "val-page_1-line_3.jpg"): "â“¢â‡e â‡lieâ‡er â“•Jules â“‘21â‡11â‡11", }, } @@ -393,12 +422,12 @@ def test_extract( # Check "charset.pkl" expected_charset = set() - for labels in expected_labels.values(): - for label in labels.values(): - expected_charset.update(set(label)) + for label in expected_labels["train"].values(): + expected_charset.update(set(label)) if load_entities: expected_charset.update(tokens) + expected_charset.add("â‡") assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset # Check cropped images