diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 92be49137058a56f421fba1d7bdb7dffe565d20b..2ada2ce97b404e2d3f3b15618a76243cd5aac928 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -186,13 +186,9 @@ def create_tables(): db.create_tables(MODELS) -def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=None): - """ - Merge all the potential parent task's databases into the existing local one - """ +def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None): assert isinstance(parent_ids, list) assert os.path.isdir(data_dir) - assert os.path.exists(current_database) # Handle possible chunk in parent task name # This is needed to support the init_elements databases @@ -203,7 +199,7 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No filenames.append(f"db_{chunk}.sqlite") # Find all the paths for these databases - paths = list( + return list( filter( lambda p: os.path.isfile(p), [ @@ -214,6 +210,13 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No ) ) + +def merge_parents_cache(paths, current_database): + """ + Merge all the potential parent task's databases into the existing local one + """ + assert os.path.exists(current_database) + if not paths: logger.info("No parents cache to use") return diff --git a/arkindex_worker/worker/base.py b/arkindex_worker/worker/base.py index ff0905530aa21dfca4e520bec2dd466e13194f6b..2bc880e6e36402f5a8d692ef887d9c382b3be0c6 100644 --- a/arkindex_worker/worker/base.py +++ b/arkindex_worker/worker/base.py @@ -18,7 +18,12 @@ from tenacity import ( from arkindex import ArkindexClient, options_from_env from arkindex_worker import logger -from arkindex_worker.cache import create_tables, init_cache_db, merge_parents_cache +from arkindex_worker.cache import ( + create_tables, + init_cache_db, + merge_parents_cache, + retrieve_parents_cache_path, +) def _is_500_error(exc): @@ -137,12 +142,17 @@ class BaseWorker(object): self.secrets = {name: self.load_secret(name) for name in required_secrets} task_id = os.environ.get("PONOS_TASK") - task = None + paths = None if self.support_cache and self.args.database is not None: self.use_cache = True elif self.support_cache and task_id: task = self.request("RetrieveTaskFromAgent", id=task_id) - self.use_cache = len(task["parents"]) > 0 + paths = retrieve_parents_cache_path( + task["parents"], + data_dir=os.environ.get("PONOS_DATA", "/data"), + chunk=os.environ.get("ARKINDEX_TASK_CHUNK"), + ) + self.use_cache = len(paths) > 0 if self.use_cache: if self.args.database is not None: @@ -161,13 +171,8 @@ class BaseWorker(object): logger.debug("Cache is disabled") # Merging parents caches (if there are any) in the current task local cache, unless the database got overridden - if self.use_cache and self.args.database is None and task: - merge_parents_cache( - task["parents"], - self.cache_path, - data_dir=os.environ.get("PONOS_DATA", "/data"), - chunk=os.environ.get("ARKINDEX_TASK_CHUNK"), - ) + if self.use_cache and self.args.database is None and paths is not None: + merge_parents_cache(paths, self.cache_path) def load_secret(self, name): """Load all secrets described in the worker configuration""" diff --git a/tests/test_merge.py b/tests/test_merge.py index 60b745b6ecd316e4df7b540025d69c5711751be1..f02e32daf9e1af9196f29f034e3cde9dca194e56 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -7,9 +7,12 @@ from arkindex_worker.cache import ( MODELS, CachedClassification, CachedElement, + CachedEntity, CachedImage, CachedTranscription, + CachedTranscriptionEntity, merge_parents_cache, + retrieve_parents_cache_path, ) @@ -82,13 +85,14 @@ def test_merge_databases( assert CachedElement.select().count() == 0 assert CachedTranscription.select().count() == 0 assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 + + # Retrieve parents databases paths + paths = retrieve_parents_cache_path(parents, data_dir=tmpdir) # Merge all requested parents databases into our target - merge_parents_cache( - parents, - mock_databases["target"]["path"], - data_dir=tmpdir, - ) + merge_parents_cache(paths, mock_databases["target"]["path"]) # The target now should have the expected elements and transcriptions with mock_databases["target"]["db"].bind_ctx(MODELS): @@ -96,6 +100,8 @@ def test_merge_databases( assert CachedElement.select().count() == len(expected_elements) assert CachedTranscription.select().count() == len(expected_transcriptions) assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 assert [ e.id for e in CachedElement.select().order_by("id") ] == expected_elements @@ -111,22 +117,26 @@ def test_merge_chunk(mock_databases, tmpdir, monkeypatch): """ # At first we have nothing in target with mock_databases["target"]["db"].bind_ctx(MODELS): + assert CachedImage.select().count() == 0 assert CachedElement.select().count() == 0 assert CachedTranscription.select().count() == 0 + assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 # Check filenames assert mock_databases["chunk_42"]["path"].basename == "db_42.sqlite" assert mock_databases["second"]["path"].basename == "db.sqlite" - merge_parents_cache( + paths = retrieve_parents_cache_path( [ "chunk_42", "first", ], - mock_databases["target"]["path"], data_dir=tmpdir, chunk="42", ) + merge_parents_cache(paths, mock_databases["target"]["path"]) # The target should now have 3 elements and 0 transcription with mock_databases["target"]["db"].bind_ctx(MODELS): @@ -134,6 +144,8 @@ def test_merge_chunk(mock_databases, tmpdir, monkeypatch): assert CachedElement.select().count() == 3 assert CachedTranscription.select().count() == 0 assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 assert [e.id for e in CachedElement.select().order_by("id")] == [ UUID("42424242-4242-4242-4242-424242424242"), UUID("12341234-1234-1234-1234-123412341234"), @@ -154,11 +166,14 @@ def test_merge_from_worker( json={"parents": ["first", "second"]}, ) - # At first we have no data in our main database - assert CachedImage.select().count() == 0 - assert CachedElement.select().count() == 0 - assert CachedTranscription.select().count() == 0 - assert CachedClassification.select().count() == 0 + # At first we have nothing in target + with mock_databases["target"]["db"].bind_ctx(MODELS): + assert CachedImage.select().count() == 0 + assert CachedElement.select().count() == 0 + assert CachedTranscription.select().count() == 0 + assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 # Configure worker with a specific data directory monkeypatch.setenv("PONOS_DATA", str(tmpdir)) @@ -171,6 +186,8 @@ def test_merge_from_worker( assert CachedElement.select().count() == 3 assert CachedTranscription.select().count() == 1 assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 assert [e.id for e in CachedElement.select().order_by("id")] == [ UUID("12341234-1234-1234-1234-123412341234"), UUID("56785678-5678-5678-5678-567856785678"),