From 40b3099243d51035ed18cc045dcc5c31035da10a Mon Sep 17 00:00:00 2001
From: EvaBardou <bardou@teklia.com>
Date: Wed, 18 Oct 2023 12:21:25 +0200
Subject: [PATCH] New DatasetExtractor using a DatasetWorker

---
 setup.py                                      |   7 +-
 .../dataset_worker.py                         | 458 ++++++++++++++++++
 2 files changed, 464 insertions(+), 1 deletion(-)
 create mode 100644 worker_generic_training_dataset/dataset_worker.py

diff --git a/setup.py b/setup.py
index 81f4440..4f4d657 100755
--- a/setup.py
+++ b/setup.py
@@ -47,6 +47,11 @@ setup(
     author="Teklia",
     author_email="contact@teklia.com",
     install_requires=parse_requirements(),
-    entry_points={"console_scripts": [f"{COMMAND}={MODULE}.worker:main"]},
+    entry_points={
+        "console_scripts": [
+            f"{COMMAND}={MODULE}.worker:main",
+            "worker-generic-training-dataset-new=worker_generic_training_dataset.dataset_worker:main",
+        ]
+    },
     packages=find_packages(),
 )
diff --git a/worker_generic_training_dataset/dataset_worker.py b/worker_generic_training_dataset/dataset_worker.py
new file mode 100644
index 0000000..a82c464
--- /dev/null
+++ b/worker_generic_training_dataset/dataset_worker.py
@@ -0,0 +1,458 @@
+# -*- coding: utf-8 -*-
+import logging
+import sys
+import tempfile
+from argparse import Namespace
+from itertools import groupby
+from operator import itemgetter
+from pathlib import Path
+from tempfile import _TemporaryFileWrapper
+from typing import Iterator, List, Optional, Tuple
+from uuid import UUID
+
+from apistar.exceptions import ErrorResponse
+from arkindex_export import Element, open_database
+from arkindex_export.queries import list_children
+from arkindex_worker.cache import (
+    CachedClassification,
+    CachedElement,
+    CachedEntity,
+    CachedImage,
+    CachedTranscription,
+    CachedTranscriptionEntity,
+    create_tables,
+    create_version_table,
+)
+from arkindex_worker.cache import db as cache_database
+from arkindex_worker.cache import init_cache_db
+from arkindex_worker.image import download_image
+from arkindex_worker.models import Dataset
+from arkindex_worker.utils import create_tar_zst_archive
+from arkindex_worker.worker.base import BaseWorker
+from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
+from worker_generic_training_dataset.db import (
+    list_classifications,
+    list_transcription_entities,
+    list_transcriptions,
+)
+from worker_generic_training_dataset.utils import build_image_url
+from worker_generic_training_dataset.worker import (
+    BULK_BATCH_SIZE,
+    DEFAULT_TRANSCRIPTION_ORIENTATION,
+)
+
+logger: logging.Logger = logging.getLogger(__name__)
+
+
+class DatasetWorker(BaseWorker, DatasetMixin):
+    def __init__(
+        self,
+        description: str = "Arkindex Elements Worker",
+        support_cache: bool = False,
+        generator: bool = False,
+    ):
+        super().__init__(description, support_cache)
+
+        self.parser.add_argument(
+            "--dataset",
+            type=UUID,
+            nargs="+",
+            help="One or more Arkindex dataset ID",
+        )
+
+        self.generator = generator
+
+    def list_dataset_elements_per_set(
+        self, dataset: Dataset
+    ) -> Iterator[Tuple[str, Element]]:
+        """
+        Calls `list_dataset_elements` but returns results grouped by Set
+        """
+
+        def format_element(element):
+            return Element.get(Element.id == element[1].id)
+
+        def format_set(set):
+            return (set[0], list(map(format_element, list(set[1]))))
+
+        return list(
+            map(
+                format_set,
+                groupby(
+                    sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
+                    key=itemgetter(0),
+                ),
+            )
+        )
+
+    def process_dataset(self, dataset: Dataset):
+        """
+        Override this method to implement your worker and process a single Arkindex dataset at once.
+
+        :param dataset: The dataset to process.
+        """
+
+    def list_datasets(self) -> List[Dataset] | List[str]:
+        """
+        Calls `list_process_datasets` if not is_read_only,
+        else simply give the list of IDs provided via CLI
+        """
+        if self.is_read_only:
+            return list(map(str, self.args.dataset))
+
+        return self.list_process_datasets()
+
+    def run(self):
+        self.configure()
+
+        datasets: List[Dataset] | List[str] = self.list_datasets()
+        if not datasets:
+            logger.warning("No datasets to process, stopping.")
+            sys.exit(1)
+
+        # Process every dataset
+        count = len(datasets)
+        failed = 0
+        for i, item in enumerate(datasets, start=1):
+            dataset = None
+            try:
+                if not self.is_read_only:
+                    # Just use the result of list_datasets as the dataset
+                    dataset = item
+                else:
+                    # Load dataset using the Arkindex API
+                    dataset = Dataset(**self.request("RetrieveDataset", id=item))
+
+                if self.generator:
+                    assert (
+                        dataset.state == DatasetState.Open.value
+                    ), "When generating a new dataset, its state should be Open"
+                else:
+                    assert (
+                        dataset.state == DatasetState.Complete.value
+                    ), "When processing an existing dataset, its state should be Complete"
+
+                if self.generator:
+                    # Update the dataset state to Building
+                    logger.info(f"Building {dataset} ({i}/{count})")
+                    self.update_dataset_state(dataset, DatasetState.Building)
+
+                # Process the dataset
+                self.process_dataset(dataset)
+
+                if self.generator:
+                    # Update the dataset state to Complete
+                    logger.info(f"Completed {dataset} ({i}/{count})")
+                    self.update_dataset_state(dataset, DatasetState.Complete)
+            except Exception as e:
+                # Handle errors occurring while retrieving, processing or patching the state for this dataset.
+                failed += 1
+
+                # Handle the case where we failed retrieving the dataset
+                dataset_id = dataset.id if dataset else item
+
+                if isinstance(e, ErrorResponse):
+                    message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}"
+                else:
+                    message = (
+                        f"Failed running worker on dataset {dataset_id}: {repr(e)}"
+                    )
+
+                logger.warning(
+                    message,
+                    exc_info=e if self.args.verbose else None,
+                )
+                if dataset and self.generator:
+                    # Try to update the state to Error regardless of the response
+                    try:
+                        self.update_dataset_state(dataset, DatasetState.Error)
+                    except Exception:
+                        pass
+
+        if failed:
+            logger.error(
+                "Ran on {} dataset: {} completed, {} failed".format(
+                    count, count - failed, failed
+                )
+            )
+            if failed >= count:  # Everything failed!
+                sys.exit(1)
+
+
+class DatasetExtractor(DatasetWorker):
+    def configure(self) -> None:
+        self.args: Namespace = self.parser.parse_args()
+        if self.is_read_only:
+            super().configure_for_developers()
+        else:
+            super().configure()
+
+        if self.user_configuration:
+            logger.info("Overriding with user_configuration")
+            self.config.update(self.user_configuration)
+
+        # Download corpus
+        self.download_latest_export()
+
+        # Initialize db that will be written
+        self.configure_cache()
+
+        # CachedImage downloaded and created in DB
+        self.cached_images = dict()
+
+        # Where to save the downloaded images
+        self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
+        logger.info(f"Images will be saved at `{self.image_folder}`.")
+
+    def configure_cache(self) -> None:
+        """
+        Create an SQLite database compatible with base-worker cache and initialize it.
+        """
+        self.use_cache = True
+        self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite"
+        # Remove previous execution result if present
+        self.cache_path.unlink(missing_ok=True)
+
+        init_cache_db(self.cache_path)
+
+        create_version_table()
+
+        create_tables()
+
+    def download_latest_export(self) -> None:
+        """
+        Download the latest export of the current corpus.
+        Export must be in `"done"` state.
+        """
+        try:
+            exports = list(
+                self.api_client.paginate(
+                    "ListExports",
+                    id=self.corpus_id,
+                )
+            )
+        except ErrorResponse as e:
+            logger.error(
+                f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
+            )
+            raise e
+
+        # Find the latest that is in "done" state
+        exports: List[dict] = sorted(
+            list(filter(lambda exp: exp["state"] == "done", exports)),
+            key=itemgetter("updated"),
+            reverse=True,
+        )
+        assert (
+            len(exports) > 0
+        ), f"No available exports found for the corpus {self.corpus_id}."
+
+        # Download latest export
+        try:
+            export_id: str = exports[0]["id"]
+            logger.info(f"Downloading export ({export_id})...")
+            self.export: _TemporaryFileWrapper = self.api_client.request(
+                "DownloadExport",
+                id=export_id,
+            )
+            logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`")
+            open_database(self.export.name)
+        except ErrorResponse as e:
+            logger.error(
+                f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}"
+            )
+            raise e
+
+    def insert_classifications(self, element: CachedElement) -> None:
+        logger.info("Listing classifications")
+        classifications: list[CachedClassification] = [
+            CachedClassification(
+                id=classification.id,
+                element=element,
+                class_name=classification.class_name,
+                confidence=classification.confidence,
+                state=classification.state,
+                worker_run_id=classification.worker_run,
+            )
+            for classification in list_classifications(element.id)
+        ]
+        if classifications:
+            logger.info(f"Inserting {len(classifications)} classification(s)")
+            with cache_database.atomic():
+                CachedClassification.bulk_create(
+                    model_list=classifications,
+                    batch_size=BULK_BATCH_SIZE,
+                )
+
+    def insert_transcriptions(
+        self, element: CachedElement
+    ) -> List[CachedTranscription]:
+        logger.info("Listing transcriptions")
+        transcriptions: list[CachedTranscription] = [
+            CachedTranscription(
+                id=transcription.id,
+                element=element,
+                text=transcription.text,
+                confidence=transcription.confidence,
+                orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
+                worker_version_id=transcription.worker_version,
+                worker_run_id=transcription.worker_run,
+            )
+            for transcription in list_transcriptions(element.id)
+        ]
+        if transcriptions:
+            logger.info(f"Inserting {len(transcriptions)} transcription(s)")
+            with cache_database.atomic():
+                CachedTranscription.bulk_create(
+                    model_list=transcriptions,
+                    batch_size=BULK_BATCH_SIZE,
+                )
+        return transcriptions
+
+    def insert_entities(self, transcriptions: List[CachedTranscription]) -> None:
+        logger.info("Listing entities")
+        entities: List[CachedEntity] = []
+        transcription_entities: List[CachedTranscriptionEntity] = []
+        for transcription in transcriptions:
+            for transcription_entity in list_transcription_entities(transcription.id):
+                entity = CachedEntity(
+                    id=transcription_entity.entity.id,
+                    type=transcription_entity.entity.type.name,
+                    name=transcription_entity.entity.name,
+                    validated=transcription_entity.entity.validated,
+                    metas=transcription_entity.entity.metas,
+                    worker_run_id=transcription_entity.entity.worker_run,
+                )
+                entities.append(entity)
+                transcription_entities.append(
+                    CachedTranscriptionEntity(
+                        id=transcription_entity.id,
+                        transcription=transcription,
+                        entity=entity,
+                        offset=transcription_entity.offset,
+                        length=transcription_entity.length,
+                        confidence=transcription_entity.confidence,
+                        worker_run_id=transcription_entity.worker_run,
+                    )
+                )
+        if entities:
+            # First insert entities since they are foreign keys on transcription entities
+            logger.info(f"Inserting {len(entities)} entities")
+            with cache_database.atomic():
+                CachedEntity.bulk_create(
+                    model_list=entities,
+                    batch_size=BULK_BATCH_SIZE,
+                )
+
+        if transcription_entities:
+            # Insert transcription entities
+            logger.info(
+                f"Inserting {len(transcription_entities)} transcription entities"
+            )
+            with cache_database.atomic():
+                CachedTranscriptionEntity.bulk_create(
+                    model_list=transcription_entities,
+                    batch_size=BULK_BATCH_SIZE,
+                )
+
+    def insert_element(
+        self, element: Element, parent_id: Optional[UUID] = None
+    ) -> None:
+        """
+        Insert the given element in the cache database.
+        Its image will also be saved to disk, if it wasn't already.
+
+        The insertion of an element includes:
+        - its classifications
+        - its transcriptions
+        - its transcriptions' entities (both Entity and TranscriptionEntity)
+
+        :param element: Element to insert.
+        :param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
+        """
+        logger.info(f"Processing element ({element.id})")
+        if element.image and element.image.id not in self.cached_images:
+            # Download image
+            logger.info("Downloading image")
+            download_image(url=build_image_url(element)).save(
+                self.image_folder / f"{element.image.id}.jpg"
+            )
+            # Insert image
+            logger.info("Inserting image")
+            # Store images in case some other elements use it as well
+            with cache_database.atomic():
+                self.cached_images[element.image.id] = CachedImage.create(
+                    id=element.image.id,
+                    width=element.image.width,
+                    height=element.image.height,
+                    url=element.image.url,
+                )
+
+        # Insert element
+        logger.info("Inserting element")
+        with cache_database.atomic():
+            cached_element: CachedElement = CachedElement.create(
+                id=element.id,
+                parent_id=parent_id,
+                type=element.type,
+                image=self.cached_images[element.image.id] if element.image else None,
+                polygon=element.polygon,
+                rotation_angle=element.rotation_angle,
+                mirrored=element.mirrored,
+                worker_version_id=element.worker_version,
+                worker_run_id=element.worker_run,
+                confidence=element.confidence,
+            )
+
+        # Insert classifications
+        self.insert_classifications(cached_element)
+
+        # Insert transcriptions
+        transcriptions: List[CachedTranscription] = self.insert_transcriptions(
+            cached_element
+        )
+
+        # Insert entities
+        self.insert_entities(transcriptions)
+
+    def process_set(self, set_name: str, elements: List[Element]) -> None:
+        logger.info(
+            f"Filling the cache with information from elements in the set {set_name}"
+        )
+
+        # First list all pages
+        nb_elements: int = len(elements)
+        for idx, element in enumerate(elements, start=1):
+            logger.info(f"Processing `{set_name}` element ({idx}/{nb_elements})")
+
+            # Insert page
+            self.insert_element(element)
+
+            # List children
+            children = list_children(element.id)
+            nb_children: int = children.count()
+            for child_idx, child in enumerate(children, start=1):
+                logger.info(f"Processing child ({child_idx}/{nb_children})")
+                # Insert child
+                self.insert_element(child, parent_id=element.id)
+
+    def process_dataset(self, dataset: Dataset):
+        # Iterate over given sets
+        for set_name, elements in self.list_dataset_elements_per_set(dataset):
+            self.process_set(set_name, elements)
+
+        # TAR + ZSTD Image folder and store as task artifact
+        zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
+        logger.info(f"Compressing the images to {zstd_archive_path}")
+        create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path)
+
+
+def main():
+    DatasetExtractor(
+        description="Fill base-worker cache with information about dataset and extract images",
+        generator=True,
+    ).run()
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab