From ec6c64fa18e90d8a203b7a991eafd1b30cf7a89d Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Fri, 31 Mar 2023 17:31:02 +0200
Subject: [PATCH] first working version

---
 requirements.txt                              |   3 +
 worker_generic_training_dataset/db.py         |  89 +++++++++++
 worker_generic_training_dataset/exceptions.py |  36 +++++
 worker_generic_training_dataset/utils.py      |  55 +++++++
 worker_generic_training_dataset/worker.py     | 140 ++++++++++++++++--
 5 files changed, 310 insertions(+), 13 deletions(-)
 create mode 100644 worker_generic_training_dataset/db.py
 create mode 100644 worker_generic_training_dataset/exceptions.py
 create mode 100644 worker_generic_training_dataset/utils.py

diff --git a/requirements.txt b/requirements.txt
index 5ff1be6..5611000 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1,4 @@
 arkindex-base-worker==0.3.2
+arkindex-export==0.1.2
+imageio==2.27.0
+opencv-python==4.7.0.72
diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py
new file mode 100644
index 0000000..c4a38ea
--- /dev/null
+++ b/worker_generic_training_dataset/db.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+from typing import NamedTuple
+
+from arkindex_export import Classification
+from arkindex_export.models import (
+    Element,
+    Entity,
+    EntityType,
+    Transcription,
+    TranscriptionEntity,
+)
+from arkindex_worker.cache import (
+    CachedElement,
+    CachedEntity,
+    CachedTranscription,
+    CachedTranscriptionEntity,
+)
+
+DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
+
+
+def retrieve_element(element_id: str):
+    return Element.get_by_id(element_id)
+
+
+def list_classifications(element: Element):
+    query = Classification.select().where(Classification.element == element)
+    return query
+
+
+def parse_transcription(transcription: NamedTuple, element: CachedElement):
+    return CachedTranscription(
+        id=transcription.id,
+        element=element,
+        text=transcription.text,
+        confidence=transcription.confidence,
+        orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
+        worker_version_id=transcription.worker_version.id
+        if transcription.worker_version
+        else None,
+    )
+
+
+def list_transcriptions(element: CachedElement):
+    query = Transcription.select().where(Transcription.element_id == element.id)
+    return [parse_transcription(x, element) for x in query]
+
+
+def parse_entities(data: NamedTuple, transcription: CachedTranscription):
+    entity = CachedEntity(
+        id=data.entity_id,
+        type=data.type,
+        name=data.name,
+        validated=data.validated,
+        metas=data.metas,
+    )
+    return entity, CachedTranscriptionEntity(
+        id=data.transcription_entity_id,
+        transcription=transcription,
+        entity=entity,
+        offset=data.offset,
+        length=data.length,
+        confidence=data.confidence,
+    )
+
+
+def retrieve_entities(transcription: CachedTranscription):
+    query = (
+        TranscriptionEntity.select(
+            TranscriptionEntity.id.alias("transcription_entity_id"),
+            TranscriptionEntity.length.alias("length"),
+            TranscriptionEntity.offset.alias("offset"),
+            TranscriptionEntity.confidence.alias("confidence"),
+            Entity.id.alias("entity_id"),
+            EntityType.name.alias("type"),
+            Entity.name,
+            Entity.validated,
+            Entity.metas,
+        )
+        .where(TranscriptionEntity.transcription_id == transcription.id)
+        .join(Entity, on=TranscriptionEntity.entity)
+        .join(EntityType, on=Entity.type)
+    )
+    return zip(
+        *[
+            parse_entities(entity_data, transcription)
+            for entity_data in query.namedtuples()
+        ]
+    )
diff --git a/worker_generic_training_dataset/exceptions.py b/worker_generic_training_dataset/exceptions.py
new file mode 100644
index 0000000..062d580
--- /dev/null
+++ b/worker_generic_training_dataset/exceptions.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+
+
+class ElementProcessingError(Exception):
+    """
+    Raised when a problem is encountered while processing an element
+    """
+
+    element_id: str
+    """
+    ID of the element being processed.
+    """
+
+    def __init__(self, element_id: str, *args: object) -> None:
+        super().__init__(*args)
+        self.element_id = element_id
+
+
+class ImageDownloadError(ElementProcessingError):
+    """
+    Raised when an element's image could not be downloaded
+    """
+
+    error: Exception
+    """
+    Error encountered.
+    """
+
+    def __init__(self, element_id: str, error: Exception, *args: object) -> None:
+        super().__init__(element_id, *args)
+        self.error = error
+
+    def __str__(self) -> str:
+        return (
+            f"Couldn't retrieve image of element ({self.element_id}: {str(self.error)})"
+        )
diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py
new file mode 100644
index 0000000..151f903
--- /dev/null
+++ b/worker_generic_training_dataset/utils.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+import ast
+import logging
+import time
+from pathlib import Path
+from urllib.parse import urljoin
+
+import cv2
+import imageio.v2 as iio
+from arkindex_export.models import Element
+from worker_generic_training_dataset.exceptions import ImageDownloadError
+
+logger = logging.getLogger(__name__)
+MAX_RETRIES = 5
+
+
+def bounding_box(polygon: list):
+    """
+    Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points)
+    """
+    all_x, all_y = zip(*polygon)
+    x, y = min(all_x), min(all_y)
+    width, height = max(all_x) - x, max(all_y) - y
+    return int(x), int(y), int(width), int(height)
+
+
+def build_image_url(element: Element):
+    x, y, width, height = bounding_box(ast.literal_eval(element.polygon))
+    return urljoin(
+        element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
+    )
+
+
+def download_image(element: Element, folder: Path):
+    """
+    Download the image to `folder / {element.id}.jpg`
+    """
+    tries = 1
+    # retry loop
+    while True:
+        if tries > MAX_RETRIES:
+            raise ImageDownloadError(element.id, Exception("Maximum retries reached."))
+        try:
+            image = iio.imread(build_image_url(element))
+            cv2.imwrite(
+                str(folder / f"{element.id}.jpg"),
+                cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
+            )
+            break
+        except TimeoutError:
+            logger.warning("Timeout, retry in 1 second.")
+            time.sleep(1)
+            tries += 1
+        except Exception as e:
+            raise ImageDownloadError(element.id, e)
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index 5b18a7b..48420c3 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -1,13 +1,38 @@
 # -*- coding: utf-8 -*-
 import logging
 import operator
