Skip to content
Snippets Groups Projects

Save Dataset and DatasetElements in cache database

Merged Eva Bardou requested to merge save-dataset-db into main
All threads resolved!
3 files
+ 69
7
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 25
2
# -*- coding: utf-8 -*-
import json
from argparse import Namespace
from uuid import UUID
from uuid import UUID, uuid4
from arkindex_worker.cache import (
CachedClassification,
CachedDataset,
CachedDatasetElement,
CachedElement,
CachedEntity,
CachedImage,
@@ -30,6 +33,14 @@ def test_process_split(tmp_path, downloaded_images):
first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# The dataset should already be saved in database when we call `process_split`
worker.cached_dataset = CachedDataset.create(
id=uuid4(),
name="My dataset",
state="complete",
sets=json.dumps(["train", "val", "test"]),
)
worker.process_split(
"train",
[
@@ -38,7 +49,7 @@ def test_process_split(tmp_path, downloaded_images):
],
)
# Should have created 20 elements in total
# Should have created 19 elements in total
assert CachedElement.select().count() == 19
# Should have created two pages at root
@@ -125,6 +136,18 @@ def test_process_split(tmp_path, downloaded_images):
assert tr_entity.confidence == 1.0
assert tr_entity.worker_run_id is None
# Should have linked all the elements to the correct dataset & split
assert CachedDatasetElement.select().count() == 19
assert (
CachedDatasetElement.select()
.where(
CachedDatasetElement.dataset == worker.cached_dataset,
CachedDatasetElement.set_name == "train",
)
.count()
== 19
)
# Full structure of the archive
assert sorted(tmp_path.rglob("*")) == [
tmp_path / "db.sqlite",
Loading