From 137d8b438e1d069a955fca2be814ede2d1b57993 Mon Sep 17 00:00:00 2001 From: Manon Blanco <blanco@teklia.com> Date: Fri, 15 Dec 2023 10:17:21 +0000 Subject: [PATCH] Support multiple datasets from Arkindex as input --- dan/datasets/download/images.py | 15 ++++--- dan/datasets/extract/__init__.py | 2 + dan/datasets/extract/arkindex.py | 75 ++++++++++++++++---------------- dan/datasets/extract/db.py | 2 +- docs/usage/datasets/download.md | 3 +- tests/conftest.py | 7 ++- tests/data/extraction/split.json | 16 +++++++ tests/test_db.py | 4 +- tests/test_download.py | 6 +-- tests/test_extract.py | 31 ++++++------- 10 files changed, 91 insertions(+), 70 deletions(-) diff --git a/dan/datasets/download/images.py b/dan/datasets/download/images.py index b492377f..c1702e51 100644 --- a/dan/datasets/download/images.py +++ b/dan/datasets/download/images.py @@ -62,11 +62,15 @@ class ImageDownloader: self.data: Dict = defaultdict(dict) def check_extraction(self, values: dict) -> str | None: + # Check dataset_id parameter + if values.get("dataset_id") is None: + return "Dataset ID not found" + # Check image parameters if not (image := values.get("image")): return "Image information not found" - # Only support `iiif_url` with `polygon` for now + # Only support iiif_url with polygon for now if not image.get("iiif_url"): return "Image IIIF URL not found" if not image.get("polygon"): @@ -113,15 +117,16 @@ class ImageDownloader: destination.mkdir(parents=True, exist_ok=True) for element_id, values in items.items(): - image_path = (destination / element_id).with_suffix( - self.image_extension - ) + filename = Path(element_id).with_suffix(self.image_extension) error = self.check_extraction(values) if error: - logger.warning(f"{image_path}: {error}") + logger.warning(f"{destination / filename}: {error}") continue + image_path = destination / values["dataset_id"] / filename + image_path.parent.mkdir(parents=True, exist_ok=True) + self.data[split][str(image_path)] = values["text"] # Create task for multithreading pool if image does not exist yet diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 3e05ba0d..278ee892 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -48,9 +48,11 @@ def add_extract_parser(subcommands) -> None: ) parser.add_argument( "--dataset-id", + nargs="+", type=UUID, help="ID of the dataset to extract from Arkindex.", required=True, + dest="dataset_ids", ) parser.add_argument( "--element-type", diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py index 86f2f3ba..8e7be4c4 100644 --- a/dan/datasets/extract/arkindex.py +++ b/dan/datasets/extract/arkindex.py @@ -11,9 +11,8 @@ from uuid import UUID from tqdm import tqdm -from arkindex_export import Dataset, open_database +from arkindex_export import Dataset, DatasetElement, Element, open_database from dan.datasets.extract.db import ( - Element, get_dataset_elements, get_elements, get_transcription_entities, @@ -51,7 +50,7 @@ class ArkindexExtractor: def __init__( self, - dataset_id: UUID | None = None, + dataset_ids: List[UUID] | None = None, element_type: List[str] = [], output: Path | None = None, entity_separators: List[str] = ["\n", " "], @@ -63,7 +62,7 @@ class ArkindexExtractor: allow_empty: bool = False, subword_vocab_size: int = 1000, ) -> None: - self.dataset_id = dataset_id + self.dataset_ids = dataset_ids self.element_type = element_type self.output = output self.entity_separators = entity_separators @@ -139,7 +138,7 @@ class ArkindexExtractor: ) return text.strip() - def process_element(self, element: Element, split: str): + def process_element(self, dataset_parent: DatasetElement, element: Element): """ Extract an element's data and save it to disk. The output path is directly related to the split of the element. @@ -152,10 +151,11 @@ class ArkindexExtractor: text = self.format_text( text, # Do not replace unknown characters in train split - charset=self.charset if split != TRAIN_NAME else None, + charset=self.charset if dataset_parent.set_name != TRAIN_NAME else None, ) - self.data[split][element.id] = { + self.data[dataset_parent.set_name][element.id] = { + "dataset_id": dataset_parent.dataset_id, "text": text, "image": { "iiif_url": element.image.url, @@ -165,17 +165,16 @@ class ArkindexExtractor: self.charset = self.charset.union(set(text)) - def process_parent(self, pbar, parent: Element, split: str): + def process_parent(self, pbar, dataset_parent: DatasetElement): """ Extract data from a parent element. """ - base_description = ( - f"Extracting data from {parent.type} ({parent.id}) for split ({split})" - ) + parent = dataset_parent.element + base_description = f"Extracting data from {parent.type} ({parent.id}) for split ({dataset_parent.set_name})" pbar.set_description(desc=base_description) if self.element_type == [parent.type]: try: - self.process_element(parent, split) + self.process_element(dataset_parent, parent) except ProcessingError as e: logger.warning(f"Skipping {parent.id}: {str(e)}") # Extract children elements @@ -190,7 +189,7 @@ class ArkindexExtractor: # Update description to update the children processing progress pbar.set_description(desc=base_description + f" ({idx}/{nb_children})") try: - self.process_element(element, split) + self.process_element(dataset_parent, element) except ProcessingError as e: logger.warning(f"Skipping {element.id}: {str(e)}") @@ -274,28 +273,30 @@ class ArkindexExtractor: def run(self): # Retrieve the Dataset and its splits from the cache - dataset = Dataset.get_by_id(self.dataset_id) - splits = dataset.sets.split(",") - assert set(splits).issubset( - set(SPLIT_NAMES) - ), f'Dataset must have "{TRAIN_NAME}", "{VAL_NAME}" and "{TEST_NAME}" steps' - - # Iterate over the subsets to find the page images and labels. - for split in splits: - with tqdm( - get_dataset_elements(dataset, split), - desc=f"Extracting data from ({self.dataset_id}) for split ({split})", - ) as pbar: - # Iterate over the pages to create splits at page level. - for parent in pbar: - self.process_parent( - pbar=pbar, - parent=parent.element, - split=split, - ) - # Progress bar updates - pbar.update() - pbar.refresh() + for dataset_id in self.dataset_ids: + dataset = Dataset.get_by_id(dataset_id) + splits = dataset.sets.split(",") + if not set(splits).issubset(set(SPLIT_NAMES)): + logger.warning( + f'Dataset {dataset.name} ({dataset.id}) does not have "{TRAIN_NAME}", "{VAL_NAME}" and "{TEST_NAME}" steps' + ) + continue + + # Iterate over the subsets to find the page images and labels. + for split in splits: + with tqdm( + get_dataset_elements(dataset, split), + desc=f"Extracting data from ({dataset_id}) for split ({split})", + ) as pbar: + # Iterate over the pages to create splits at page level. + for parent in pbar: + self.process_parent( + pbar=pbar, + dataset_parent=parent, + ) + # Progress bar updates + pbar.update() + pbar.refresh() if not self.data: raise Exception( @@ -308,7 +309,7 @@ class ArkindexExtractor: def run( database: Path, - dataset_id: UUID, + dataset_ids: List[UUID], element_type: List[str], output: Path, entity_separators: List[str], @@ -327,7 +328,7 @@ def run( Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True) ArkindexExtractor( - dataset_id=dataset_id, + dataset_ids=dataset_ids, element_type=element_type, output=output, entity_separators=entity_separators, diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py index 3b89902c..25146799 100644 --- a/dan/datasets/extract/db.py +++ b/dan/datasets/extract/db.py @@ -22,7 +22,7 @@ def get_dataset_elements( Retrieve dataset elements in a specific split from an SQLite export of an Arkindex corpus """ query = ( - DatasetElement.select(DatasetElement.element) + DatasetElement.select() .join(Element) .join(Image, on=(DatasetElement.element.image == Image.id)) .where( diff --git a/docs/usage/datasets/download.md b/docs/usage/datasets/download.md index 77f3c7f9..221f6108 100644 --- a/docs/usage/datasets/download.md +++ b/docs/usage/datasets/download.md @@ -22,6 +22,7 @@ The `--output` directory should have a `split.json` JSON-formatted file with a s { "train": { "<element_id>": { + "dataset_id": "<dataset_id>", "image": { "iiif_url": "https://<iiif_server>/iiif/2/<path>", "polygon": [ @@ -32,7 +33,7 @@ The `--output` directory should have a `split.json` JSON-formatted file with a s [37, 191] ] }, - "text": "â“¢Couâ‡e⇠ⓕBouis â“‘â‡.12.14" + "text": "â“¢Coufet â“•Bouis â“‘07.12.14" }, }, "val": {}, diff --git a/tests/conftest.py b/tests/conftest.py index 2cb4c475..4ad5b89c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,7 @@ from dan.datasets.extract.arkindex import SPLIT_NAMES from tests import FIXTURES -@pytest.fixture(scope="session") +@pytest.fixture() def mock_database(tmp_path_factory): def create_transcription_entity( transcription: Transcription, @@ -182,7 +182,10 @@ def mock_database(tmp_path_factory): # Create dataset dataset = Dataset.create( - id="dataset", name="Dataset", state="complete", sets=",".join(SPLIT_NAMES) + id="dataset_id", + name="Dataset", + state="complete", + sets=",".join(SPLIT_NAMES), ) # Create dataset elements diff --git a/tests/data/extraction/split.json b/tests/data/extraction/split.json index 36a6ce49..e264f689 100644 --- a/tests/data/extraction/split.json +++ b/tests/data/extraction/split.json @@ -1,6 +1,7 @@ { "test": { "test-page_1-line_1": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_1-line_1.jpg", "polygon": [ @@ -29,6 +30,7 @@ "text": "â“¢Couâ‡e⇠ⓕBouis â“‘â‡.12.14" }, "test-page_1-line_2": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_1-line_2.jpg", "polygon": [ @@ -57,6 +59,7 @@ "text": "â“¢â‡outrain â“•Aâ‡olâ‡â‡e â“‘9.4.13" }, "test-page_1-line_3": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_1-line_3.jpg", "polygon": [ @@ -85,6 +88,7 @@ "text": "â“¢â‡abale â“•â‡ranâ‡ais â“‘26.3.11" }, "test-page_2-line_1": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_2-line_1.jpg", "polygon": [ @@ -113,6 +117,7 @@ "text": "â“¢â‡urosoy â“•Bouis â“‘22â‡4â‡18" }, "test-page_2-line_2": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_2-line_2.jpg", "polygon": [ @@ -141,6 +146,7 @@ "text": "â“¢Colaiani â“•Anâ‡els â“‘28.11.1â‡" }, "test-page_2-line_3": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_2-line_3.jpg", "polygon": [ @@ -171,6 +177,7 @@ }, "train": { "train-page_1-line_1": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_1.jpg", "polygon": [ @@ -199,6 +206,7 @@ "text": "â“¢Caillet â“•Maurice â“‘28.9.06" }, "train-page_1-line_2": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_2.jpg", "polygon": [ @@ -227,6 +235,7 @@ "text": "â“¢Reboul â“•Jean â“‘30.9.02" }, "train-page_1-line_3": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_3.jpg", "polygon": [ @@ -255,6 +264,7 @@ "text": "â“¢Bareyre â“•Jean â“‘28.3.11" }, "train-page_1-line_4": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_4.jpg", "polygon": [ @@ -283,6 +293,7 @@ "text": "â“¢Roussy â“•Jean â“‘4.11.14" }, "train-page_2-line_1": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_2-line_1.jpg", "polygon": [ @@ -311,6 +322,7 @@ "text": "â“¢Marin â“•Marcel â“‘10.8.06" }, "train-page_2-line_2": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_2-line_2.jpg", "polygon": [ @@ -339,6 +351,7 @@ "text": "â“¢Amical â“•Eloi â“‘11.10.04" }, "train-page_2-line_3": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_2-line_3.jpg", "polygon": [ @@ -369,6 +382,7 @@ }, "val": { "val-page_1-line_1": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/val-page_1-line_1.jpg", "polygon": [ @@ -397,6 +411,7 @@ "text": "â“¢Monar⇠ⓕBouis â“‘29â‡â‡â‡04" }, "val-page_1-line_2": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/val-page_1-line_2.jpg", "polygon": [ @@ -425,6 +440,7 @@ "text": "â“¢Astier â“•Artâ‡ur â“‘11â‡2â‡13" }, "val-page_1-line_3": { + "dataset_id": "dataset_id", "image": { "iiif_url": "{FIXTURES}/extraction/images/text_line/val-page_1-line_3.jpg", "polygon": [ diff --git a/tests/test_db.py b/tests/test_db.py index ac39d09f..0da81d8e 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -4,11 +4,9 @@ from operator import itemgetter import pytest +from arkindex_export import Dataset, DatasetElement, Element from dan.datasets.extract.arkindex import TRAIN_NAME from dan.datasets.extract.db import ( - Dataset, - DatasetElement, - Element, get_dataset_elements, get_elements, get_transcription_entities, diff --git a/tests/test_download.py b/tests/test_download.py index d5348ec3..174e69a4 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -58,9 +58,9 @@ def test_download(split_content, monkeypatch, tmp_path): # Check files IMAGE_DIR = output / "images" - TEST_DIR = IMAGE_DIR / "test" - TRAIN_DIR = IMAGE_DIR / "train" - VAL_DIR = IMAGE_DIR / "val" + TEST_DIR = IMAGE_DIR / "test" / "dataset_id" + TRAIN_DIR = IMAGE_DIR / "train" / "dataset_id" + VAL_DIR = IMAGE_DIR / "val" / "dataset_id" expected_paths = [ # Images of test folder diff --git a/tests/test_extract.py b/tests/test_extract.py index a5ed2cd9..46721185 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -8,7 +8,12 @@ from typing import NamedTuple import pytest -from arkindex_export import Element, Transcription, TranscriptionEntity +from arkindex_export import ( + DatasetElement, + Element, + Transcription, + TranscriptionEntity, +) from dan.datasets.extract.arkindex import ArkindexExtractor from dan.datasets.extract.db import get_transcription_entities from dan.datasets.extract.exceptions import ( @@ -85,28 +90,18 @@ def test_process_element_unknown_token_in_text_error(mock_database, tmp_path): output = tmp_path / "extraction" arkindex_extractor = ArkindexExtractor(output=output) - # Create an element with an invalid transcription - element = Element.create( - id="element_id", - name="1", - type="page", - polygon="[]", - created=0.0, - updated=0.0, - ) - Transcription.create( - id="transcription_id", - text="Is this text validâ‡", - element=element, - ) + # Retrieve a dataset element and update its transcription with an invalid one + dataset_element = DatasetElement.select().first() + element = dataset_element.element + Transcription.update({Transcription.text: "Is this text validâ‡"}).execute() with pytest.raises( UnknownTokenInText, match=re.escape( - "Unknown token found in the transcription text of element (element_id)" + f"Unknown token found in the transcription text of element ({element.id})" ), ): - arkindex_extractor.process_element(element, "val") + arkindex_extractor.process_element(dataset_element, element) @pytest.mark.parametrize( @@ -253,7 +248,7 @@ def test_extract( ] extractor = ArkindexExtractor( - dataset_id="dataset", + dataset_ids=["dataset_id"], element_type=["text_line"], output=output, # Keep the whole text -- GitLab