From 7fb7097c197db5ca406a37d8e52657d933448826 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Tue, 18 Apr 2023 15:47:20 +0200 Subject: [PATCH] working implem --- worker_generic_training_dataset/db.py | 53 ++++++++++++++++-- worker_generic_training_dataset/utils.py | 37 +++++++++++-- worker_generic_training_dataset/worker.py | 67 +++++++++++++++-------- 3 files changed, 126 insertions(+), 31 deletions(-) diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py index 05ccb67..f04319f 100644 --- a/worker_generic_training_dataset/db.py +++ b/worker_generic_training_dataset/db.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from typing import NamedTuple -from arkindex_export import Classification +from arkindex_export import Classification, ElementPath, Image from arkindex_export.models import ( Element, Entity, @@ -23,8 +23,8 @@ def retrieve_element(element_id: str): return Element.get_by_id(element_id) -def list_classifications(element: Element): - query = Classification.select().where(Classification.element == element) +def list_classifications(element_id: str): + query = Classification.select().where(Classification.element_id == element_id) return query @@ -33,7 +33,8 @@ def parse_transcription(transcription: NamedTuple, element: CachedElement): id=transcription.id, element=element, text=transcription.text, - confidence=transcription.confidence, + # Dodge not-null constraint for now + confidence=transcription.confidence or 1.0, orientation=DEFAULT_TRANSCRIPTION_ORIENTATION, worker_version_id=transcription.worker_version.id if transcription.worker_version @@ -89,3 +90,47 @@ def retrieve_entities(transcription: CachedTranscription): return [], [] return zip(*data) + + +def list_children(parent_id): + # First, build the base query to get direct children + base = ( + ElementPath.select( + ElementPath.child_id, ElementPath.parent_id, ElementPath.ordering + ) + .where(ElementPath.parent_id == parent_id) + .cte("children", recursive=True, columns=("child_id", "parent_id", "ordering")) + ) + + # Then build the second recursive query, using an alias to join both queries on the same table + EP = ElementPath.alias() + recursion = EP.select(EP.child_id, EP.parent_id, EP.ordering).join( + base, on=(EP.parent_id == base.c.child_id) + ) + + # Combine both queries, using UNION and not UNION ALL to deduplicate parents + # that might be found multiple times with complex element structures + cte = base.union(recursion) + + # And load all the elements found in the CTE + query = ( + Element.select( + Element.id, + Element.type, + Image.id.alias("image_id"), + Image.width.alias("image_width"), + Image.height.alias("image_height"), + Image.url.alias("image_url"), + Element.polygon, + Element.rotation_angle, + Element.mirrored, + Element.worker_version, + Element.confidence, + cte.c.parent_id, + ) + .with_cte(cte) + .join(cte, on=(Element.id == cte.c.child_id)) + .join(Image, on=(Element.image_id == Image.id)) + .order_by(cte.c.ordering.asc()) + ) + return query diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py index b6e8c3e..245e587 100644 --- a/worker_generic_training_dataset/utils.py +++ b/worker_generic_training_dataset/utils.py @@ -1,13 +1,16 @@ # -*- coding: utf-8 -*- import ast import logging +import os +import tarfile +import tempfile import time from pathlib import Path from urllib.parse import urljoin import cv2 import imageio.v2 as iio -from arkindex_export.models import Element +import zstandard as zstd from worker_generic_training_dataset.exceptions import ImageDownloadError logger = logging.getLogger(__name__) @@ -24,14 +27,14 @@ def bounding_box(polygon: list): return int(x), int(y), int(width), int(height) -def build_image_url(element: Element): +def build_image_url(element): x, y, width, height = bounding_box(ast.literal_eval(element.polygon)) return urljoin( - element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" + element.image_url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" ) -def download_image(element: Element, folder: Path): +def download_image(element, folder: Path): """ Download the image to `folder / {element.image.id}.jpg` """ @@ -43,7 +46,7 @@ def download_image(element: Element, folder: Path): try: image = iio.imread(build_image_url(element)) cv2.imwrite( - str(folder / f"{element.image.id}.jpg"), + str(folder / f"{element.image_id}.jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB), ) break @@ -53,3 +56,27 @@ def download_image(element: Element, folder: Path): tries += 1 except Exception as e: raise ImageDownloadError(element.id, e) + + +def create_tar_zstd_archive(folder_path, destination: Path, chunk_size=1024): + compressor = zstd.ZstdCompressor(level=3) + + # Remove extension from the model filename + _, path_to_tar_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar") + + # Create an uncompressed tar archive with all the needed files + # Files hierarchy ifs kept in the archive. + with tarfile.open(path_to_tar_archive, "w") as tar: + for p in folder_path.glob("**/*"): + x = p.relative_to(folder_path) + tar.add(p, arcname=x, recursive=False) + + # Compress the archive + with destination.open("wb") as archive_file: + with open(path_to_tar_archive, "rb") as model_data: + for model_chunk in iter(lambda: model_data.read(chunk_size), b""): + compressed_chunk = compressor.compress(model_chunk) + archive_file.write(compressed_chunk) + + # Remove the tar archive + os.remove(path_to_tar_archive) diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index cbd0555..f24e062 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- import logging import operator +import tempfile from pathlib import Path from apistar.exceptions import ErrorResponse from arkindex_export import open_database -from arkindex_export.models import Element -from arkindex_export.queries import list_children from arkindex_worker.cache import ( CachedClassification, CachedElement, @@ -21,16 +20,19 @@ from arkindex_worker.cache import db as cache_database from arkindex_worker.cache import init_cache_db from arkindex_worker.worker import ElementsWorker from worker_generic_training_dataset.db import ( + list_children, list_classifications, list_transcriptions, retrieve_element, retrieve_entities, ) -from worker_generic_training_dataset.utils import download_image +from worker_generic_training_dataset.utils import ( + create_tar_zstd_archive, + download_image, +) logger = logging.getLogger(__name__) -IMAGE_FOLDER = Path("images") BULK_BATCH_SIZE = 50 @@ -57,8 +59,12 @@ class DatasetExtractor(ElementsWorker): # - self.workdir / "db.sqlite" in Arkindex mode # - self.args.database in dev mode database_path = ( - self.args.database if self.is_read_only else self.workdir / "db.sqlite" + Path(self.args.database) + if self.is_read_only + else self.workdir / "db.sqlite" ) + if database_path.exists(): + database_path.unlink() init_cache_db(database_path) @@ -83,6 +89,7 @@ class DatasetExtractor(ElementsWorker): exports = sorted( list(filter(lambda exp: exp["state"] == "done", exports)), key=operator.itemgetter("updated"), + reverse=True, ) assert len(exports) > 0, "No available exports found." @@ -102,33 +109,33 @@ class DatasetExtractor(ElementsWorker): ) raise e - def insert_element(self, element: Element, parent_id: str): + def insert_element(self, element, image_folder: Path, root: bool = False): logger.info(f"Processing element ({element.id})") - if element.image and element.image.id not in self.cached_images: + if element.image_id and element.image_id not in self.cached_images: # Download image logger.info("Downloading image") - download_image(element, folder=IMAGE_FOLDER) + download_image(element, folder=image_folder) # Insert image logger.info("Inserting image") # Store images in case some other elements use it as well - self.cached_images[element.image.id] = CachedImage.create( - id=element.image.id, - width=element.image.width, - height=element.image.height, - url=element.image.url, + self.cached_images[element.image_id] = CachedImage.create( + id=element.image_id, + width=element.image_width, + height=element.image_height, + url=element.image_url, ) # Insert element logger.info("Inserting element") cached_element = CachedElement.create( id=element.id, - parent_id=parent_id, + parent_id=None if root else element.parent_id, type=element.type, - image=element.image.id if element.image else None, + image=self.cached_images[element.image_id] if element.image_id else None, polygon=element.polygon, rotation_angle=element.rotation_angle, mirrored=element.mirrored, - worker_version_id=element.worker_version.id + worker_version_id=element.worker_version if element.worker_version else None, confidence=element.confidence, @@ -144,10 +151,10 @@ class DatasetExtractor(ElementsWorker): confidence=classification.confidence, state=classification.state, ) - for classification in list_classifications(element) + for classification in list_classifications(element.id) ] if classifications: - logger.info(f"Inserting {len(classifications)} classifications") + logger.info(f"Inserting {len(classifications)} classification(s)") with cache_database.atomic(): CachedClassification.bulk_create( model_list=classifications, @@ -158,7 +165,7 @@ class DatasetExtractor(ElementsWorker): logger.info("Listing transcriptions") transcriptions = list_transcriptions(cached_element) if transcriptions: - logger.info(f"Inserting {len(transcriptions)} transcriptions") + logger.info(f"Inserting {len(transcriptions)} transcription(s)") with cache_database.atomic(): CachedTranscription.bulk_create( model_list=transcriptions, @@ -190,11 +197,27 @@ class DatasetExtractor(ElementsWorker): ) def process_element(self, element): + # Where to save the downloaded images + image_folder = Path(tempfile.mkdtemp()) + + logger.info( + f"Filling the Base-Worker cache with information from children under element ({element.id})" + ) + # Fill cache # Retrieve parent and create parent parent = retrieve_element(element.id) - self.insert_element(parent, parent_id=None) - for child in list_children(parent_id=element.id): - self.insert_element(child, parent_id=element.id) + self.insert_element(parent, image_folder, root=True) + # Create children + children = list_children(parent_id=element.id) + nb_children = children.count() + for idx, child in enumerate(children.namedtuples(), start=1): + logger.info(f"Processing child ({idx}/{nb_children})") + self.insert_element(child, image_folder) + + # TAR + ZSTD Image folder and store as task artifact + zstd_archive_path = Path(self.work_dir) / "arkindex_data.zstd" + logger.info(f"Compressing the images to {zstd_archive_path}") + create_tar_zstd_archive(folder_path=image_folder, destination=zstd_archive_path) def main(): -- GitLab