+from pathlib import Path
 
 from apistar.exceptions import ErrorResponse
-from arkindex_worker.cache import create_tables, create_version_table, init_cache_db
+from arkindex_export import open_database
+from arkindex_export.models import Element
+from arkindex_export.queries import list_children
+from arkindex_worker.cache import (
+    CachedClassification,
+    CachedElement,
+    CachedEntity,
+    CachedImage,
+    CachedTranscription,
+    CachedTranscriptionEntity,
+    create_tables,
+    create_version_table,
+)
+from arkindex_worker.cache import db as cache_database
+from arkindex_worker.cache import init_cache_db
 from arkindex_worker.worker import ElementsWorker
+from worker_generic_training_dataset.db import (
+    list_classifications,
+    list_transcriptions,
+    retrieve_element,
+    retrieve_entities,
+)
+from worker_generic_training_dataset.utils import download_image
 
 logger = logging.getLogger(__name__)
 
+IMAGE_FOLDER = Path("images")
+BULK_BATCH_SIZE = 50
+
 
 class DatasetExtractor(ElementsWorker):
     def configure(self):
@@ -26,12 +51,10 @@ class DatasetExtractor(ElementsWorker):
 
     def initialize_database(self):
         # Create db at
-        # - self.workdir.parent / self.task_id in Arkindex mode
+        # - self.workdir / "db.sqlite" in Arkindex mode
         # - self.args.database in dev mode
         database_path = (
-            self.args.database
-            if self.is_read_only
-            else self.workdir.parent / self.task_id
+            self.args.database if self.is_read_only else self.workdir / "db.sqlite"
         )
 
         init_cache_db(database_path)
