diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index ad8155c3bbd3a59855a687543171ffffa46e669a..3e05ba0d1185de217b15a885d232ba2e7124ffca 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -46,6 +46,12 @@ def add_extract_parser(subcommands) -> None: type=pathlib.Path, help="Path where the data were exported from Arkindex.", ) + parser.add_argument( + "--dataset-id", + type=UUID, + help="ID of the dataset to extract from Arkindex.", + required=True, + ) parser.add_argument( "--element-type", nargs="+", @@ -53,13 +59,6 @@ def add_extract_parser(subcommands) -> None: help="Type of elements to retrieve.", required=True, ) - parser.add_argument( - "--parent-element-type", - type=str, - help="Type of the parent element containing the data.", - required=False, - default="page", - ) parser.add_argument( "--output", type=pathlib.Path, @@ -67,25 +66,6 @@ def add_extract_parser(subcommands) -> None: required=True, ) - parser.add_argument( - "--train-folder", - type=UUID, - help="ID of the training folder to extract from Arkindex.", - required=True, - ) - parser.add_argument( - "--val-folder", - type=UUID, - help="ID of the validation folder to extract from Arkindex.", - required=True, - ) - parser.add_argument( - "--test-folder", - type=UUID, - help="ID of the testing folder to extract from Arkindex.", - required=True, - ) - # Optional arguments. parser.add_argument( "--entity-separators", diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py index 0c8b5abc46dbe8ec069a33460e710ac46324cd85..86f2f3ba59d0523e4a487d9bb9759f5d60ca24fa 100644 --- a/dan/datasets/extract/arkindex.py +++ b/dan/datasets/extract/arkindex.py @@ -11,9 +11,10 @@ from uuid import UUID from tqdm import tqdm -from arkindex_export import open_database +from arkindex_export import Dataset, open_database from dan.datasets.extract.db import ( Element, + get_dataset_elements, get_elements, get_transcription_entities, get_transcriptions, @@ -36,7 +37,9 @@ from dan.utils import LMTokenMapping, parse_tokens LANGUAGE_DIR = "language_model" # Subpath to the language model directory. TRAIN_NAME = "train" -SPLIT_NAMES = [TRAIN_NAME, "val", "test"] +VAL_NAME = "val" +TEST_NAME = "test" +SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME] logger = logging.getLogger(__name__) @@ -48,9 +51,8 @@ class ArkindexExtractor: def __init__( self, - folders: list = [], + dataset_id: UUID | None = None, element_type: List[str] = [], - parent_element_type: str | None = None, output: Path | None = None, entity_separators: List[str] = ["\n", " "], unknown_token: str = "â‡", @@ -61,9 +63,8 @@ class ArkindexExtractor: allow_empty: bool = False, subword_vocab_size: int = 1000, ) -> None: - self.folders = folders + self.dataset_id = dataset_id self.element_type = element_type - self.parent_element_type = parent_element_type self.output = output self.entity_separators = entity_separators self.unknown_token = unknown_token @@ -272,20 +273,24 @@ class ArkindexExtractor: ) def run(self): + # Retrieve the Dataset and its splits from the cache + dataset = Dataset.get_by_id(self.dataset_id) + splits = dataset.sets.split(",") + assert set(splits).issubset( + set(SPLIT_NAMES) + ), f'Dataset must have "{TRAIN_NAME}", "{VAL_NAME}" and "{TEST_NAME}" steps' + # Iterate over the subsets to find the page images and labels. - for folder_id, split in zip(self.folders, SPLIT_NAMES): + for split in splits: with tqdm( - get_elements( - folder_id, - [self.parent_element_type], - ), - desc=f"Extracting data from ({folder_id}) for split ({split})", + get_dataset_elements(dataset, split), + desc=f"Extracting data from ({self.dataset_id}) for split ({split})", ) as pbar: # Iterate over the pages to create splits at page level. for parent in pbar: self.process_parent( pbar=pbar, - parent=parent, + parent=parent.element, split=split, ) # Progress bar updates @@ -303,15 +308,12 @@ class ArkindexExtractor: def run( database: Path, + dataset_id: UUID, element_type: List[str], - parent_element_type: str, output: Path, entity_separators: List[str], unknown_token: str, tokens: Path, - train_folder: UUID, - val_folder: UUID, - test_folder: UUID, transcription_worker_version: str | bool | None, entity_worker_version: str | bool | None, keep_spaces: bool, @@ -321,15 +323,12 @@ def run( assert database.exists(), f"No file found @ {database}" open_database(path=database) - folders = [str(train_folder), str(val_folder), str(test_folder)] - # Create directories Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True) ArkindexExtractor( - folders=folders, + dataset_id=dataset_id, element_type=element_type, - parent_element_type=parent_element_type, output=output, entity_separators=entity_separators, unknown_token=unknown_token, diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py index 5aeccbf0d78ae744d8d531d20adf18cc20dbcfa8..3b89902c2fb11de9e07e91d900e2dd665dc483ac 100644 --- a/dan/datasets/extract/db.py +++ b/dan/datasets/extract/db.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- - from typing import List from arkindex_export import Image from arkindex_export.models import ( + Dataset, + DatasetElement, Element, Entity, EntityType, @@ -13,6 +14,26 @@ from arkindex_export.models import ( from arkindex_export.queries import list_children +def get_dataset_elements( + dataset: Dataset, + split: str, +): + """ + Retrieve dataset elements in a specific split from an SQLite export of an Arkindex corpus + """ + query = ( + DatasetElement.select(DatasetElement.element) + .join(Element) + .join(Image, on=(DatasetElement.element.image == Image.id)) + .where( + DatasetElement.dataset == dataset, + DatasetElement.set_name == split, + ) + ) + + return query + + def get_elements( parent_id: str, element_type: List[str], diff --git a/docs/get_started/training.md b/docs/get_started/training.md index de06a3b3d897acd8ed20757c746e0f5830de5b71..93a4a38c4150323ebbb9b83b8c4cb68b412e3865 100644 --- a/docs/get_started/training.md +++ b/docs/get_started/training.md @@ -6,7 +6,7 @@ There are a several steps to follow when training a DAN model. To extract the data, DAN uses an Arkindex export database in SQLite format. You will need to: -1. Structure the data into folders (`train` / `val` / `test`) in [Arkindex](https://demo.arkindex.org/). +1. Structure the data into splits (`train` / `val` / `test`) in a project dataset in [Arkindex](https://demo.arkindex.org/). 1. [Export the project](https://doc.arkindex.org/howto/export/) in SQLite format. 1. Extract the data with the [extract command](../usage/datasets/extract.md). 1. Download images with the [download command](../usage/datasets/download.md). diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md index 2ac15b6b90799675ede631ab5214e9e67fecde43..f3f5f05c4f6193adc5510dfe583a43717bbc00da 100644 --- a/docs/usage/datasets/extract.md +++ b/docs/usage/datasets/extract.md @@ -11,15 +11,12 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind | Parameter | Description | Type | Default | | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- | | `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | | +| `--dataset-id ` | ID of the dataset to extract from Arkindex. | `uuid` | | | `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | -| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` | | `--output` | Folder where the data will be generated. | `pathlib.Path` | | | `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | | | `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `â‡` | | `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | | -| `--train-folder` | ID of the training folder to extract from Arkindex. | `uuid` | | -| `--val-folder` | ID of the validation folder to extract from Arkindex. | `uuid` | | -| `--test-folder` | ID of the training folder to extract from Arkindex. | `uuid` | | | `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | | `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | | `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` | @@ -61,9 +58,7 @@ To use the data from three folders as **training**, **validation** and **testing ```shell teklia-dan dataset extract \ database.sqlite \ - --train-folder train_folder_uuid \ - --val-folder val_folder_uuid \ - --test-folder test_folder_uuid \ + --dataset-id dataset_uuid \ --element-type page \ --output data \ --tokens tokens.yml @@ -115,10 +110,7 @@ To extract HTR data from **annotations** and **text_zones** from each folder, bu ```shell teklia-dan dataset extract \ database.sqlite \ - --train-folder train_folder_uuid \ - --val-folder val_folder_uuid \ - --test-folder test_folder_uuid \ + --dataset-id dataset_uuid \ --element-type text_zone annotation \ - --parent-element-type single_page \ --output data ``` diff --git a/tests/conftest.py b/tests/conftest.py index 2fe9bc4b872f76a7daffe34f36e5f57d7650ec37..2cb4c4754ff65c67a0928c9922efe322e9cc3b4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,8 @@ from typing import List import pytest from arkindex_export import ( + Dataset, + DatasetElement, Element, ElementPath, Entity, @@ -19,6 +21,7 @@ from arkindex_export import ( WorkerVersion, database, ) +from dan.datasets.extract.arkindex import SPLIT_NAMES from tests import FIXTURES @@ -133,6 +136,8 @@ def mock_database(tmp_path_factory): WorkerRun, ImageServer, Image, + Dataset, + DatasetElement, Element, ElementPath, EntityType, @@ -175,8 +180,26 @@ def mock_database(tmp_path_factory): type="worker", ) - # Create folders - create_element(id="root") + # Create dataset + dataset = Dataset.create( + id="dataset", name="Dataset", state="complete", sets=",".join(SPLIT_NAMES) + ) + + # Create dataset elements + for split in SPLIT_NAMES: + element_path = (FIXTURES / "extraction" / "elements" / split).with_suffix( + ".json" + ) + element_json = json.loads(element_path.read_text()) + + # Recursive function to create children + for child in element_json.get("children", []): + create_element(id=child) + + # Linking the element to the dataset split + DatasetElement.create( + id=child, element_id=child, dataset=dataset, set_name=split + ) # Create data for entities extraction tests # Create transcription diff --git a/tests/test_db.py b/tests/test_db.py index 60fa969bc7868959b22a4fa2aa95076b992bbf75..ac39d09ff5cbc16649707e7c323bfd19ad8e8c66 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -4,26 +4,55 @@ from operator import itemgetter import pytest +from dan.datasets.extract.arkindex import TRAIN_NAME from dan.datasets.extract.db import ( + Dataset, + DatasetElement, Element, + get_dataset_elements, get_elements, get_transcription_entities, get_transcriptions, ) +def test_get_dataset_elements(mock_database): + """ + Assert dataset elements retrieval output against verified results + """ + dataset_elements = get_dataset_elements( + dataset=Dataset.select().get(), + split=TRAIN_NAME, + ) + + # ID verification + assert all( + isinstance(dataset_element, DatasetElement) + for dataset_element in dataset_elements + ) + assert [dataset_element.element.id for dataset_element in dataset_elements] == [ + "train-page_1", + "train-page_2", + ] + + def test_get_elements(mock_database): """ Assert elements retrieval output against verified results """ elements = get_elements( - parent_id="train", - element_type=["double_page"], + parent_id="train-page_1", + element_type=["text_line"], ) # ID verification assert all(isinstance(element, Element) for element in elements) - assert [element.id for element in elements] == ["train-page_1", "train-page_2"] + assert [element.id for element in elements] == [ + "train-page_1-line_1", + "train-page_1-line_2", + "train-page_1-line_3", + "train-page_1-line_4", + ] @pytest.mark.parametrize("worker_version", (False, "worker_version_id", None)) diff --git a/tests/test_extract.py b/tests/test_extract.py index 5f857044ea1756743e687bb7aef5a07b7a8f97cd..a5ed2cd9b08357c10217dabdb57b22030b542d6a 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -253,9 +253,8 @@ def test_extract( ] extractor = ArkindexExtractor( - folders=["train", "val", "test"], + dataset_id="dataset", element_type=["text_line"], - parent_element_type="double_page", output=output, # Keep the whole text entity_separators=None, @@ -419,9 +418,7 @@ def test_extract( @pytest.mark.parametrize("allow_empty", (True, False)) def test_empty_transcription(allow_empty, mock_database): extractor = ArkindexExtractor( - folders=["train", "val", "test"], element_type=["text_line"], - parent_element_type="double_page", output=None, entity_separators=None, tokens=None,