From 7fb7097c197db5ca406a37d8e52657d933448826 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Tue, 18 Apr 2023 15:47:20 +0200
Subject: [PATCH] working implem

---
 worker_generic_training_dataset/db.py     | 53 ++++++++++++++++--
 worker_generic_training_dataset/utils.py  | 37 +++++++++++--
 worker_generic_training_dataset/worker.py | 67 +++++++++++++++--------
 3 files changed, 126 insertions(+), 31 deletions(-)

diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py
index 05ccb67..f04319f 100644
--- a/worker_generic_training_dataset/db.py
+++ b/worker_generic_training_dataset/db.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 from typing import NamedTuple
 
-from arkindex_export import Classification
+from arkindex_export import Classification, ElementPath, Image
 from arkindex_export.models import (
     Element,
     Entity,
@@ -23,8 +23,8 @@ def retrieve_element(element_id: str):
     return Element.get_by_id(element_id)
 
 
-def list_classifications(element: Element):
-    query = Classification.select().where(Classification.element == element)
+def list_classifications(element_id: str):
+    query = Classification.select().where(Classification.element_id == element_id)
     return query
 
 
@@ -33,7 +33,8 @@ def parse_transcription(transcription: NamedTuple, element: CachedElement):
         id=transcription.id,
         element=element,
         text=transcription.text,
-        confidence=transcription.confidence,
+        # Dodge not-null constraint for now
+        confidence=transcription.confidence or 1.0,
         orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
         worker_version_id=transcription.worker_version.id
         if transcription.worker_version
@@ -89,3 +90,47 @@ def retrieve_entities(transcription: CachedTranscription):
         return [], []
 
     return zip(*data)
+
+
+def list_children(parent_id):
+    # First, build the base query to get direct children
+    base = (
+        ElementPath.select(
+            ElementPath.child_id, ElementPath.parent_id, ElementPath.ordering
+        )
+        .where(ElementPath.parent_id == parent_id)
+        .cte("children", recursive=True, columns=("child_id", "parent_id", "ordering"))
+    )
+
+    # Then build the second recursive query, using an alias to join both queries on the same table
+    EP = ElementPath.alias()
+    recursion = EP.select(EP.child_id, EP.parent_id, EP.ordering).join(
+        base, on=(EP.parent_id == base.c.child_id)
+    )
+
+    # Combine both queries, using UNION and not UNION ALL to deduplicate parents
+    # that might be found multiple times with complex element structures
+    cte = base.union(recursion)
+
+    # And load all the elements found in the CTE
+    query = (
+        Element.select(
+            Element.id,
+            Element.type,
+            Image.id.alias("image_id"),
+            Image.width.alias("image_width"),
+            Image.height.alias("image_height"),
+            Image.url.alias("image_url"),
+            Element.polygon,
+            Element.rotation_angle,
+            Element.mirrored,
+            Element.worker_version,
+            Element.confidence,
+            cte.c.parent_id,
+        )
+        .with_cte(cte)
+        .join(cte, on=(Element.id == cte.c.child_id))
+        .join(Image, on=(Element.image_id == Image.id))
+        .order_by(cte.c.ordering.asc())
+    )
+    return query
diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py
index b6e8c3e..245e587 100644
--- a/worker_generic_training_dataset/utils.py
+++ b/worker_generic_training_dataset/utils.py
@@ -1,13 +1,16 @@
 # -*- coding: utf-8 -*-
 import ast
 import logging
+import os
+import tarfile
+import tempfile
 import time
 from pathlib import Path
 from urllib.parse import urljoin
 
 import cv2
 import imageio.v2 as iio
-from arkindex_export.models import Element
+import zstandard as zstd
 from worker_generic_training_dataset.exceptions import ImageDownloadError
 
 logger = logging.getLogger(__name__)
@@ -24,14 +27,14 @@ def bounding_box(polygon: list):
     return int(x), int(y), int(width), int(height)
 
 
-def build_image_url(element: Element):
+def build_image_url(element):
     x, y, width, height = bounding_box(ast.literal_eval(element.polygon))
     return urljoin(
-        element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
+        element.image_url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
     )
 
 
-def download_image(element: Element, folder: Path):
+def download_image(element, folder: Path):
     """
     Download the image to `folder / {element.image.id}.jpg`
     """
@@ -43,7 +46,7 @@ def download_image(element: Element, folder: Path):
         try:
             image = iio.imread(build_image_url(element))
             cv2.imwrite(
-                str(folder / f"{element.image.id}.jpg"),
+                str(folder / f"{element.image_id}.jpg"),
                 cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
             )
             break
@@ -53,3 +56,27 @@ def download_image(element: Element, folder: Path):
             tries += 1
         except Exception as e:
             raise ImageDownloadError(element.id, e)
+
+
+def create_tar_zstd_archive(folder_path, destination: Path, chunk_size=1024):
+    compressor = zstd.ZstdCompressor(level=3)
+
+    # Remove extension from the model filename
+    _, path_to_tar_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar")
+
+    # Create an uncompressed tar archive with all the needed files
+    # Files hierarchy ifs kept in the archive.
+    with tarfile.open(path_to_tar_archive, "w") as tar:
+        for p in folder_path.glob("**/*"):
+            x = p.relative_to(folder_path)
+            tar.add(p, arcname=x, recursive=False)
+
+    # Compress the archive
+    with destination.open("wb") as archive_file:
+        with open(path_to_tar_archive, "rb") as model_data:
+            for model_chunk in iter(lambda: model_data.read(chunk_size), b""):
+                compressed_chunk = compressor.compress(model_chunk)
+                archive_file.write(compressed_chunk)
+
+    # Remove the tar archive
+    os.remove(path_to_tar_archive)
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index cbd0555..f24e062 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -1,12 +1,11 @@
 # -*- coding: utf-8 -*-
 import logging
 import operator
+import tempfile
 from pathlib import Path
 
 from apistar.exceptions import ErrorResponse
 from arkindex_export import open_database
-from arkindex_export.models import Element
-from arkindex_export.queries import list_children
 from arkindex_worker.cache import (
     CachedClassification,
     CachedElement,
@@ -21,16 +20,19 @@ from arkindex_worker.cache import db as cache_database
 from arkindex_worker.cache import init_cache_db
 from arkindex_worker.worker import ElementsWorker
 from worker_generic_training_dataset.db import (
+    list_children,
     list_classifications,
     list_transcriptions,
     retrieve_element,
     retrieve_entities,
 )
-from worker_generic_training_dataset.utils import download_image
+from worker_generic_training_dataset.utils import (
+    create_tar_zstd_archive,
+    download_image,
+)
 
 logger = logging.getLogger(__name__)
 
-IMAGE_FOLDER = Path("images")
 BULK_BATCH_SIZE = 50
 
 
@@ -57,8 +59,12 @@ class DatasetExtractor(ElementsWorker):
         # - self.workdir / "db.sqlite" in Arkindex mode
         # - self.args.database in dev mode
         database_path = (
-            self.args.database if self.is_read_only else self.workdir / "db.sqlite"
+            Path(self.args.database)
+            if self.is_read_only
+            else self.workdir / "db.sqlite"
         )
+        if database_path.exists():
+            database_path.unlink()
 
         init_cache_db(database_path)
 
@@ -83,6 +89,7 @@ class DatasetExtractor(ElementsWorker):
         exports = sorted(
             list(filter(lambda exp: exp["state"] == "done", exports)),
             key=operator.itemgetter("updated"),
+            reverse=True,
         )
         assert len(exports) > 0, "No available exports found."
 
@@ -102,33 +109,33 @@ class DatasetExtractor(ElementsWorker):
             )
             raise e
 
-    def insert_element(self, element: Element, parent_id: str):
+    def insert_element(self, element, image_folder: Path, root: bool = False):
         logger.info(f"Processing element ({element.id})")
-        if element.image and element.image.id not in self.cached_images:
+        if element.image_id and element.image_id not in self.cached_images:
             # Download image
             logger.info("Downloading image")
-            download_image(element, folder=IMAGE_FOLDER)
+            download_image(element, folder=image_folder)
             # Insert image
             logger.info("Inserting image")
             # Store images in case some other elements use it as well
-            self.cached_images[element.image.id] = CachedImage.create(
-                id=element.image.id,
-                width=element.image.width,
-                height=element.image.height,
-                url=element.image.url,
+            self.cached_images[element.image_id] = CachedImage.create(
+                id=element.image_id,
+                width=element.image_width,
+                height=element.image_height,
+                url=element.image_url,
             )
 
         # Insert element
         logger.info("Inserting element")
         cached_element = CachedElement.create(
             id=element.id,
-            parent_id=parent_id,
+            parent_id=None if root else element.parent_id,
             type=element.type,
-            image=element.image.id if element.image else None,
+            image=self.cached_images[element.image_id] if element.image_id else None,
             polygon=element.polygon,
             rotation_angle=element.rotation_angle,
             mirrored=element.mirrored,
-            worker_version_id=element.worker_version.id
+            worker_version_id=element.worker_version
             if element.worker_version
             else None,
             confidence=element.confidence,
@@ -144,10 +151,10 @@ class DatasetExtractor(ElementsWorker):
                 confidence=classification.confidence,
                 state=classification.state,
             )
-            for classification in list_classifications(element)
+            for classification in list_classifications(element.id)
         ]
         if classifications:
-            logger.info(f"Inserting {len(classifications)} classifications")
+            logger.info(f"Inserting {len(classifications)} classification(s)")
             with cache_database.atomic():
                 CachedClassification.bulk_create(
                     model_list=classifications,
@@ -158,7 +165,7 @@ class DatasetExtractor(ElementsWorker):
         logger.info("Listing transcriptions")
         transcriptions = list_transcriptions(cached_element)
         if transcriptions:
-            logger.info(f"Inserting {len(transcriptions)} transcriptions")
+            logger.info(f"Inserting {len(transcriptions)} transcription(s)")
             with cache_database.atomic():
                 CachedTranscription.bulk_create(
                     model_list=transcriptions,
@@ -190,11 +197,27 @@ class DatasetExtractor(ElementsWorker):
                 )
 
     def process_element(self, element):
+        # Where to save the downloaded images
+        image_folder = Path(tempfile.mkdtemp())
+
+        logger.info(
+            f"Filling the Base-Worker cache with information from children under element ({element.id})"
+        )
+        # Fill cache
         # Retrieve parent and create parent
         parent = retrieve_element(element.id)
-        self.insert_element(parent, parent_id=None)
-        for child in list_children(parent_id=element.id):
-            self.insert_element(child, parent_id=element.id)
+        self.insert_element(parent, image_folder, root=True)
+        # Create children
+        children = list_children(parent_id=element.id)
+        nb_children = children.count()
+        for idx, child in enumerate(children.namedtuples(), start=1):
+            logger.info(f"Processing child ({idx}/{nb_children})")
+            self.insert_element(child, image_folder)
+
+        # TAR + ZSTD Image folder and store as task artifact
+        zstd_archive_path = Path(self.work_dir) / "arkindex_data.zstd"
+        logger.info(f"Compressing the images to {zstd_archive_path}")
+        create_tar_zstd_archive(folder_path=image_folder, destination=zstd_archive_path)
 
 
 def main():
-- 
GitLab