From 316f067b33b5bd8c18bd37682a62afe6ff1879ab Mon Sep 17 00:00:00 2001
From: EvaBardou <bardou@teklia.com>
Date: Fri, 20 Oct 2023 10:51:22 +0200
Subject: [PATCH] Save the cache in the archive too

---
 tests/test_worker.py                      | 17 +++++++++++------
 worker_generic_training_dataset/worker.py | 22 ++++++++++++++--------
 2 files changed, 25 insertions(+), 14 deletions(-)

diff --git a/tests/test_worker.py b/tests/test_worker.py
index cd01c8e..c64e177 100644
--- a/tests/test_worker.py
+++ b/tests/test_worker.py
@@ -19,11 +19,13 @@ def test_process_split(tmp_path, downloaded_images):
     worker = DatasetExtractor()
     # Parse some arguments
     worker.args = Namespace(database=None)
+    worker.data_folder = tmp_path
     worker.configure_cache()
     worker.cached_images = dict()
 
     # Where to save the downloaded images
-    worker.image_folder = tmp_path
+    worker.images_folder = tmp_path / "images"
+    worker.images_folder.mkdir(parents=True)
 
     first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
     second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
@@ -80,11 +82,6 @@ def test_process_split(tmp_path, downloaded_images):
             == f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png"
         )
 
-    assert sorted(tmp_path.rglob("*")) == [
-        tmp_path / f"{first_image_id}.jpg",
-        tmp_path / f"{second_image_id}.jpg",
-    ]
-
     # Should have created 17 transcriptions
     assert CachedTranscription.select().count() == 17
     # Check transcription of first line on first page
@@ -127,3 +124,11 @@ def test_process_split(tmp_path, downloaded_images):
     assert tr_entity.length == 23
     assert tr_entity.confidence == 1.0
     assert tr_entity.worker_run_id is None
+
+    # Full structure of the archive
+    assert sorted(tmp_path.rglob("*")) == [
+        tmp_path / "db.sqlite",
+        tmp_path / "images",
+        tmp_path / "images" / f"{first_image_id}.jpg",
+        tmp_path / "images" / f"{second_image_id}.jpg",
+    ]
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index 93a4c75..3e54552 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -61,6 +61,9 @@ class DatasetExtractor(DatasetWorker):
         # Download corpus
         self.download_latest_export()
 
+    def configure_storage(self) -> None:
+        self.data_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
+
         # Initialize db that will be written
         self.configure_cache()
 
@@ -68,17 +71,17 @@ class DatasetExtractor(DatasetWorker):
         self.cached_images = dict()
 
         # Where to save the downloaded images
-        self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
-        logger.info(f"Images will be saved at `{self.image_folder}`.")
+        self.images_folder = self.data_folder / "images"
+        self.images_folder.mkdir(parents=True)
+        logger.info(f"Images will be saved at `{self.images_folder}`.")
 
     def configure_cache(self) -> None:
         """
         Create an SQLite database compatible with base-worker cache and initialize it.
         """
         self.use_cache = True
-        self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite"
-        # Remove previous execution result if present
-        self.cache_path.unlink(missing_ok=True)
+        self.cache_path: Path = self.data_folder / "db.sqlite"
+        logger.info(f"Cached database will be saved at `{self.data_folder}`.")
 
         init_cache_db(self.cache_path)
 
@@ -242,7 +245,7 @@ class DatasetExtractor(DatasetWorker):
             # Download image
             logger.info("Downloading image")
             download_image(url=build_image_url(element)).save(
-                self.image_folder / f"{element.image.id}.jpg"
+                self.images_folder / f"{element.image.id}.jpg"
             )
             # Insert image
             logger.info("Inserting image")
@@ -304,15 +307,18 @@ class DatasetExtractor(DatasetWorker):
                 self.insert_element(child, parent_id=element.id)
 
     def process_dataset(self, dataset: Dataset):
+        # Configure temporary storage for the dataset data (cache + images)
+        self.configure_storage()
+
         # Iterate over given splits
         for split_name, elements in self.list_dataset_elements_per_split(dataset):
             casted_elements = list(map(_format_element, elements))
             self.process_split(split_name, casted_elements)
 
-        # TAR + ZSTD Image folder and store as task artifact
+        # TAR + ZSTD the cache and the images folder, and store as task artifact
         zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
         logger.info(f"Compressing the images to {zstd_archive_path}")
-        create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path)
+        create_tar_zst_archive(source=self.data_folder, destination=zstd_archive_path)
 
 
 def main():
-- 
GitLab