Skip to content
Snippets Groups Projects
Commit b017a586 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'save-dataset-db' into 'main'

Save Dataset and DatasetElements in cache database

Closes #8

See merge request workers/generic-training-dataset!12
parents f56a8e78 e6fb5f20
No related branches found
No related tags found
1 merge request!12Save Dataset and DatasetElements in cache database
Pipeline #141240 passed
# -*- coding: utf-8 -*-
import json
from argparse import Namespace
from uuid import UUID
from uuid import UUID, uuid4
from arkindex_worker.cache import (
CachedClassification,
CachedDataset,
CachedDatasetElement,
CachedElement,
CachedEntity,
CachedImage,
......@@ -30,6 +33,14 @@ def test_process_split(tmp_path, downloaded_images):
first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# The dataset should already be saved in database when we call `process_split`
worker.cached_dataset = CachedDataset.create(
id=uuid4(),
name="My dataset",
state="complete",
sets=json.dumps(["train", "val", "test"]),
)
worker.process_split(
"train",
[
......@@ -38,7 +49,7 @@ def test_process_split(tmp_path, downloaded_images):
],
)
# Should have created 20 elements in total
# Should have created 19 elements in total
assert CachedElement.select().count() == 19
# Should have created two pages at root
......@@ -125,6 +136,18 @@ def test_process_split(tmp_path, downloaded_images):
assert tr_entity.confidence == 1.0
assert tr_entity.worker_run_id is None
# Should have linked all the elements to the correct dataset & split
assert CachedDatasetElement.select().count() == 19
assert (
CachedDatasetElement.select()
.where(
CachedDatasetElement.dataset == worker.cached_dataset,
CachedDatasetElement.set_name == "train",
)
.count()
== 19
)
# Full structure of the archive
assert sorted(tmp_path.rglob("*")) == [
tmp_path / "db.sqlite",
......
# -*- coding: utf-8 -*-
import json
import logging
import tempfile
import uuid
from argparse import Namespace
from operator import itemgetter
from pathlib import Path
......@@ -13,6 +15,8 @@ from arkindex_export import Element, open_database
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
CachedDataset,
CachedDatasetElement,
CachedElement,
CachedEntity,
CachedImage,
......@@ -227,7 +231,10 @@ class DatasetExtractor(DatasetWorker):
)
def insert_element(
self, element: Element, parent_id: Optional[UUID] = None
self,
element: Element,
split_name: str,
parent_id: Optional[UUID] = None,
) -> None:
"""
Insert the given element in the cache database.
......@@ -238,6 +245,8 @@ class DatasetExtractor(DatasetWorker):
- its transcriptions
- its transcriptions' entities (both Entity and TranscriptionEntity)
The element will also be linked to the appropriate split in the current dataset.
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
......@@ -286,6 +295,18 @@ class DatasetExtractor(DatasetWorker):
# Insert entities
self.insert_entities(transcriptions)
# Link the element to the dataset
logger.info(
f"Linking element {cached_element.id} to dataset ({self.cached_dataset.id})"
)
with cache_database.atomic():
cached_element: CachedDatasetElement = CachedDatasetElement.create(
id=uuid.uuid4(),
element=cached_element,
dataset=self.cached_dataset,
set_name=split_name,
)
def process_split(self, split_name: str, elements: List[Element]) -> None:
logger.info(
f"Filling the cache with information from elements in the split {split_name}"
......@@ -297,7 +318,7 @@ class DatasetExtractor(DatasetWorker):
logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})")
# Insert page
self.insert_element(element)
self.insert_element(element, split_name)
# List children
children = list_children(element.id)
......@@ -305,12 +326,30 @@ class DatasetExtractor(DatasetWorker):
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, parent_id=element.id)
self.insert_element(child, split_name, parent_id=element.id)
def insert_dataset(self, dataset: Dataset) -> None:
"""
Insert the given dataset in the cache database.
:param dataset: Dataset to insert.
"""
logger.info(f"Inserting dataset ({dataset.id})")
with cache_database.atomic():
self.cached_dataset = CachedDataset.create(
id=dataset.id,
name=dataset.name,
state=dataset.state,
sets=json.dumps(dataset.sets),
)
def process_dataset(self, dataset: Dataset):
# Configure temporary storage for the dataset data (cache + images)
self.configure_storage()
# Insert dataset in cache database
self.insert_dataset(dataset)
# Iterate over given splits
for split_name, elements in self.list_dataset_elements_per_split(dataset):
casted_elements = list(map(_format_element, elements))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment