diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 92be49137058a56f421fba1d7bdb7dffe565d20b..2ada2ce97b404e2d3f3b15618a76243cd5aac928 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -186,13 +186,9 @@ def create_tables(): db.create_tables(MODELS) -def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=None): - """ - Merge all the potential parent task's databases into the existing local one - """ +def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None): assert isinstance(parent_ids, list) assert os.path.isdir(data_dir) - assert os.path.exists(current_database) # Handle possible chunk in parent task name # This is needed to support the init_elements databases @@ -203,7 +199,7 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No filenames.append(f"db_{chunk}.sqlite") # Find all the paths for these databases - paths = list( + return list( filter( lambda p: os.path.isfile(p), [ @@ -214,6 +210,13 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No ) ) + +def merge_parents_cache(paths, current_database): + """ + Merge all the potential parent task's databases into the existing local one + """ + assert os.path.exists(current_database) + if not paths: logger.info("No parents cache to use") return diff --git a/arkindex_worker/worker/__init__.py b/arkindex_worker/worker/__init__.py index 1ff5ba7631d7f648a6ce398034fc738665cb3502..67f01a1e331c0c764b01dfd76d5dee8373b0ec63 100644 --- a/arkindex_worker/worker/__init__.py +++ b/arkindex_worker/worker/__init__.py @@ -36,8 +36,8 @@ class ElementsWorker( EntityMixin, MetaDataMixin, ): - def __init__(self, description="Arkindex Elements Worker", use_cache=False): - super().__init__(description, use_cache) + def __init__(self, description="Arkindex Elements Worker", support_cache=False): + super().__init__(description, support_cache) # Add report concerning elements self.report = Reporter("unknown worker") diff --git a/arkindex_worker/worker/base.py b/arkindex_worker/worker/base.py index d9ab11cea1204ad31fcd3dfd81daf9b125236f94..2f684ff30d3985503397e32409c24eb7d8c2556e 100644 --- a/arkindex_worker/worker/base.py +++ b/arkindex_worker/worker/base.py @@ -18,7 +18,12 @@ from tenacity import ( from arkindex import ArkindexClient, options_from_env from arkindex_worker import logger -from arkindex_worker.cache import create_tables, init_cache_db, merge_parents_cache +from arkindex_worker.cache import ( + create_tables, + init_cache_db, + merge_parents_cache, + retrieve_parents_cache_path, +) def _is_500_error(exc): @@ -33,7 +38,7 @@ def _is_500_error(exc): class BaseWorker(object): - def __init__(self, description="Arkindex Base Worker", use_cache=False): + def __init__(self, description="Arkindex Base Worker", support_cache=False): self.parser = argparse.ArgumentParser(description=description) # Setup workdir either in Ponos environment or on host's home @@ -56,7 +61,10 @@ class BaseWorker(object): logger.info(f"Worker will use {self.work_dir} as working directory") - self.use_cache = use_cache + self.support_cache = support_cache + # use_cache will be updated in configure() if the cache is supported and if there + # is at least one available sqlite database either given or in the parent tasks + self.use_cache = False @property def is_read_only(self): @@ -133,38 +141,39 @@ class BaseWorker(object): # Load all required secrets self.secrets = {name: self.load_secret(name) for name in required_secrets} - if self.args.database is not None: + task_id = os.environ.get("PONOS_TASK") + paths = None + if self.support_cache and self.args.database is not None: self.use_cache = True + elif self.support_cache and task_id: + task = self.request("RetrieveTaskFromAgent", id=task_id) + paths = retrieve_parents_cache_path( + task["parents"], + data_dir=os.environ.get("PONOS_DATA", "/data"), + chunk=os.environ.get("ARKINDEX_TASK_CHUNK"), + ) + self.use_cache = len(paths) > 0 - task_id = os.environ.get("PONOS_TASK") - if self.use_cache is True: + if self.use_cache: if self.args.database is not None: assert os.path.isfile( self.args.database ), f"Database in {self.args.database} does not exist" self.cache_path = self.args.database - elif task_id: + else: cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id) assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}" self.cache_path = os.path.join(cache_dir, "db.sqlite") - else: - self.cache_path = os.path.join(os.getcwd(), "db.sqlite") init_cache_db(self.cache_path) create_tables() + + # Merging parents caches (if there are any) in the current task local cache, unless the database got overridden + if self.args.database is None and paths is not None: + merge_parents_cache(paths, self.cache_path) else: logger.debug("Cache is disabled") - # Merging parents caches (if there are any) in the current task local cache, unless the database got overridden - if self.use_cache and self.args.database is None and task_id is not None: - task = self.request("RetrieveTaskFromAgent", id=task_id) - merge_parents_cache( - task["parents"], - self.cache_path, - data_dir=os.environ.get("PONOS_DATA", "/data"), - chunk=os.environ.get("ARKINDEX_TASK_CHUNK"), - ) - def load_secret(self, name): """Load all secrets described in the worker configuration""" secret = None diff --git a/tests/conftest.py b/tests/conftest.py index 840784cadf09ef204381dbd2eb50c2f3f8defea6..83a55661a3ed8d60eef3890cb38698063f3a1525 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,14 +91,6 @@ def setup_api(responses, monkeypatch, cache_yaml): monkeypatch.setenv("ARKINDEX_API_TOKEN", "unittest1234") -@pytest.fixture(autouse=True) -def temp_working_directory(monkeypatch, tmp_path): - def _getcwd(): - return str(tmp_path) - - monkeypatch.setattr(os, "getcwd", _getcwd) - - @pytest.fixture(autouse=True) def give_worker_version_id_env_variable(monkeypatch): monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234") @@ -177,18 +169,20 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api): """Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK""" monkeypatch.setattr(sys, "argv", ["worker"]) - worker = BaseWorker(use_cache=True) + worker = BaseWorker(support_cache=True) monkeypatch.setenv("PONOS_TASK", "my_task") return worker @pytest.fixture -def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api): +def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api, tmp_path): """Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest""" - monkeypatch.setattr(sys, "argv", ["worker"]) + cache_path = tmp_path / "db.sqlite" + cache_path.touch() + monkeypatch.setattr(sys, "argv", ["worker", "-d", str(cache_path)]) monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111") - worker = ElementsWorker(use_cache=True) + worker = ElementsWorker(support_cache=True) worker.configure() return worker diff --git a/tests/test_base_worker.py b/tests/test_base_worker.py index 40279ea0edab216f73cabd3245407d5fa6049995..444d5131b80b925d4a98acd27dd123ad011e6583 100644 --- a/tests/test_base_worker.py +++ b/tests/test_base_worker.py @@ -29,11 +29,11 @@ def test_init_default_xdg_data_home(monkeypatch): def test_init_with_local_cache(monkeypatch): - worker = BaseWorker(use_cache=True) + worker = BaseWorker(support_cache=True) assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex") assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234" - assert worker.use_cache is True + assert worker.support_cache is True def test_init_var_ponos_data_given(monkeypatch): diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 018bd9cbde155445c240c6fbef30b846eff56c7f..744009fbf99b432ce99840f9d6ca44c66a0f5cf4 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -145,7 +145,7 @@ def test_database_arg(mocker, mock_elements_worker, tmp_path): ), ) - worker = ElementsWorker() + worker = ElementsWorker(support_cache=True) worker.configure() assert worker.use_cache is True diff --git a/tests/test_merge.py b/tests/test_merge.py index 60b745b6ecd316e4df7b540025d69c5711751be1..f02e32daf9e1af9196f29f034e3cde9dca194e56 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -7,9 +7,12 @@ from arkindex_worker.cache import ( MODELS, CachedClassification, CachedElement, + CachedEntity, CachedImage, CachedTranscription, + CachedTranscriptionEntity, merge_parents_cache, + retrieve_parents_cache_path, ) @@ -82,13 +85,14 @@ def test_merge_databases( assert CachedElement.select().count() == 0 assert CachedTranscription.select().count() == 0 assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 + + # Retrieve parents databases paths + paths = retrieve_parents_cache_path(parents, data_dir=tmpdir) # Merge all requested parents databases into our target - merge_parents_cache( - parents, - mock_databases["target"]["path"], - data_dir=tmpdir, - ) + merge_parents_cache(paths, mock_databases["target"]["path"]) # The target now should have the expected elements and transcriptions with mock_databases["target"]["db"].bind_ctx(MODELS): @@ -96,6 +100,8 @@ def test_merge_databases( assert CachedElement.select().count() == len(expected_elements) assert CachedTranscription.select().count() == len(expected_transcriptions) assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 assert [ e.id for e in CachedElement.select().order_by("id") ] == expected_elements @@ -111,22 +117,26 @@ def test_merge_chunk(mock_databases, tmpdir, monkeypatch): """ # At first we have nothing in target with mock_databases["target"]["db"].bind_ctx(MODELS): + assert CachedImage.select().count() == 0 assert CachedElement.select().count() == 0 assert CachedTranscription.select().count() == 0 + assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 # Check filenames assert mock_databases["chunk_42"]["path"].basename == "db_42.sqlite" assert mock_databases["second"]["path"].basename == "db.sqlite" - merge_parents_cache( + paths = retrieve_parents_cache_path( [ "chunk_42", "first", ], - mock_databases["target"]["path"], data_dir=tmpdir, chunk="42", ) + merge_parents_cache(paths, mock_databases["target"]["path"]) # The target should now have 3 elements and 0 transcription with mock_databases["target"]["db"].bind_ctx(MODELS): @@ -134,6 +144,8 @@ def test_merge_chunk(mock_databases, tmpdir, monkeypatch): assert CachedElement.select().count() == 3 assert CachedTranscription.select().count() == 0 assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 assert [e.id for e in CachedElement.select().order_by("id")] == [ UUID("42424242-4242-4242-4242-424242424242"), UUID("12341234-1234-1234-1234-123412341234"), @@ -154,11 +166,14 @@ def test_merge_from_worker( json={"parents": ["first", "second"]}, ) - # At first we have no data in our main database - assert CachedImage.select().count() == 0 - assert CachedElement.select().count() == 0 - assert CachedTranscription.select().count() == 0 - assert CachedClassification.select().count() == 0 + # At first we have nothing in target + with mock_databases["target"]["db"].bind_ctx(MODELS): + assert CachedImage.select().count() == 0 + assert CachedElement.select().count() == 0 + assert CachedTranscription.select().count() == 0 + assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 # Configure worker with a specific data directory monkeypatch.setenv("PONOS_DATA", str(tmpdir)) @@ -171,6 +186,8 @@ def test_merge_from_worker( assert CachedElement.select().count() == 3 assert CachedTranscription.select().count() == 1 assert CachedClassification.select().count() == 0 + assert CachedEntity.select().count() == 0 + assert CachedTranscriptionEntity.select().count() == 0 assert [e.id for e in CachedElement.select().order_by("id")] == [ UUID("12341234-1234-1234-1234-123412341234"), UUID("56785678-5678-5678-5678-567856785678"),