From cee333fb010318c9f4d056566f9f3d3358d8eb4c Mon Sep 17 00:00:00 2001 From: EvaBardou <bardou@teklia.com> Date: Wed, 8 Nov 2023 10:24:16 +0100 Subject: [PATCH] Save Dataset and DatasetElements in cache database --- requirements.txt | 4 +- tests/test_worker.py | 28 +++++++++++- worker_generic_training_dataset/worker.py | 52 ++++++++++++++++++++--- 3 files changed, 74 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2fc24b3..be8c01c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -arkindex-base-worker==0.3.5rc4 -arkindex-export==0.1.7 +arkindex-base-worker==0.3.5rc5 +arkindex-export==0.1.8 diff --git a/tests/test_worker.py b/tests/test_worker.py index 2b4651c..05971b3 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,10 +1,13 @@ # -*- coding: utf-8 -*- +import json from argparse import Namespace -from uuid import UUID +from uuid import UUID, uuid4 from arkindex_worker.cache import ( CachedClassification, + CachedDataset, + CachedDatasetElement, CachedElement, CachedEntity, CachedImage, @@ -30,15 +33,24 @@ def test_process_split(tmp_path, downloaded_images): first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c") second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f") + # The dataset should already be saved in database when we call `process_split` + cached_dataset = CachedDataset.create( + id=uuid4(), + name="My dataset", + state="complete", + sets=json.dumps(["train", "val", "test"]), + ) + worker.process_split( "train", [ retrieve_element(first_page_id), retrieve_element(second_page_id), ], + cached_dataset, ) - # Should have created 20 elements in total + # Should have created 19 elements in total assert CachedElement.select().count() == 19 # Should have created two pages at root @@ -125,6 +137,18 @@ def test_process_split(tmp_path, downloaded_images): assert tr_entity.confidence == 1.0 assert tr_entity.worker_run_id is None + # Should have linked all the elements to the correct dataset & split + assert CachedDatasetElement.select().count() == 19 + assert ( + CachedDatasetElement.select() + .where( + CachedDatasetElement.dataset == cached_dataset, + CachedDatasetElement.set_name == "train", + ) + .count() + == 19 + ) + # Full structure of the archive assert sorted(tmp_path.rglob("*")) == [ tmp_path / "db.sqlite", diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 20c0934..c8123c7 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- +import json import logging import tempfile +import uuid from argparse import Namespace from operator import itemgetter from pathlib import Path @@ -13,6 +15,8 @@ from arkindex_export import Element, open_database from arkindex_export.queries import list_children from arkindex_worker.cache import ( CachedClassification, + CachedDataset, + CachedDatasetElement, CachedElement, CachedEntity, CachedImage, @@ -227,7 +231,11 @@ class DatasetExtractor(DatasetWorker): ) def insert_element( - self, element: Element, parent_id: Optional[UUID] = None + self, + element: Element, + dataset: CachedDataset, + split_name: str, + parent_id: Optional[UUID] = None, ) -> None: """ Insert the given element in the cache database. @@ -238,6 +246,8 @@ class DatasetExtractor(DatasetWorker): - its transcriptions - its transcriptions' entities (both Entity and TranscriptionEntity) + The element will also be linked to the appropriate split in the current dataset. + :param element: Element to insert. :param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements. """ @@ -286,7 +296,19 @@ class DatasetExtractor(DatasetWorker): # Insert entities self.insert_entities(transcriptions) - def process_split(self, split_name: str, elements: List[Element]) -> None: + # Link the element to the dataset + logger.info(f"Linking element {cached_element.id} to dataset ({dataset.id})") + with cache_database.atomic(): + cached_element: CachedDatasetElement = CachedDatasetElement.create( + id=uuid.uuid4(), + element=cached_element, + dataset=dataset, + set_name=split_name, + ) + + def process_split( + self, split_name: str, elements: List[Element], dataset: CachedDataset + ) -> None: logger.info( f"Filling the cache with information from elements in the split {split_name}" ) @@ -297,7 +319,7 @@ class DatasetExtractor(DatasetWorker): logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})") # Insert page - self.insert_element(element) + self.insert_element(element, dataset, split_name) # List children children = list_children(element.id) @@ -305,16 +327,34 @@ class DatasetExtractor(DatasetWorker): for child_idx, child in enumerate(children, start=1): logger.info(f"Processing child ({child_idx}/{nb_children})") # Insert child - self.insert_element(child, parent_id=element.id) + self.insert_element(child, dataset, split_name, parent_id=element.id) + + def insert_dataset(self, dataset: Dataset) -> None: + """ + Insert the given dataset in the cache database. + + :param dataset: Dataset to insert. + """ + logger.info(f"Inserting dataset ({dataset.id})") + with cache_database.atomic(): + return CachedDataset.create( + id=dataset.id, + name=dataset.name, + state=dataset.state, + sets=json.dumps(dataset.sets), + ) def process_dataset(self, dataset: Dataset): # Configure temporary storage for the dataset data (cache + images) self.configure_storage() + splits = self.list_dataset_elements_per_split(dataset) + cached_dataset = self.insert_dataset(dataset) + # Iterate over given splits - for split_name, elements in self.list_dataset_elements_per_split(dataset): + for split_name, elements in splits: casted_elements = list(map(_format_element, elements)) - self.process_split(split_name, casted_elements) + self.process_split(split_name, casted_elements, cached_dataset) # TAR + ZSTD the cache and the images folder, and store as task artifact zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd" -- GitLab