Skip to content
Snippets Groups Projects
Commit bbc80402 authored by Eva Bardou's avatar Eva Bardou :frog:
Browse files

Nit

parent cee333fb
No related branches found
No related tags found
1 merge request!12Save Dataset and DatasetElements in cache database
Pipeline #140298 passed
......@@ -34,7 +34,7 @@ def test_process_split(tmp_path, downloaded_images):
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# The dataset should already be saved in database when we call `process_split`
cached_dataset = CachedDataset.create(
worker.cached_dataset = CachedDataset.create(
id=uuid4(),
name="My dataset",
state="complete",
......@@ -47,7 +47,6 @@ def test_process_split(tmp_path, downloaded_images):
retrieve_element(first_page_id),
retrieve_element(second_page_id),
],
cached_dataset,
)
# Should have created 19 elements in total
......@@ -142,7 +141,7 @@ def test_process_split(tmp_path, downloaded_images):
assert (
CachedDatasetElement.select()
.where(
CachedDatasetElement.dataset == cached_dataset,
CachedDatasetElement.dataset == worker.cached_dataset,
CachedDatasetElement.set_name == "train",
)
.count()
......
......@@ -233,7 +233,6 @@ class DatasetExtractor(DatasetWorker):
def insert_element(
self,
element: Element,
dataset: CachedDataset,
split_name: str,
parent_id: Optional[UUID] = None,
) -> None:
......@@ -297,18 +296,18 @@ class DatasetExtractor(DatasetWorker):
self.insert_entities(transcriptions)
# Link the element to the dataset
logger.info(f"Linking element {cached_element.id} to dataset ({dataset.id})")
logger.info(
f"Linking element {cached_element.id} to dataset ({self.cached_dataset.id})"
)
with cache_database.atomic():
cached_element: CachedDatasetElement = CachedDatasetElement.create(
id=uuid.uuid4(),
element=cached_element,
dataset=dataset,
dataset=self.cached_dataset,
set_name=split_name,
)
def process_split(
self, split_name: str, elements: List[Element], dataset: CachedDataset
) -> None:
def process_split(self, split_name: str, elements: List[Element]) -> None:
logger.info(
f"Filling the cache with information from elements in the split {split_name}"
)
......@@ -319,7 +318,7 @@ class DatasetExtractor(DatasetWorker):
logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})")
# Insert page
self.insert_element(element, dataset, split_name)
self.insert_element(element, split_name)
# List children
children = list_children(element.id)
......@@ -327,7 +326,7 @@ class DatasetExtractor(DatasetWorker):
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, dataset, split_name, parent_id=element.id)
self.insert_element(child, split_name, parent_id=element.id)
def insert_dataset(self, dataset: Dataset) -> None:
"""
......@@ -349,12 +348,12 @@ class DatasetExtractor(DatasetWorker):
self.configure_storage()
splits = self.list_dataset_elements_per_split(dataset)
cached_dataset = self.insert_dataset(dataset)
self.cached_dataset = self.insert_dataset(dataset)
# Iterate over given splits
for split_name, elements in splits:
casted_elements = list(map(_format_element, elements))
self.process_split(split_name, casted_elements, cached_dataset)
self.process_split(split_name, casted_elements)
# TAR + ZSTD the cache and the images folder, and store as task artifact
zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment