diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py index f04319f5c464ffa904cfb57e62379cb2a7e5d0d0..a608ab93e5cfe626b273d022966282067a9fab89 100644 --- a/worker_generic_training_dataset/db.py +++ b/worker_generic_training_dataset/db.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from typing import NamedTuple -from arkindex_export import Classification, ElementPath, Image +from arkindex_export import Classification, Image from arkindex_export.models import ( Element, Entity, @@ -9,6 +9,7 @@ from arkindex_export.models import ( Transcription, TranscriptionEntity, ) +from arkindex_export.queries import list_children from arkindex_worker.cache import ( CachedElement, CachedEntity, @@ -92,45 +93,8 @@ def retrieve_entities(transcription: CachedTranscription): return zip(*data) -def list_children(parent_id): - # First, build the base query to get direct children - base = ( - ElementPath.select( - ElementPath.child_id, ElementPath.parent_id, ElementPath.ordering - ) - .where(ElementPath.parent_id == parent_id) - .cte("children", recursive=True, columns=("child_id", "parent_id", "ordering")) - ) - - # Then build the second recursive query, using an alias to join both queries on the same table - EP = ElementPath.alias() - recursion = EP.select(EP.child_id, EP.parent_id, EP.ordering).join( - base, on=(EP.parent_id == base.c.child_id) - ) - - # Combine both queries, using UNION and not UNION ALL to deduplicate parents - # that might be found multiple times with complex element structures - cte = base.union(recursion) - - # And load all the elements found in the CTE - query = ( - Element.select( - Element.id, - Element.type, - Image.id.alias("image_id"), - Image.width.alias("image_width"), - Image.height.alias("image_height"), - Image.url.alias("image_url"), - Element.polygon, - Element.rotation_angle, - Element.mirrored, - Element.worker_version, - Element.confidence, - cte.c.parent_id, - ) - .with_cte(cte) - .join(cte, on=(Element.id == cte.c.child_id)) - .join(Image, on=(Element.image_id == Image.id)) - .order_by(cte.c.ordering.asc()) - ) +def get_children(parent_id, element_type=None): + query = list_children(parent_id).join(Image) + if element_type: + query = query.where(Element.type == element_type) return query diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py index 47f2bfc44638a46328c9f0551a6c034c69488fd9..39b034ce98001e23880d7970fc2d6a1027542e75 100644 --- a/worker_generic_training_dataset/utils.py +++ b/worker_generic_training_dataset/utils.py @@ -30,7 +30,7 @@ def bounding_box(polygon: list): def build_image_url(element): x, y, width, height = bounding_box(ast.literal_eval(element.polygon)) return urljoin( - element.image_url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" + element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" ) @@ -46,7 +46,7 @@ def download_image(element, folder: Path): try: image = iio.imread(build_image_url(element)) cv2.imwrite( - str(folder / f"{element.image_id}.png"), + str(folder / f"{element.image.id}.jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB), ) break diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index f24e062f0d9cde9d5e88c19fd480fc23f2f65a9c..4c475a4ae5331d271f286a6ab12da4e94002127b 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- import logging import operator +import shutil import tempfile from pathlib import Path +from typing import Optional +from uuid import UUID from apistar.exceptions import ErrorResponse from arkindex_export import open_database @@ -18,9 +21,9 @@ from arkindex_worker.cache import ( ) from arkindex_worker.cache import db as cache_database from arkindex_worker.cache import init_cache_db -from arkindex_worker.worker import ElementsWorker +from arkindex_worker.worker.base import BaseWorker from worker_generic_training_dataset.db import ( - list_children, + get_children, list_classifications, list_transcriptions, retrieve_element, @@ -36,15 +39,27 @@ logger = logging.getLogger(__name__) BULK_BATCH_SIZE = 50 -class DatasetExtractor(ElementsWorker): +class DatasetExtractor(BaseWorker): def configure(self): - super().configure() + self.args = self.parser.parse_args() + if self.is_read_only: + super().configure_for_developers() + self.process_information = { + "train_folder_id": "47a0e07b-d07a-4969-aced-44450d132f0d", + "validation_folder_id": "8cbc4b53-9e07-4a72-b4e6-93f7f5b0cbed", + "test_folder_id": "659a37ea-3b26-42f0-8b65-78964f9e433e", + } + else: + super().configure() # database arg is mandatory in dev mode assert ( not self.is_read_only or self.args.database is not None ), "`--database` arg is mandatory in developer mode." + # Read process information + self.read_training_related_information() + # Download corpus self.download_latest_export() @@ -54,6 +69,27 @@ class DatasetExtractor(ElementsWorker): # Cached Images downloaded and created in DB self.cached_images = dict() + def read_training_related_information(self): + """ + Read from process information + - train_folder_id + - validation_folder_id + - test_folder_id (optional) + - model_id + """ + logger.info("Retrieving information from process_information") + + train_folder_id = self.process_information.get("train_folder_id") + assert train_folder_id, "A training folder id is necessary to use this worker" + self.training_folder_id = UUID(train_folder_id) + + val_folder_id = self.process_information.get("validation_folder_id") + assert val_folder_id, "A validation folder id is necessary to use this worker" + self.validation_folder_id = UUID(val_folder_id) + + test_folder_id = self.process_information.get("test_folder_id") + self.testing_folder_id = UUID(test_folder_id) if test_folder_id else None + def initialize_database(self): # Create db at # - self.workdir / "db.sqlite" in Arkindex mode @@ -109,9 +145,11 @@ class DatasetExtractor(ElementsWorker): ) raise e - def insert_element(self, element, image_folder: Path, root: bool = False): + def insert_element( + self, element, image_folder: Path, parent_id: Optional[str] = None + ): logger.info(f"Processing element ({element.id})") - if element.image_id and element.image_id not in self.cached_images: + if element.image and element.image.id not in self.cached_images: # Download image logger.info("Downloading image") download_image(element, folder=image_folder) @@ -119,19 +157,19 @@ class DatasetExtractor(ElementsWorker): logger.info("Inserting image") # Store images in case some other elements use it as well self.cached_images[element.image_id] = CachedImage.create( - id=element.image_id, - width=element.image_width, - height=element.image_height, - url=element.image_url, + id=element.image.id, + width=element.image.width, + height=element.image.height, + url=element.image.url, ) # Insert element logger.info("Inserting element") cached_element = CachedElement.create( id=element.id, - parent_id=None if root else element.parent_id, + parent_id=parent_id, type=element.type, - image=self.cached_images[element.image_id] if element.image_id else None, + image=self.cached_images[element.image.id] if element.image else None, polygon=element.polygon, rotation_angle=element.rotation_angle, mirrored=element.mirrored, @@ -196,29 +234,56 @@ class DatasetExtractor(ElementsWorker): batch_size=BULK_BATCH_SIZE, ) - def process_element(self, element): - # Where to save the downloaded images - image_folder = Path(tempfile.mkdtemp()) - + def process_split(self, split_name, split_id, image_folder): logger.info( - f"Filling the Base-Worker cache with information from children under element ({element.id})" + f"Filling the Base-Worker cache with information from children under element ({split_id})" ) # Fill cache # Retrieve parent and create parent - parent = retrieve_element(element.id) - self.insert_element(parent, image_folder, root=True) - # Create children - children = list_children(parent_id=element.id) - nb_children = children.count() - for idx, child in enumerate(children.namedtuples(), start=1): - logger.info(f"Processing child ({idx}/{nb_children})") - self.insert_element(child, image_folder) + parent = retrieve_element(split_id) + self.insert_element(parent, image_folder) + + # First list all pages + pages = get_children(parent_id=split_id, element_type="page") + nb_pages = pages.count() + for idx, page in enumerate(pages, start=1): + logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})") + + # Insert page + self.insert_element(page, image_folder, parent_id=split_id) + + # List children + children = get_children(parent_id=page.id) + nb_children = children.count() + for child_idx, child in enumerate(children, start=1): + logger.info(f"Processing child ({child_idx}/{nb_children})") + # Insert child + self.insert_element(child, image_folder, parent_id=page.id) + + def run(self): + self.configure() + + # Where to save the downloaded images + image_folder = Path(tempfile.mkdtemp()) + + # Iterate over given split + for split_name, split_id in [ + ("Train", self.training_folder_id), + ("Validation", self.validation_folder_id), + ("Test", self.testing_folder_id), + ]: + if not split_id: + continue + self.process_split(split_name, split_id, image_folder) # TAR + ZSTD Image folder and store as task artifact zstd_archive_path = Path(self.work_dir) / "arkindex_data.zstd" logger.info(f"Compressing the images to {zstd_archive_path}") create_tar_zstd_archive(folder_path=image_folder, destination=zstd_archive_path) + # Cleanup image folder + shutil.rmtree(image_folder) + def main(): DatasetExtractor(