From bbc80402621e3c085b2fa88cf86358cffe7b0a24 Mon Sep 17 00:00:00 2001 From: EvaBardou <bardou@teklia.com> Date: Wed, 8 Nov 2023 10:34:54 +0100 Subject: [PATCH] Nit --- tests/test_worker.py | 5 ++--- worker_generic_training_dataset/worker.py | 19 +++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/test_worker.py b/tests/test_worker.py index 05971b3..870c2dd 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -34,7 +34,7 @@ def test_process_split(tmp_path, downloaded_images): second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f") # The dataset should already be saved in database when we call `process_split` - cached_dataset = CachedDataset.create( + worker.cached_dataset = CachedDataset.create( id=uuid4(), name="My dataset", state="complete", @@ -47,7 +47,6 @@ def test_process_split(tmp_path, downloaded_images): retrieve_element(first_page_id), retrieve_element(second_page_id), ], - cached_dataset, ) # Should have created 19 elements in total @@ -142,7 +141,7 @@ def test_process_split(tmp_path, downloaded_images): assert ( CachedDatasetElement.select() .where( - CachedDatasetElement.dataset == cached_dataset, + CachedDatasetElement.dataset == worker.cached_dataset, CachedDatasetElement.set_name == "train", ) .count() diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index c8123c7..5501704 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -233,7 +233,6 @@ class DatasetExtractor(DatasetWorker): def insert_element( self, element: Element, - dataset: CachedDataset, split_name: str, parent_id: Optional[UUID] = None, ) -> None: @@ -297,18 +296,18 @@ class DatasetExtractor(DatasetWorker): self.insert_entities(transcriptions) # Link the element to the dataset - logger.info(f"Linking element {cached_element.id} to dataset ({dataset.id})") + logger.info( + f"Linking element {cached_element.id} to dataset ({self.cached_dataset.id})" + ) with cache_database.atomic(): cached_element: CachedDatasetElement = CachedDatasetElement.create( id=uuid.uuid4(), element=cached_element, - dataset=dataset, + dataset=self.cached_dataset, set_name=split_name, ) - def process_split( - self, split_name: str, elements: List[Element], dataset: CachedDataset - ) -> None: + def process_split(self, split_name: str, elements: List[Element]) -> None: logger.info( f"Filling the cache with information from elements in the split {split_name}" ) @@ -319,7 +318,7 @@ class DatasetExtractor(DatasetWorker): logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})") # Insert page - self.insert_element(element, dataset, split_name) + self.insert_element(element, split_name) # List children children = list_children(element.id) @@ -327,7 +326,7 @@ class DatasetExtractor(DatasetWorker): for child_idx, child in enumerate(children, start=1): logger.info(f"Processing child ({child_idx}/{nb_children})") # Insert child - self.insert_element(child, dataset, split_name, parent_id=element.id) + self.insert_element(child, split_name, parent_id=element.id) def insert_dataset(self, dataset: Dataset) -> None: """ @@ -349,12 +348,12 @@ class DatasetExtractor(DatasetWorker): self.configure_storage() splits = self.list_dataset_elements_per_split(dataset) - cached_dataset = self.insert_dataset(dataset) + self.cached_dataset = self.insert_dataset(dataset) # Iterate over given splits for split_name, elements in splits: casted_elements = list(map(_format_element, elements)) - self.process_split(split_name, casted_elements, cached_dataset) + self.process_split(split_name, casted_elements) # TAR + ZSTD the cache and the images folder, and store as task artifact zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd" -- GitLab