Skip to content
Snippets Groups Projects
Commit a81a4857 authored by Eva Bardou's avatar Eva Bardou :frog:
Browse files

Bump arkindex-base-worker + Fix tests

parent b1bbda3e
No related branches found
No related tags found
1 merge request!8New DatasetExtractor using a DatasetWorker
......@@ -11,13 +11,11 @@ from arkindex_worker.cache import (
CachedTranscription,
CachedTranscriptionEntity,
)
from worker_generic_training_dataset.db import retrieve_element
from worker_generic_training_dataset.worker import DatasetExtractor
def test_process_split(tmp_path, downloaded_images):
# Parent is train folder
parent_id: UUID = UUID("a0c4522d-2d80-4766-a01c-b9d686f41f6a")
worker = DatasetExtractor()
# Parse some arguments
worker.args = Namespace(database=None)
......@@ -27,18 +25,22 @@ def test_process_split(tmp_path, downloaded_images):
# Where to save the downloaded images
worker.image_folder = tmp_path
worker.process_split("train", parent_id)
# Should have created 20 elements in total
assert CachedElement.select().count() == 20
first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# Should have created two pages under root folder
assert (
CachedElement.select().where(CachedElement.parent_id == parent_id).count() == 2
worker.process_split(
"train",
[
retrieve_element(first_page_id),
retrieve_element(second_page_id),
],
)
first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# Should have created 20 elements in total
assert CachedElement.select().count() == 19
# Should have created two pages at root
assert CachedElement.select().where(CachedElement.parent_id.is_null()).count() == 2
# Should have created 8 text_lines under first page
assert (
......
......@@ -35,6 +35,7 @@ from worker_generic_training_dataset.db import (
list_classifications,
list_transcription_entities,
list_transcriptions,
retrieve_element,
)
from worker_generic_training_dataset.utils import build_image_url
......@@ -70,7 +71,7 @@ class DatasetWorker(BaseWorker, DatasetMixin):
"""
def format_element(element: Tuple[str, WorkerElement]) -> Element:
return Element.get(Element.id == element[1].id)
return retrieve_element(element[1].id)
def format_split(
split: Tuple[str, Iterator[Tuple[str, WorkerElement]]]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment