From cee333fb010318c9f4d056566f9f3d3358d8eb4c Mon Sep 17 00:00:00 2001
From: EvaBardou <bardou@teklia.com>
Date: Wed, 8 Nov 2023 10:24:16 +0100
Subject: [PATCH] Save Dataset and DatasetElements in cache database

---
 requirements.txt                          |  4 +-
 tests/test_worker.py                      | 28 +++++++++++-
 worker_generic_training_dataset/worker.py | 52 ++++++++++++++++++++---
 3 files changed, 74 insertions(+), 10 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index 2fc24b3..be8c01c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,2 +1,2 @@
-arkindex-base-worker==0.3.5rc4
-arkindex-export==0.1.7
+arkindex-base-worker==0.3.5rc5
+arkindex-export==0.1.8
diff --git a/tests/test_worker.py b/tests/test_worker.py
index 2b4651c..05971b3 100644
--- a/tests/test_worker.py
+++ b/tests/test_worker.py
@@ -1,10 +1,13 @@
 # -*- coding: utf-8 -*-
 
+import json
 from argparse import Namespace
-from uuid import UUID
+from uuid import UUID, uuid4
 
 from arkindex_worker.cache import (
     CachedClassification,
+    CachedDataset,
+    CachedDatasetElement,
     CachedElement,
     CachedEntity,
     CachedImage,
@@ -30,15 +33,24 @@ def test_process_split(tmp_path, downloaded_images):
     first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
     second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
 
+    # The dataset should already be saved in database when we call `process_split`
+    cached_dataset = CachedDataset.create(
+        id=uuid4(),
+        name="My dataset",
+        state="complete",
+        sets=json.dumps(["train", "val", "test"]),
+    )
+
     worker.process_split(
         "train",
         [
             retrieve_element(first_page_id),
             retrieve_element(second_page_id),
         ],
+        cached_dataset,
     )
 
-    # Should have created 20 elements in total
+    # Should have created 19 elements in total
     assert CachedElement.select().count() == 19
 
     # Should have created two pages at root
@@ -125,6 +137,18 @@ def test_process_split(tmp_path, downloaded_images):
     assert tr_entity.confidence == 1.0
     assert tr_entity.worker_run_id is None
 
+    # Should have linked all the elements to the correct dataset & split
+    assert CachedDatasetElement.select().count() == 19
+    assert (
+        CachedDatasetElement.select()
+        .where(
+            CachedDatasetElement.dataset == cached_dataset,
+            CachedDatasetElement.set_name == "train",
+        )
+        .count()
+        == 19
+    )
+
     # Full structure of the archive
     assert sorted(tmp_path.rglob("*")) == [
         tmp_path / "db.sqlite",
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index 20c0934..c8123c7 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -1,6 +1,8 @@
 # -*- coding: utf-8 -*-
+import json
 import logging
 import tempfile
+import uuid
 from argparse import Namespace
 from operator import itemgetter
 from pathlib import Path
@@ -13,6 +15,8 @@ from arkindex_export import Element, open_database
 from arkindex_export.queries import list_children
 from arkindex_worker.cache import (
     CachedClassification,
+    CachedDataset,
+    CachedDatasetElement,
     CachedElement,
     CachedEntity,
     CachedImage,
@@ -227,7 +231,11 @@ class DatasetExtractor(DatasetWorker):
                 )
 
     def insert_element(
-        self, element: Element, parent_id: Optional[UUID] = None
+        self,
+        element: Element,
+        dataset: CachedDataset,
+        split_name: str,
+        parent_id: Optional[UUID] = None,
     ) -> None:
         """
         Insert the given element in the cache database.
@@ -238,6 +246,8 @@ class DatasetExtractor(DatasetWorker):
         - its transcriptions
         - its transcriptions' entities (both Entity and TranscriptionEntity)
 
+        The element will also be linked to the appropriate split in the current dataset.
+
         :param element: Element to insert.
         :param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
         """
@@ -286,7 +296,19 @@ class DatasetExtractor(DatasetWorker):
         # Insert entities
         self.insert_entities(transcriptions)
 
-    def process_split(self, split_name: str, elements: List[Element]) -> None:
+        # Link the element to the dataset
+        logger.info(f"Linking element {cached_element.id} to dataset ({dataset.id})")
+        with cache_database.atomic():
+            cached_element: CachedDatasetElement = CachedDatasetElement.create(
+                id=uuid.uuid4(),
+                element=cached_element,
+                dataset=dataset,
+                set_name=split_name,
+            )
+
+    def process_split(
+        self, split_name: str, elements: List[Element], dataset: CachedDataset
+    ) -> None:
         logger.info(
             f"Filling the cache with information from elements in the split {split_name}"
         )
@@ -297,7 +319,7 @@ class DatasetExtractor(DatasetWorker):
             logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})")
 
             # Insert page
-            self.insert_element(element)
+            self.insert_element(element, dataset, split_name)
 
             # List children
             children = list_children(element.id)
@@ -305,16 +327,34 @@ class DatasetExtractor(DatasetWorker):
             for child_idx, child in enumerate(children, start=1):
                 logger.info(f"Processing child ({child_idx}/{nb_children})")
                 # Insert child
-                self.insert_element(child, parent_id=element.id)
+                self.insert_element(child, dataset, split_name, parent_id=element.id)
+
+    def insert_dataset(self, dataset: Dataset) -> None:
+        """
+        Insert the given dataset in the cache database.
+
+        :param dataset: Dataset to insert.
+        """
+        logger.info(f"Inserting dataset ({dataset.id})")
+        with cache_database.atomic():
+            return CachedDataset.create(
+                id=dataset.id,
+                name=dataset.name,
+                state=dataset.state,
+                sets=json.dumps(dataset.sets),
+            )
 
     def process_dataset(self, dataset: Dataset):
         # Configure temporary storage for the dataset data (cache + images)
         self.configure_storage()
 
+        splits = self.list_dataset_elements_per_split(dataset)
+        cached_dataset = self.insert_dataset(dataset)
+
         # Iterate over given splits
-        for split_name, elements in self.list_dataset_elements_per_split(dataset):
+        for split_name, elements in splits:
             casted_elements = list(map(_format_element, elements))
-            self.process_split(split_name, casted_elements)
+            self.process_split(split_name, casted_elements, cached_dataset)
 
         # TAR + ZSTD the cache and the images folder, and store as task artifact
         zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
-- 
GitLab