@@ -49,8 +72,9 @@ class DatasetExtractor(ElementsWorker):
             )["results"]
         except ErrorResponse as e:
             logger.error(
-                f"Could not list exports of corpus ({self.corpus_id}): {str(e)}"
+                f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
             )
+            raise e
 
         # Find latest that is in "done" state
         exports = sorted(
@@ -62,21 +86,111 @@ class DatasetExtractor(ElementsWorker):
         # Download latest it in a tmpfile
         try:
             export_id = exports[0]["id"]
-            download_url = self.api_client.request(
+            logger.info(f"Downloading export ({export_id})...")
+            self.export = self.api_client.request(
                 "DownloadExport",
                 id=export_id,
-            )["results"]
+            )
+            logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`")
+            open_database(self.export.name)
         except ErrorResponse as e:
             logger.error(
                 f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e)}"
             )
-        print(download_url)
+            raise e
+
+    def insert_element(self, element: Element, parent_id: str):
+        logger.info(f"Processing element ({element.id})")
+        if element.image:
+            # Download image
+            logger.info("Downloading image")
+            download_image(element, folder=IMAGE_FOLDER)
+            # Insert image
+            logger.info("Inserting image")
+            CachedImage.create(
+                id=element.image.id,
+                width=element.image.width,
+                height=element.image.height,
+                url=element.image.url,
+            )
 
-    def process_element(self, element):
-        ...
+        # Insert element
+        logger.info("Inserting element")
+        cached_element = CachedElement.create(
+            id=element.id,
+            parent_id=parent_id,
+            type=element.type,
+            image=element.image.id if element.image else None,
+            polygon=element.polygon,
+            rotation_angle=element.rotation_angle,
+            mirrored=element.mirrored,
+            worker_version_id=element.worker_version.id
+            if element.worker_version
+            else None,
+            confidence=element.confidence,
+        )
 
-    # List Transcriptions, Metas
-    #
+        # Insert classifications
+        logger.info("Listing classifications")
+        classifications = [
+            CachedClassification(
+                id=classification.id,
+                element=cached_element,
+                class_name=classification.class_name,
+                confidence=classification.confidence,
+                state=classification.state,
+            )
+            for classification in list_classifications(element)
+        ]
+        if classifications:
+            logger.info(f"Inserting {len(classifications)} classifications")
+            with cache_database.atomic():
+                CachedClassification.bulk_create(
+                    model_list=classifications,
+                    batch_size=BULK_BATCH_SIZE,
+                )
+
+        # Insert transcriptions
+        logger.info("Listing transcriptions")
+        transcriptions = list_transcriptions(cached_element)
+        if transcriptions:
+            logger.info(f"Inserting {len(transcriptions)} transcriptions")
+            with cache_database.atomic():
+                CachedTranscription.bulk_create(
+                    model_list=transcriptions,
+                    batch_size=BULK_BATCH_SIZE,
+                )
+
+        logger.info("Listing entities")
+        entities, transcription_entities = [], []
+        for transcription in transcriptions:
+            ents, transc_ents = retrieve_entities(transcription)
+            entities.extend(ents)
+            transcription_entities.extend(transc_ents)
+
+        if entities:
+            logger.info(f"Inserting {len(entities)} entities")
+            with cache_database.atomic():
+                CachedEntity.bulk_create(
+                    model_list=entities,
+                    batch_size=BULK_BATCH_SIZE,
+                )
+            # Insert transcription entities
+            logger.info(
+                f"Inserting {len(transcription_entities)} transcription entities"
+            )
+            with cache_database.atomic():
+                CachedTranscriptionEntity.bulk_create(
+                    model_list=transcription_entities,
+                    batch_size=BULK_BATCH_SIZE,
+                )
+
+    def process_element(self, element):
+        # Retrieve parent and create parent
+        parent = retrieve_element(element.id)
+        self.insert_element(parent, parent_id=None)
+        for child in list_children(parent_id=element.id):
+            self.insert_element(child, parent_id=element.id)
 
 
 def main():
-- 
GitLab