From 316f067b33b5bd8c18bd37682a62afe6ff1879ab Mon Sep 17 00:00:00 2001 From: EvaBardou <bardou@teklia.com> Date: Fri, 20 Oct 2023 10:51:22 +0200 Subject: [PATCH] Save the cache in the archive too --- tests/test_worker.py | 17 +++++++++++------ worker_generic_training_dataset/worker.py | 22 ++++++++++++++-------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/tests/test_worker.py b/tests/test_worker.py index cd01c8e..c64e177 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -19,11 +19,13 @@ def test_process_split(tmp_path, downloaded_images): worker = DatasetExtractor() # Parse some arguments worker.args = Namespace(database=None) + worker.data_folder = tmp_path worker.configure_cache() worker.cached_images = dict() # Where to save the downloaded images - worker.image_folder = tmp_path + worker.images_folder = tmp_path / "images" + worker.images_folder.mkdir(parents=True) first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c") second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f") @@ -80,11 +82,6 @@ def test_process_split(tmp_path, downloaded_images): == f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png" ) - assert sorted(tmp_path.rglob("*")) == [ - tmp_path / f"{first_image_id}.jpg", - tmp_path / f"{second_image_id}.jpg", - ] - # Should have created 17 transcriptions assert CachedTranscription.select().count() == 17 # Check transcription of first line on first page @@ -127,3 +124,11 @@ def test_process_split(tmp_path, downloaded_images): assert tr_entity.length == 23 assert tr_entity.confidence == 1.0 assert tr_entity.worker_run_id is None + + # Full structure of the archive + assert sorted(tmp_path.rglob("*")) == [ + tmp_path / "db.sqlite", + tmp_path / "images", + tmp_path / "images" / f"{first_image_id}.jpg", + tmp_path / "images" / f"{second_image_id}.jpg", + ] diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 93a4c75..3e54552 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -61,6 +61,9 @@ class DatasetExtractor(DatasetWorker): # Download corpus self.download_latest_export() + def configure_storage(self) -> None: + self.data_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data")) + # Initialize db that will be written self.configure_cache() @@ -68,17 +71,17 @@ class DatasetExtractor(DatasetWorker): self.cached_images = dict() # Where to save the downloaded images - self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data")) - logger.info(f"Images will be saved at `{self.image_folder}`.") + self.images_folder = self.data_folder / "images" + self.images_folder.mkdir(parents=True) + logger.info(f"Images will be saved at `{self.images_folder}`.") def configure_cache(self) -> None: """ Create an SQLite database compatible with base-worker cache and initialize it. """ self.use_cache = True - self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite" - # Remove previous execution result if present - self.cache_path.unlink(missing_ok=True) + self.cache_path: Path = self.data_folder / "db.sqlite" + logger.info(f"Cached database will be saved at `{self.data_folder}`.") init_cache_db(self.cache_path) @@ -242,7 +245,7 @@ class DatasetExtractor(DatasetWorker): # Download image logger.info("Downloading image") download_image(url=build_image_url(element)).save( - self.image_folder / f"{element.image.id}.jpg" + self.images_folder / f"{element.image.id}.jpg" ) # Insert image logger.info("Inserting image") @@ -304,15 +307,18 @@ class DatasetExtractor(DatasetWorker): self.insert_element(child, parent_id=element.id) def process_dataset(self, dataset: Dataset): + # Configure temporary storage for the dataset data (cache + images) + self.configure_storage() + # Iterate over given splits for split_name, elements in self.list_dataset_elements_per_split(dataset): casted_elements = list(map(_format_element, elements)) self.process_split(split_name, casted_elements) - # TAR + ZSTD Image folder and store as task artifact + # TAR + ZSTD the cache and the images folder, and store as task artifact zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd" logger.info(f"Compressing the images to {zstd_archive_path}") - create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path) + create_tar_zst_archive(source=self.data_folder, destination=zstd_archive_path) def main(): -- GitLab