From 137d8b438e1d069a955fca2be814ede2d1b57993 Mon Sep 17 00:00:00 2001
From: Manon Blanco <blanco@teklia.com>
Date: Fri, 15 Dec 2023 10:17:21 +0000
Subject: [PATCH] Support multiple datasets from Arkindex as input

---
 dan/datasets/download/images.py  | 15 ++++---
 dan/datasets/extract/__init__.py |  2 +
 dan/datasets/extract/arkindex.py | 75 ++++++++++++++++----------------
 dan/datasets/extract/db.py       |  2 +-
 docs/usage/datasets/download.md  |  3 +-
 tests/conftest.py                |  7 ++-
 tests/data/extraction/split.json | 16 +++++++
 tests/test_db.py                 |  4 +-
 tests/test_download.py           |  6 +--
 tests/test_extract.py            | 31 ++++++-------
 10 files changed, 91 insertions(+), 70 deletions(-)

diff --git a/dan/datasets/download/images.py b/dan/datasets/download/images.py
index b492377f..c1702e51 100644
--- a/dan/datasets/download/images.py
+++ b/dan/datasets/download/images.py
@@ -62,11 +62,15 @@ class ImageDownloader:
         self.data: Dict = defaultdict(dict)
 
     def check_extraction(self, values: dict) -> str | None:
+        # Check dataset_id parameter
+        if values.get("dataset_id") is None:
+            return "Dataset ID not found"
+
         # Check image parameters
         if not (image := values.get("image")):
             return "Image information not found"
 
-        # Only support `iiif_url` with `polygon` for now
+        # Only support iiif_url with polygon for now
         if not image.get("iiif_url"):
             return "Image IIIF URL not found"
         if not image.get("polygon"):
@@ -113,15 +117,16 @@ class ImageDownloader:
             destination.mkdir(parents=True, exist_ok=True)
 
             for element_id, values in items.items():
-                image_path = (destination / element_id).with_suffix(
-                    self.image_extension
-                )
+                filename = Path(element_id).with_suffix(self.image_extension)
 
                 error = self.check_extraction(values)
                 if error:
-                    logger.warning(f"{image_path}: {error}")
+                    logger.warning(f"{destination / filename}: {error}")
                     continue
 
+                image_path = destination / values["dataset_id"] / filename
+                image_path.parent.mkdir(parents=True, exist_ok=True)
+
                 self.data[split][str(image_path)] = values["text"]
 
                 # Create task for multithreading pool if image does not exist yet
diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py
index 3e05ba0d..278ee892 100644
--- a/dan/datasets/extract/__init__.py
+++ b/dan/datasets/extract/__init__.py
@@ -48,9 +48,11 @@ def add_extract_parser(subcommands) -> None:
     )
     parser.add_argument(
         "--dataset-id",
+        nargs="+",
         type=UUID,
         help="ID of the dataset to extract from Arkindex.",
         required=True,
+        dest="dataset_ids",
     )
     parser.add_argument(
         "--element-type",
diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py
index 86f2f3ba..8e7be4c4 100644
--- a/dan/datasets/extract/arkindex.py
+++ b/dan/datasets/extract/arkindex.py
@@ -11,9 +11,8 @@ from uuid import UUID
 
 from tqdm import tqdm
 
-from arkindex_export import Dataset, open_database
+from arkindex_export import Dataset, DatasetElement, Element, open_database
 from dan.datasets.extract.db import (
-    Element,
     get_dataset_elements,
     get_elements,
     get_transcription_entities,
@@ -51,7 +50,7 @@ class ArkindexExtractor:
 
     def __init__(
         self,
-        dataset_id: UUID | None = None,
+        dataset_ids: List[UUID] | None = None,
         element_type: List[str] = [],
         output: Path | None = None,
         entity_separators: List[str] = ["\n", " "],
@@ -63,7 +62,7 @@ class ArkindexExtractor:
         allow_empty: bool = False,
         subword_vocab_size: int = 1000,
     ) -> None:
-        self.dataset_id = dataset_id
+        self.dataset_ids = dataset_ids
         self.element_type = element_type
         self.output = output
         self.entity_separators = entity_separators
@@ -139,7 +138,7 @@ class ArkindexExtractor:
             )
         return text.strip()
 
-    def process_element(self, element: Element, split: str):
+    def process_element(self, dataset_parent: DatasetElement, element: Element):
         """
         Extract an element's data and save it to disk.
         The output path is directly related to the split of the element.
@@ -152,10 +151,11 @@ class ArkindexExtractor:
         text = self.format_text(
             text,
             # Do not replace unknown characters in train split
-            charset=self.charset if split != TRAIN_NAME else None,
+            charset=self.charset if dataset_parent.set_name != TRAIN_NAME else None,
         )
 
-        self.data[split][element.id] = {
+        self.data[dataset_parent.set_name][element.id] = {
+            "dataset_id": dataset_parent.dataset_id,
             "text": text,
             "image": {
                 "iiif_url": element.image.url,
@@ -165,17 +165,16 @@ class ArkindexExtractor:
 
         self.charset = self.charset.union(set(text))
 
-    def process_parent(self, pbar, parent: Element, split: str):
+    def process_parent(self, pbar, dataset_parent: DatasetElement):
         """
         Extract data from a parent element.
         """
-        base_description = (
-            f"Extracting data from {parent.type} ({parent.id}) for split ({split})"
-        )
+        parent = dataset_parent.element
+        base_description = f"Extracting data from {parent.type} ({parent.id}) for split ({dataset_parent.set_name})"
         pbar.set_description(desc=base_description)
         if self.element_type == [parent.type]:
             try:
-                self.process_element(parent, split)
+                self.process_element(dataset_parent, parent)
             except ProcessingError as e:
                 logger.warning(f"Skipping {parent.id}: {str(e)}")
         # Extract children elements
@@ -190,7 +189,7 @@ class ArkindexExtractor:
                 # Update description to update the children processing progress
                 pbar.set_description(desc=base_description + f" ({idx}/{nb_children})")
                 try:
-                    self.process_element(element, split)
+                    self.process_element(dataset_parent, element)
                 except ProcessingError as e:
                     logger.warning(f"Skipping {element.id}: {str(e)}")
 
@@ -274,28 +273,30 @@ class ArkindexExtractor:
 
     def run(self):
         # Retrieve the Dataset and its splits from the cache
-        dataset = Dataset.get_by_id(self.dataset_id)
-        splits = dataset.sets.split(",")
-        assert set(splits).issubset(
-            set(SPLIT_NAMES)
-        ), f'Dataset must have "{TRAIN_NAME}", "{VAL_NAME}" and "{TEST_NAME}" steps'
-
-        # Iterate over the subsets to find the page images and labels.
-        for split in splits:
-            with tqdm(
-                get_dataset_elements(dataset, split),
-                desc=f"Extracting data from ({self.dataset_id}) for split ({split})",
-            ) as pbar:
-                # Iterate over the pages to create splits at page level.
-                for parent in pbar:
-                    self.process_parent(
-                        pbar=pbar,
-                        parent=parent.element,
-                        split=split,
-                    )
-                    # Progress bar updates
-                    pbar.update()
-                    pbar.refresh()
+        for dataset_id in self.dataset_ids:
+            dataset = Dataset.get_by_id(dataset_id)
+            splits = dataset.sets.split(",")
+            if not set(splits).issubset(set(SPLIT_NAMES)):
+                logger.warning(
+                    f'Dataset {dataset.name} ({dataset.id}) does not have "{TRAIN_NAME}", "{VAL_NAME}" and "{TEST_NAME}" steps'
+                )
+                continue
+
+            # Iterate over the subsets to find the page images and labels.
+            for split in splits:
+                with tqdm(
+                    get_dataset_elements(dataset, split),
+                    desc=f"Extracting data from ({dataset_id}) for split ({split})",
+                ) as pbar:
+                    # Iterate over the pages to create splits at page level.
+                    for parent in pbar:
+                        self.process_parent(
+                            pbar=pbar,
+                            dataset_parent=parent,
+                        )
+                        # Progress bar updates
+                        pbar.update()
+                        pbar.refresh()
 
         if not self.data:
             raise Exception(
@@ -308,7 +309,7 @@ class ArkindexExtractor:
 
 def run(
     database: Path,
-    dataset_id: UUID,
+    dataset_ids: List[UUID],
     element_type: List[str],
     output: Path,
     entity_separators: List[str],
@@ -327,7 +328,7 @@ def run(
     Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True)
 
     ArkindexExtractor(
-        dataset_id=dataset_id,
+        dataset_ids=dataset_ids,
         element_type=element_type,
         output=output,
         entity_separators=entity_separators,
diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py
index 3b89902c..25146799 100644
--- a/dan/datasets/extract/db.py
+++ b/dan/datasets/extract/db.py
@@ -22,7 +22,7 @@ def get_dataset_elements(
     Retrieve dataset elements in a specific split from an SQLite export of an Arkindex corpus
     """
     query = (
-        DatasetElement.select(DatasetElement.element)
+        DatasetElement.select()
         .join(Element)
         .join(Image, on=(DatasetElement.element.image == Image.id))
         .where(
diff --git a/docs/usage/datasets/download.md b/docs/usage/datasets/download.md
index 77f3c7f9..221f6108 100644
--- a/docs/usage/datasets/download.md
+++ b/docs/usage/datasets/download.md
@@ -22,6 +22,7 @@ The `--output` directory should have a `split.json` JSON-formatted file with a s
 {
   "train": {
     "<element_id>": {
+      "dataset_id": "<dataset_id>",
       "image": {
         "iiif_url": "https://<iiif_server>/iiif/2/<path>",
         "polygon": [
@@ -32,7 +33,7 @@ The `--output` directory should have a `split.json` JSON-formatted file with a s
           [37, 191]
         ]
       },
-      "text": "ⓢCou⁇e⁇  ⓕBouis  ⓑ⁇.12.14"
+      "text": "â“¢Coufet â“•Bouis â“‘07.12.14"
     },
   },
   "val": {},
diff --git a/tests/conftest.py b/tests/conftest.py
index 2cb4c475..4ad5b89c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -25,7 +25,7 @@ from dan.datasets.extract.arkindex import SPLIT_NAMES
 from tests import FIXTURES
 
 
-@pytest.fixture(scope="session")
+@pytest.fixture()
 def mock_database(tmp_path_factory):
     def create_transcription_entity(
         transcription: Transcription,
@@ -182,7 +182,10 @@ def mock_database(tmp_path_factory):
 
     # Create dataset
     dataset = Dataset.create(
-        id="dataset", name="Dataset", state="complete", sets=",".join(SPLIT_NAMES)
+        id="dataset_id",
+        name="Dataset",
+        state="complete",
+        sets=",".join(SPLIT_NAMES),
     )
 
     # Create dataset elements
diff --git a/tests/data/extraction/split.json b/tests/data/extraction/split.json
index 36a6ce49..e264f689 100644
--- a/tests/data/extraction/split.json
+++ b/tests/data/extraction/split.json
@@ -1,6 +1,7 @@
 {
     "test": {
         "test-page_1-line_1": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_1-line_1.jpg",
                 "polygon": [
@@ -29,6 +30,7 @@
             "text": "ⓢCou⁇e⁇  ⓕBouis  ⓑ⁇.12.14"
         },
         "test-page_1-line_2": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_1-line_2.jpg",
                 "polygon": [
@@ -57,6 +59,7 @@
             "text": "ⓢ⁇outrain  ⓕA⁇ol⁇⁇e  ⓑ9.4.13"
         },
         "test-page_1-line_3": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_1-line_3.jpg",
                 "polygon": [
@@ -85,6 +88,7 @@
             "text": "ⓢ⁇abale  ⓕ⁇ran⁇ais  ⓑ26.3.11"
         },
         "test-page_2-line_1": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_2-line_1.jpg",
                 "polygon": [
@@ -113,6 +117,7 @@
             "text": "ⓢ⁇urosoy  ⓕBouis  ⓑ22⁇4⁇18"
         },
         "test-page_2-line_2": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_2-line_2.jpg",
                 "polygon": [
@@ -141,6 +146,7 @@
             "text": "ⓢColaiani  ⓕAn⁇els  ⓑ28.11.1⁇"
         },
         "test-page_2-line_3": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_2-line_3.jpg",
                 "polygon": [
@@ -171,6 +177,7 @@
     },
     "train": {
         "train-page_1-line_1": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_1.jpg",
                 "polygon": [
@@ -199,6 +206,7 @@
             "text": "â“¢Caillet  â“•Maurice  â“‘28.9.06"
         },
         "train-page_1-line_2": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_2.jpg",
                 "polygon": [
@@ -227,6 +235,7 @@
             "text": "â“¢Reboul  â“•Jean  â“‘30.9.02"
         },
         "train-page_1-line_3": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_3.jpg",
                 "polygon": [
@@ -255,6 +264,7 @@
             "text": "â“¢Bareyre  â“•Jean  â“‘28.3.11"
         },
         "train-page_1-line_4": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_4.jpg",
                 "polygon": [
@@ -283,6 +293,7 @@
             "text": "â“¢Roussy  â“•Jean  â“‘4.11.14"
         },
         "train-page_2-line_1": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_2-line_1.jpg",
                 "polygon": [
@@ -311,6 +322,7 @@
             "text": "â“¢Marin  â“•Marcel  â“‘10.8.06"
         },
         "train-page_2-line_2": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_2-line_2.jpg",
                 "polygon": [
@@ -339,6 +351,7 @@
             "text": "â“¢Amical  â“•Eloi  â“‘11.10.04"
         },
         "train-page_2-line_3": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_2-line_3.jpg",
                 "polygon": [
@@ -369,6 +382,7 @@
     },
     "val": {
         "val-page_1-line_1": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/val-page_1-line_1.jpg",
                 "polygon": [
@@ -397,6 +411,7 @@
             "text": "ⓢMonar⁇  ⓕBouis  ⓑ29⁇⁇⁇04"
         },
         "val-page_1-line_2": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/val-page_1-line_2.jpg",
                 "polygon": [
@@ -425,6 +440,7 @@
             "text": "ⓢAstier  ⓕArt⁇ur  ⓑ11⁇2⁇13"
         },
         "val-page_1-line_3": {
+            "dataset_id": "dataset_id",
             "image": {
                 "iiif_url": "{FIXTURES}/extraction/images/text_line/val-page_1-line_3.jpg",
                 "polygon": [
diff --git a/tests/test_db.py b/tests/test_db.py
index ac39d09f..0da81d8e 100644
--- a/tests/test_db.py
+++ b/tests/test_db.py
@@ -4,11 +4,9 @@ from operator import itemgetter
 
 import pytest
 
+from arkindex_export import Dataset, DatasetElement, Element
 from dan.datasets.extract.arkindex import TRAIN_NAME
 from dan.datasets.extract.db import (
-    Dataset,
-    DatasetElement,
-    Element,
     get_dataset_elements,
     get_elements,
     get_transcription_entities,
diff --git a/tests/test_download.py b/tests/test_download.py
index d5348ec3..174e69a4 100644
--- a/tests/test_download.py
+++ b/tests/test_download.py
@@ -58,9 +58,9 @@ def test_download(split_content, monkeypatch, tmp_path):
 
     # Check files
     IMAGE_DIR = output / "images"
-    TEST_DIR = IMAGE_DIR / "test"
-    TRAIN_DIR = IMAGE_DIR / "train"
-    VAL_DIR = IMAGE_DIR / "val"
+    TEST_DIR = IMAGE_DIR / "test" / "dataset_id"
+    TRAIN_DIR = IMAGE_DIR / "train" / "dataset_id"
+    VAL_DIR = IMAGE_DIR / "val" / "dataset_id"
 
     expected_paths = [
         # Images of test folder
diff --git a/tests/test_extract.py b/tests/test_extract.py
index a5ed2cd9..46721185 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -8,7 +8,12 @@ from typing import NamedTuple
 
 import pytest
 
-from arkindex_export import Element, Transcription, TranscriptionEntity
+from arkindex_export import (
+    DatasetElement,
+    Element,
+    Transcription,
+    TranscriptionEntity,
+)
 from dan.datasets.extract.arkindex import ArkindexExtractor
 from dan.datasets.extract.db import get_transcription_entities
 from dan.datasets.extract.exceptions import (
@@ -85,28 +90,18 @@ def test_process_element_unknown_token_in_text_error(mock_database, tmp_path):
     output = tmp_path / "extraction"
     arkindex_extractor = ArkindexExtractor(output=output)
 
-    # Create an element with an invalid transcription
-    element = Element.create(
-        id="element_id",
-        name="1",
-        type="page",
-        polygon="[]",
-        created=0.0,
-        updated=0.0,
-    )
-    Transcription.create(
-        id="transcription_id",
-        text="Is this text valid⁇",
-        element=element,
-    )
+    # Retrieve a dataset element and update its transcription with an invalid one
+    dataset_element = DatasetElement.select().first()
+    element = dataset_element.element
+    Transcription.update({Transcription.text: "Is this text valid⁇"}).execute()
 
     with pytest.raises(
         UnknownTokenInText,
         match=re.escape(
-            "Unknown token found in the transcription text of element (element_id)"
+            f"Unknown token found in the transcription text of element ({element.id})"
         ),
     ):
-        arkindex_extractor.process_element(element, "val")
+        arkindex_extractor.process_element(dataset_element, element)
 
 
 @pytest.mark.parametrize(
@@ -253,7 +248,7 @@ def test_extract(
     ]
 
     extractor = ArkindexExtractor(
-        dataset_id="dataset",
+        dataset_ids=["dataset_id"],
         element_type=["text_line"],
         output=output,
         # Keep the whole text
-- 
GitLab