diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 2391c45ba1c4a10353fa978f8e71f6b8af178485..15f1514f143a2d0b76e87da9e77636bdec7039dd 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -102,13 +102,15 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No filenames.append(f"db_{chunk}.sqlite") # Find all the paths for these databases - paths = filter( - lambda p: os.path.isfile(p), - [ - os.path.join(data_dir, parent, name) - for parent in parent_ids - for name in filenames - ], + paths = list( + filter( + lambda p: os.path.isfile(p), + [ + os.path.join(data_dir, parent, name) + for parent in parent_ids + for name in filenames + ], + ) ) if not paths: @@ -133,5 +135,3 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No for statement in statements: cursor.execute(statement) connection.commit() - - # TODO: maybe reopen peewee connection ? diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index c76c01186132202944bb0173f86d15b8599ff671..f1aaa5ae4768f9ac709dd7fd538ba8c3bb26f7cc 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -160,7 +160,7 @@ class BaseWorker(object): merge_parents_cache( task["parents"], self.cache_path, - data_dir=os.environ.get("PONOS_DATA_DIR"), + data_dir=os.environ.get("PONOS_DATA_DIR", "/data"), chunk=os.environ.get("ARKINDEX_TASK_CHUNK"), ) diff --git a/tests/test_merge.py b/tests/test_merge.py index 5b405989283a1a874c8139ab91d40c1f051a3a71..f2f4203a180aaf731d076b93a383aee0d7614242 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -1,7 +1,12 @@ # -*- coding: utf-8 -*- from uuid import UUID -from arkindex_worker.cache import MODELS, CachedElement, merge_parents_cache +from arkindex_worker.cache import ( + MODELS, + CachedElement, + CachedTranscription, + merge_parents_cache, +) def test_merge_no_parent(mock_databases, tmpdir): @@ -271,4 +276,65 @@ def test_merge_chunk(mock_databases, tmpdir, monkeypatch): ] -# TODO: add a unit test using base worker +def test_merge_from_worker( + responses, mock_base_worker_with_cache, mock_databases, tmpdir, monkeypatch +): + """ + High level merge from the base worker + """ + responses.add( + responses.GET, + "http://testserver/ponos/v1/task/my_task/from-agent/", + status=200, + json={"parents": ["first", "second"]}, + ) + + # Add an element in first parent database + with mock_databases["first"]["db"].bind_ctx(MODELS): + CachedElement.create( + id=UUID("12341234-1234-1234-1234-123412341234"), + type="page", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + CachedElement.create( + id=UUID("56785678-5678-5678-5678-567856785678"), + type="page", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + + # Add another element in second parent database + with mock_databases["second"]["db"].bind_ctx(MODELS): + CachedElement.create( + id=UUID("12341234-1234-1234-1234-123412341234"), + type="page", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + CachedTranscription.create( + id=UUID("11111111-1111-1111-1111-111111111111"), + element_id=UUID("12341234-1234-1234-1234-123412341234"), + text="Hello!", + confidence=0.42, + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + + # At first we have no data in our main database + assert CachedElement.select().count() == 0 + assert CachedTranscription.select().count() == 0 + + # Configure worker with a specific data directory + monkeypatch.setenv("PONOS_DATA_DIR", str(tmpdir)) + mock_base_worker_with_cache.configure() + + # Then we have 2 elements and a transcription + assert CachedElement.select().count() == 2 + assert CachedTranscription.select().count() == 1 + assert [e.id for e in CachedElement.select().order_by("id")] == [ + UUID("56785678-5678-5678-5678-567856785678"), + UUID("12341234-1234-1234-1234-123412341234"), + ] + assert [t.id for t in CachedTranscription.select().order_by("id")] == [ + UUID("11111111-1111-1111-1111-111111111111"), + ]