From 56b1ba7c53c5d476366aa07b5b0a6856aa265a77 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Tue, 9 May 2023 16:29:41 +0200 Subject: [PATCH] major refactoring --- tests/conftest.py | 2 + tests/test_worker.py | 14 +- worker_generic_training_dataset/cache.py | 300 ------------------ worker_generic_training_dataset/db.py | 79 +---- worker_generic_training_dataset/exceptions.py | 36 --- worker_generic_training_dataset/utils.py | 31 -- worker_generic_training_dataset/worker.py | 257 +++++++++------ 7 files changed, 177 insertions(+), 542 deletions(-) delete mode 100644 worker_generic_training_dataset/cache.py delete mode 100644 worker_generic_training_dataset/exceptions.py diff --git a/tests/conftest.py b/tests/conftest.py index 0121c27..b0da596 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,8 @@ def setup_environment(responses, monkeypatch): "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json", ) responses.add_passthru(schema_url) + # To allow image download + responses.add_passthru("https://europe-gamma.iiif.teklia.com/iiif/2") # Set schema url in environment os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url diff --git a/tests/test_worker.py b/tests/test_worker.py index 3c3703f..cf638a8 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- -import tempfile from argparse import Namespace -from pathlib import Path from arkindex_worker.cache import ( CachedClassification, @@ -15,7 +13,7 @@ from arkindex_worker.cache import ( from worker_generic_training_dataset.worker import DatasetExtractor -def test_process_split(): +def test_process_split(tmp_path): # Parent is train folder parent_id = "a0c4522d-2d80-4766-a01c-b9d686f41f6a" @@ -26,9 +24,9 @@ def test_process_split(): worker.cached_images = dict() # Where to save the downloaded images - image_folder = Path(tempfile.mkdtemp()) + worker.image_folder = tmp_path - worker.process_split("train", parent_id, image_folder) + worker.process_split("train", parent_id) # Should have created two pages under root folder assert ( @@ -53,9 +51,9 @@ def test_process_split(): # Should have created two images assert CachedImage.select().count() == 2 - assert sorted(image_folder.rglob("*")) == [ - image_folder / "80a84b30-1ae1-4c13-95d6-7d0d8ee16c51.jpg", - image_folder / "e3c755f2-0e1c-468e-ae4c-9206f0fd267a.jpg", + assert sorted(tmp_path.rglob("*")) == [ + tmp_path / "80a84b30-1ae1-4c13-95d6-7d0d8ee16c51.jpg", + tmp_path / "e3c755f2-0e1c-468e-ae4c-9206f0fd267a.jpg", ] # Should have created a transcription linked to first line of first page diff --git a/worker_generic_training_dataset/cache.py b/worker_generic_training_dataset/cache.py deleted file mode 100644 index b3b4063..0000000 --- a/worker_generic_training_dataset/cache.py +++ /dev/null @@ -1,300 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Database mappings and helper methods for the experimental worker caching feature. - -On methods that support caching, the database will be used for all reads, -and writes will go both to the Arkindex API and the database, -reducing network usage. -""" - -import json -from pathlib import Path -from typing import Optional, Union - -from arkindex_worker import logger -from peewee import ( - BooleanField, - CharField, - Check, - CompositeKey, - Field, - FloatField, - ForeignKeyField, - IntegerField, - Model, - OperationalError, - SqliteDatabase, - TextField, - UUIDField, -) -from PIL import Image - -db = SqliteDatabase(None) - - -class JSONField(Field): - """ - A Peewee field that stores a JSON payload as a string and parses it automatically. - """ - - field_type = "text" - - def db_value(self, value): - if value is None: - return - return json.dumps(value) - - def python_value(self, value): - if value is None: - return - return json.loads(value) - - -class Version(Model): - """ - Cache version table, used to warn about incompatible cache databases - when a worker uses an outdated version of ``base-worker``. - """ - - version = IntegerField(primary_key=True) - - class Meta: - database = db - table_name = "version" - - -class CachedImage(Model): - """ - Cache image table - """ - - id = UUIDField(primary_key=True) - width = IntegerField() - height = IntegerField() - url = TextField() - - class Meta: - database = db - table_name = "images" - - -class CachedElement(Model): - """ - Cache element table - """ - - id = UUIDField(primary_key=True) - parent_id = UUIDField(null=True) - type = CharField(max_length=50) - image = ForeignKeyField(CachedImage, backref="elements", null=True) - polygon = JSONField(null=True) - rotation_angle = IntegerField(default=0) - mirrored = BooleanField(default=False) - initial = BooleanField(default=False) - # Needed to filter elements with cache - worker_version_id = UUIDField(null=True) - worker_run_id = UUIDField(null=True) - confidence = FloatField(null=True) - - class Meta: - database = db - table_name = "elements" - - def open_image(self, *args, max_size: Optional[int] = None, **kwargs) -> Image: - """ - Open this element's image as a Pillow image. - This does not crop the image to the element's polygon. - IIIF servers with maxWidth, maxHeight or maxArea restrictions on image size are not supported. - - :param *args: Positional arguments passed to [arkindex_worker.image.open_image][] - :param max_size: Subresolution of the image. - :param **kwargs: Keyword arguments passed to [arkindex_worker.image.open_image][] - :raises ValueError: When this element does not have an image ID or a polygon. - :return: A Pillow image. - """ - from arkindex_worker.image import open_image, polygon_bounding_box - - if not self.image_id or not self.polygon: - raise ValueError(f"Element {self.id} has no image") - - # Always fetch the image from the bounding box when size differs from full image - bounding_box = polygon_bounding_box(self.polygon) - if ( - bounding_box.width != self.image.width - or bounding_box.height != self.image.height - ): - box = f"{bounding_box.x},{bounding_box.y},{bounding_box.width},{bounding_box.height}" - else: - box = "full" - - if max_size is None: - resize = "full" - else: - # Do not resize for polygons that do not exactly match the images - # as the resize is made directly by the IIIF server using the box parameter - if ( - bounding_box.width != self.image.width - or bounding_box.height != self.image.height - ): - resize = "full" - - # Do not resize when the image is below the maximum size - elif self.image.width <= max_size and self.image.height <= max_size: - resize = "full" - else: - ratio = max_size / max(self.image.width, self.image.height) - new_width, new_height = int(self.image.width * ratio), int( - self.image.height * ratio - ) - resize = f"{new_width},{new_height}" - - url = self.image.url - if not url.endswith("/"): - url += "/" - - return open_image( - f"{url}{box}/{resize}/0/default.jpg", - *args, - rotation_angle=self.rotation_angle, - mirrored=self.mirrored, - **kwargs, - ) - - -class CachedTranscription(Model): - """ - Cache transcription table - """ - - id = UUIDField(primary_key=True) - element = ForeignKeyField(CachedElement, backref="transcriptions") - text = TextField() - confidence = FloatField(null=True) - orientation = CharField(max_length=50) - # Needed to filter transcriptions with cache - worker_version_id = UUIDField(null=True) - worker_run_id = UUIDField(null=True) - - class Meta: - database = db - table_name = "transcriptions" - - -class CachedClassification(Model): - """ - Cache classification table - """ - - id = UUIDField(primary_key=True) - element = ForeignKeyField(CachedElement, backref="classifications") - class_name = TextField() - confidence = FloatField() - state = CharField(max_length=10) - worker_run_id = UUIDField(null=True) - - class Meta: - database = db - table_name = "classifications" - - -class CachedEntity(Model): - """ - Cache entity table - """ - - id = UUIDField(primary_key=True) - type = CharField(max_length=50) - name = TextField() - validated = BooleanField(default=False) - metas = JSONField(null=True) - worker_run_id = UUIDField(null=True) - - class Meta: - database = db - table_name = "entities" - - -class CachedTranscriptionEntity(Model): - """ - Cache transcription entity table - """ - - transcription = ForeignKeyField( - CachedTranscription, backref="transcription_entities" - ) - entity = ForeignKeyField(CachedEntity, backref="transcription_entities") - offset = IntegerField(constraints=[Check("offset >= 0")]) - length = IntegerField(constraints=[Check("length > 0")]) - worker_run_id = UUIDField(null=True) - confidence = FloatField(null=True) - - class Meta: - primary_key = CompositeKey("transcription", "entity") - database = db - table_name = "transcription_entities" - - -# Add all the managed models in that list -# It's used here, but also in unit tests -MODELS = [ - CachedImage, - CachedElement, - CachedTranscription, - CachedClassification, - CachedEntity, - CachedTranscriptionEntity, -] -SQL_VERSION = 2 - - -def init_cache_db(path: str): - """ - Create the cache database on the given path - :param path: Where the new database should be created - """ - db.init( - path, - pragmas={ - # SQLite ignores foreign keys and check constraints by default! - "foreign_keys": 1, - "ignore_check_constraints": 0, - }, - ) - db.connect() - logger.info(f"Connected to cache on {path}") - - -def create_tables(): - """ - Creates the tables in the cache DB only if they do not already exist. - """ - db.create_tables(MODELS) - - -def create_version_table(): - """ - Creates the Version table in the cache DB. - This step must be independent from other tables creation since we only - want to create the table and add the one and only Version entry when the - cache is created from scratch. - """ - db.create_tables([Version]) - Version.create(version=SQL_VERSION) - - -def check_version(cache_path: Union[str, Path]): - """ - Check the validity of the SQLite version - - :param cache_path: Path towards a local SQLite database - """ - with SqliteDatabase(cache_path) as provided_db: - with provided_db.bind_ctx([Version]): - try: - version = Version.get().version - except OperationalError: - version = None - - assert ( - version == SQL_VERSION - ), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}" diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py index e81e652..bb18f7e 100644 --- a/worker_generic_training_dataset/db.py +++ b/worker_generic_training_dataset/db.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- -from typing import NamedTuple -from arkindex_export import Classification, Image +from arkindex_export import Classification from arkindex_export.models import ( Element, Entity, @@ -9,15 +8,7 @@ from arkindex_export.models import ( Transcription, TranscriptionEntity, ) -from arkindex_export.queries import list_children -from arkindex_worker.cache import ( - CachedElement, - CachedEntity, - CachedTranscription, - CachedTranscriptionEntity, -) - -DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr" +from arkindex_worker.cache import CachedElement, CachedTranscription def retrieve_element(element_id: str): @@ -28,72 +19,14 @@ def list_classifications(element_id: str): return Classification.select().where(Classification.element_id == element_id) -def parse_transcription(transcription: Transcription, element: CachedElement): - return CachedTranscription( - id=transcription.id, - element=element, - text=transcription.text, - # Dodge not-null constraint for now - confidence=transcription.confidence or 1.0, - orientation=DEFAULT_TRANSCRIPTION_ORIENTATION, - worker_version_id=transcription.worker_version.id - if transcription.worker_version - else None, - ) - - def list_transcriptions(element: CachedElement): - query = Transcription.select().where(Transcription.element_id == element.id) - return [parse_transcription(transcription, element) for transcription in query] - - -def parse_entities(data: NamedTuple, transcription: CachedTranscription): - entity = CachedEntity( - id=data.entity_id, - type=data.type, - name=data.name, - validated=data.validated, - metas=data.metas, - ) - return entity, CachedTranscriptionEntity( - id=data.transcription_entity_id, - transcription=transcription, - entity=entity, - offset=data.offset, - length=data.length, - confidence=data.confidence, - ) + return Transcription.select().where(Transcription.element_id == element.id) -def retrieve_entities(transcription: CachedTranscription): - query = ( - TranscriptionEntity.select( - TranscriptionEntity.id.alias("transcription_entity_id"), - TranscriptionEntity.length.alias("length"), - TranscriptionEntity.offset.alias("offset"), - TranscriptionEntity.confidence.alias("confidence"), - Entity.id.alias("entity_id"), - EntityType.name.alias("type"), - Entity.name, - Entity.validated, - Entity.metas, - ) +def list_transcription_entities(transcription: CachedTranscription): + return ( + TranscriptionEntity.select() .where(TranscriptionEntity.transcription_id == transcription.id) .join(Entity, on=TranscriptionEntity.entity) .join(EntityType, on=Entity.type) ) - data = [ - parse_entities(entity_data, transcription) - for entity_data in query.namedtuples() - ] - if not data: - return [], [] - - return zip(*data) - - -def get_children(parent_id: UUID, element_type=None): - query = list_children(parent_id).join(Image) - if element_type: - query = query.where(Element.type == element_type) - return query diff --git a/worker_generic_training_dataset/exceptions.py b/worker_generic_training_dataset/exceptions.py deleted file mode 100644 index 062d580..0000000 --- a/worker_generic_training_dataset/exceptions.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- - - -class ElementProcessingError(Exception): - """ - Raised when a problem is encountered while processing an element - """ - - element_id: str - """ - ID of the element being processed. - """ - - def __init__(self, element_id: str, *args: object) -> None: - super().__init__(*args) - self.element_id = element_id - - -class ImageDownloadError(ElementProcessingError): - """ - Raised when an element's image could not be downloaded - """ - - error: Exception - """ - Error encountered. - """ - - def __init__(self, element_id: str, error: Exception, *args: object) -> None: - super().__init__(element_id, *args) - self.error = error - - def __str__(self) -> str: - return ( - f"Couldn't retrieve image of element ({self.element_id}: {str(self.error)})" - ) diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py index 01ee0f9..6bfb519 100644 --- a/worker_generic_training_dataset/utils.py +++ b/worker_generic_training_dataset/utils.py @@ -1,16 +1,9 @@ # -*- coding: utf-8 -*- import ast import logging -import time -from pathlib import Path from urllib.parse import urljoin -import cv2 -import imageio.v2 as iio -from worker_generic_training_dataset.exceptions import ImageDownloadError - logger = logging.getLogger(__name__) -MAX_RETRIES = 5 def bounding_box(polygon: list): @@ -28,27 +21,3 @@ def build_image_url(element): return urljoin( element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" ) - - -def download_image(element, folder: Path): - """ - Download the image to `folder / {element.image.id}.jpg` - """ - tries = 1 - # retry loop - while True: - if tries > MAX_RETRIES: - raise ImageDownloadError(element.id, Exception("Maximum retries reached.")) - try: - image = iio.imread(build_image_url(element)) - cv2.imwrite( - str(folder / f"{element.image.id}.jpg"), - cv2.cvtColor(image, cv2.COLOR_BGR2RGB), - ) - break - except TimeoutError: - logger.warning("Timeout, retry in 1 second.") - time.sleep(1) - tries += 1 - except Exception as e: - raise ImageDownloadError(element.id, e) diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 8174350..518f562 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- import logging import operator -import shutil import tempfile from pathlib import Path -from typing import Optional +from typing import List, Optional from uuid import UUID from apistar.exceptions import ErrorResponse -from arkindex_export import open_database +from arkindex_export import Element, Image, open_database +from arkindex_export.queries import list_children from arkindex_worker.cache import ( CachedClassification, CachedElement, @@ -21,20 +21,21 @@ from arkindex_worker.cache import ( ) from arkindex_worker.cache import db as cache_database from arkindex_worker.cache import init_cache_db +from arkindex_worker.image import download_image from arkindex_worker.utils import create_tar_zst_archive from arkindex_worker.worker.base import BaseWorker from worker_generic_training_dataset.db import ( - get_children, list_classifications, + list_transcription_entities, list_transcriptions, retrieve_element, - retrieve_entities, ) -from worker_generic_training_dataset.utils import download_image +from worker_generic_training_dataset.utils import build_image_url logger = logging.getLogger(__name__) BULK_BATCH_SIZE = 50 +DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr" class DatasetExtractor(BaseWorker): @@ -49,11 +50,6 @@ class DatasetExtractor(BaseWorker): logger.info("Overriding with user_configuration") self.config.update(self.user_configuration) - # database arg is mandatory in dev mode - assert ( - not self.is_read_only or self.args.database is not None - ), "`--database` arg is mandatory in developer mode." - # Read process information self.read_training_related_information() @@ -66,6 +62,10 @@ class DatasetExtractor(BaseWorker): # CachedImage downloaded and created in DB self.cached_images = dict() + # Where to save the downloaded images + self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data")) + logger.info(f"Images will be saved at `{self.image_folder}`.") + def read_training_related_information(self): """ Read from process information @@ -87,15 +87,11 @@ class DatasetExtractor(BaseWorker): self.testing_folder_id = UUID(test_folder_id) if test_folder_id else None def initialize_database(self): - # Create db at - # - self.workdir / "db.sqlite" in Arkindex mode - # - self.args.database in dev mode - database_path = ( - self.args.database - if self.is_read_only - else self.work_dir / "db.sqlite" - ) - if database_path.exists(): + """ + Create an SQLite database compatible with base-worker cache and initialize it. + """ + database_path = self.work_dir / "db.sqlite" + # Remove previous execution result if present database_path.unlink(missing_ok=True) init_cache_db(database_path) @@ -105,12 +101,17 @@ class DatasetExtractor(BaseWorker): create_tables() def download_latest_export(self): - # Find export of corpus + """ + Download the latest export of the current corpus. + Export must be in `"done"` state. + """ try: - exports = self.api_client.request( - "ListExports", - id=self.corpus_id, - )["results"] + exports = list( + self.api_client.paginate( + "ListExports", + id=self.corpus_id, + ) + ) except ErrorResponse as e: logger.error( f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}" @@ -123,7 +124,9 @@ class DatasetExtractor(BaseWorker): key=operator.itemgetter("updated"), reverse=True, ) - assert len(exports) > 0, f"No available exports found for the corpus {self.corpus_id}." + assert ( + len(exports) > 0 + ), f"No available exports found for the corpus {self.corpus_id}." # Download latest it in a tmpfile try: @@ -137,18 +140,129 @@ class DatasetExtractor(BaseWorker): open_database(self.export.name) except ErrorResponse as e: logger.error( - f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e)}" + f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}" ) raise e - def insert_element( - self, element, image_folder: Path, parent_id: Optional[str] = None - ): + def insert_classifications(self, element: CachedElement) -> None: + logger.info("Listing classifications") + classifications = [ + CachedClassification( + id=classification.id, + element=element, + class_name=classification.class_name, + confidence=classification.confidence, + state=classification.state, + ) + for classification in list_classifications(element.id) + ] + if classifications: + logger.info(f"Inserting {len(classifications)} classification(s)") + with cache_database.atomic(): + CachedClassification.bulk_create( + model_list=classifications, + batch_size=BULK_BATCH_SIZE, + ) + + def insert_transcriptions( + self, element: CachedElement + ) -> List[CachedTranscription]: + logger.info("Listing transcriptions") + transcriptions = [ + CachedTranscription( + id=transcription.id, + element=element, + text=transcription.text, + # Dodge not-null constraint for now + confidence=transcription.confidence or 1.0, + orientation=DEFAULT_TRANSCRIPTION_ORIENTATION, + worker_version_id=transcription.worker_version.id + if transcription.worker_version + else None, + ) + for transcription in list_transcriptions(element) + ] + if not transcriptions: + return [] + + logger.info(f"Inserting {len(transcriptions)} transcription(s)") + with cache_database.atomic(): + CachedTranscription.bulk_create( + model_list=transcriptions, + batch_size=BULK_BATCH_SIZE, + ) + return transcriptions + + def insert_entities(self, transcriptions: List[CachedTranscription]): + logger.info("Listing entities") + extracted_entities = [] + for transcription in transcriptions: + for transcription_entity in list_transcription_entities(transcription): + entity = CachedEntity( + id=transcription_entity.entity.id, + type=transcription_entity.entity.type.name, + name=transcription_entity.entity.name, + validated=transcription_entity.entity.validated, + metas=transcription_entity.entity.metas, + ) + extracted_entities.append( + ( + entity, + CachedTranscriptionEntity( + id=transcription_entity.id, + transcription=transcription, + entity=entity, + offset=transcription_entity.offset, + length=transcription_entity.length, + confidence=transcription_entity.confidence, + ), + ) + ) + if not extracted_entities: + # Early return if no entities found + return + + entities, transcription_entities = zip(*extracted_entities) + + # First insert entities since they are foreign keys on transcription entities + logger.info(f"Inserting {len(entities)} entities") + with cache_database.atomic(): + CachedEntity.bulk_create( + model_list=entities, + batch_size=BULK_BATCH_SIZE, + ) + + if transcription_entities: + # Insert transcription entities + logger.info( + f"Inserting {len(transcription_entities)} transcription entities" + ) + with cache_database.atomic(): + CachedTranscriptionEntity.bulk_create( + model_list=transcription_entities, + batch_size=BULK_BATCH_SIZE, + ) + + def insert_element(self, element: Element, parent_id: Optional[str] = None): + """ + Insert the given element's children in the cache database. + Their image will also be saved to disk, if they weren't already. + + The insertion of an element includes: + - its classifications + - its transcriptions + - its transcriptions' entities (both Entity and TranscriptionEntity) + + :param element: Element to insert. All its children will be inserted as well. + :param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements. + """ logger.info(f"Processing element ({element.id})") if element.image and element.image.id not in self.cached_images: # Download image logger.info("Downloading image") - download_image(element, folder=image_folder) + download_image(url=build_image_url(element)).save( + self.image_folder / f"{element.image.id}.jpg" + ) # Insert image logger.info("Inserting image") # Store images in case some other elements use it as well @@ -169,8 +283,6 @@ class DatasetExtractor(BaseWorker): polygon=element.polygon, rotation_angle=element.rotation_angle, mirrored=element.mirrored, - worker_version_id=element.worker_version - if element.worker_version worker_version_id=element.worker_version.id if element.worker_version else None, @@ -178,89 +290,48 @@ class DatasetExtractor(BaseWorker): ) # Insert classifications - logger.info("Listing classifications") - classifications = [ - CachedClassification( - id=classification.id, - element=cached_element, - class_name=classification.class_name, - confidence=classification.confidence, - state=classification.state, - ) - for classification in list_classifications(element.id) - ] - if classifications: - logger.info(f"Inserting {len(classifications)} classification(s)") - with cache_database.atomic(): - CachedClassification.bulk_create( - model_list=classifications, - batch_size=BULK_BATCH_SIZE, - ) + self.insert_classifications(cached_element) # Insert transcriptions - logger.info("Listing transcriptions") - transcriptions = list_transcriptions(cached_element) - if transcriptions: - logger.info(f"Inserting {len(transcriptions)} transcription(s)") - with cache_database.atomic(): - CachedTranscription.bulk_create( - model_list=transcriptions, - batch_size=BULK_BATCH_SIZE, - ) + transcriptions = self.insert_transcriptions(cached_element) # Insert entities - logger.info("Listing entities") - entities, transcription_entities = zip(*[retrieve_entities(transcription) for transcription in transcriptions)) - - if entities: - logger.info(f"Inserting {len(entities)} entities") - with cache_database.atomic(): - CachedEntity.bulk_create( - model_list=entities, - batch_size=BULK_BATCH_SIZE, - ) - # Insert transcription entities - logger.info( - f"Inserting {len(transcription_entities)} transcription entities" - ) - with cache_database.atomic(): - CachedTranscriptionEntity.bulk_create( - model_list=transcription_entities, - batch_size=BULK_BATCH_SIZE, - ) + self.insert_entities(transcriptions) - def process_split(self, split_name, split_id, image_folder): + def process_split(self, split_name: str, split_id: UUID): + """ + Insert all elements under the given parent folder. + - `page` elements are linked to this folder (via parent_id foreign key) + - `page` element children are linked to their `page` parent (via parent_id foreign key) + """ logger.info( f"Filling the Base-Worker cache with information from children under element ({split_id})" ) # Fill cache # Retrieve parent and create parent parent = retrieve_element(split_id) - self.insert_element(parent, image_folder) + self.insert_element(parent) # First list all pages - pages = get_children(parent_id=split_id, element_type="page") + pages = list_children(split_id).join(Image).where(Element.type == "page") nb_pages = pages.count() for idx, page in enumerate(pages, start=1): logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})") # Insert page - self.insert_element(page, image_folder, parent_id=split_id) + self.insert_element(page, parent_id=split_id) # List children - children = get_children(parent_id=page.id) + children = list_children(page.id) nb_children = children.count() for child_idx, child in enumerate(children, start=1): logger.info(f"Processing child ({child_idx}/{nb_children})") # Insert child - self.insert_element(child, image_folder, parent_id=page.id) + self.insert_element(child, parent_id=page.id) def run(self): self.configure() - # Where to save the downloaded images - image_folder = Path(tempfile.mkdtemp()) - # Iterate over given split for split_name, split_id in [ ("Train", self.training_folder_id), @@ -269,20 +340,18 @@ class DatasetExtractor(BaseWorker): ]: if not split_id: continue - self.process_split(split_name, split_id, image_folder) + self.process_split(split_name, split_id) # TAR + ZSTD Image folder and store as task artifact zstd_archive_path = self.work_dir / "arkindex_data.zstd" logger.info(f"Compressing the images to {zstd_archive_path}") - create_tar_zst_archive(source=image_folder, destination=zstd_archive_path) - - # Cleanup image folder - shutil.rmtree(image_folder) + create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path) def main(): DatasetExtractor( - description="Fill base-worker cache with information about dataset and extract images", support_cache=True + description="Fill base-worker cache with information about dataset and extract images", + support_cache=True, ).run() -- GitLab