From bbc80402621e3c085b2fa88cf86358cffe7b0a24 Mon Sep 17 00:00:00 2001
From: EvaBardou <bardou@teklia.com>
Date: Wed, 8 Nov 2023 10:34:54 +0100
Subject: [PATCH] Nit

---
 tests/test_worker.py                      |  5 ++---
 worker_generic_training_dataset/worker.py | 19 +++++++++----------
 2 files changed, 11 insertions(+), 13 deletions(-)

diff --git a/tests/test_worker.py b/tests/test_worker.py
index 05971b3..870c2dd 100644
--- a/tests/test_worker.py
+++ b/tests/test_worker.py
@@ -34,7 +34,7 @@ def test_process_split(tmp_path, downloaded_images):
     second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
 
     # The dataset should already be saved in database when we call `process_split`
-    cached_dataset = CachedDataset.create(
+    worker.cached_dataset = CachedDataset.create(
         id=uuid4(),
         name="My dataset",
         state="complete",
@@ -47,7 +47,6 @@ def test_process_split(tmp_path, downloaded_images):
             retrieve_element(first_page_id),
             retrieve_element(second_page_id),
         ],
-        cached_dataset,
     )
 
     # Should have created 19 elements in total
@@ -142,7 +141,7 @@ def test_process_split(tmp_path, downloaded_images):
     assert (
         CachedDatasetElement.select()
         .where(
-            CachedDatasetElement.dataset == cached_dataset,
+            CachedDatasetElement.dataset == worker.cached_dataset,
             CachedDatasetElement.set_name == "train",
         )
         .count()
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index c8123c7..5501704 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -233,7 +233,6 @@ class DatasetExtractor(DatasetWorker):
     def insert_element(
         self,
         element: Element,
-        dataset: CachedDataset,
         split_name: str,
         parent_id: Optional[UUID] = None,
     ) -> None:
@@ -297,18 +296,18 @@ class DatasetExtractor(DatasetWorker):
         self.insert_entities(transcriptions)
 
         # Link the element to the dataset
-        logger.info(f"Linking element {cached_element.id} to dataset ({dataset.id})")
+        logger.info(
+            f"Linking element {cached_element.id} to dataset ({self.cached_dataset.id})"
+        )
         with cache_database.atomic():
             cached_element: CachedDatasetElement = CachedDatasetElement.create(
                 id=uuid.uuid4(),
                 element=cached_element,
-                dataset=dataset,
+                dataset=self.cached_dataset,
                 set_name=split_name,
             )
 
-    def process_split(
-        self, split_name: str, elements: List[Element], dataset: CachedDataset
-    ) -> None:
+    def process_split(self, split_name: str, elements: List[Element]) -> None:
         logger.info(
             f"Filling the cache with information from elements in the split {split_name}"
         )
@@ -319,7 +318,7 @@ class DatasetExtractor(DatasetWorker):
             logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})")
 
             # Insert page
-            self.insert_element(element, dataset, split_name)
+            self.insert_element(element, split_name)
 
             # List children
             children = list_children(element.id)
@@ -327,7 +326,7 @@ class DatasetExtractor(DatasetWorker):
             for child_idx, child in enumerate(children, start=1):
                 logger.info(f"Processing child ({child_idx}/{nb_children})")
                 # Insert child
-                self.insert_element(child, dataset, split_name, parent_id=element.id)
+                self.insert_element(child, split_name, parent_id=element.id)
 
     def insert_dataset(self, dataset: Dataset) -> None:
         """
@@ -349,12 +348,12 @@ class DatasetExtractor(DatasetWorker):
         self.configure_storage()
 
         splits = self.list_dataset_elements_per_split(dataset)
-        cached_dataset = self.insert_dataset(dataset)
+        self.cached_dataset = self.insert_dataset(dataset)
 
         # Iterate over given splits
         for split_name, elements in splits:
             casted_elements = list(map(_format_element, elements))
-            self.process_split(split_name, casted_elements, cached_dataset)
+            self.process_split(split_name, casted_elements)
 
         # TAR + ZSTD the cache and the images folder, and store as task artifact
         zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
-- 
GitLab