diff --git a/arkindex_worker/worker/base.py b/arkindex_worker/worker/base.py index 78cbb8d61795de05a60e8c305413401f2a256c2b..25c1def0338804997784a412221c3fde31a75b12 100644 --- a/arkindex_worker/worker/base.py +++ b/arkindex_worker/worker/base.py @@ -57,18 +57,6 @@ class BaseWorker(object): logger.info(f"Worker will use {self.work_dir} as working directory") self.use_cache = use_cache - if self.use_cache is True: - if os.environ.get("TASK_ID"): - cache_dir = f"/data/{os.environ.get('TASK_ID')}" - assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}" - self.cache_path = os.path.join(cache_dir, "db.sqlite") - else: - self.cache_path = os.path.join(os.getcwd(), "db.sqlite") - - init_cache_db(self.cache_path) - create_tables() - else: - logger.debug("Cache is disabled") @property def is_read_only(self): @@ -85,6 +73,13 @@ class BaseWorker(object): help="Alternative configuration file when running without a Worker Version ID", type=open, ) + self.parser.add_argument( + "-d", + "--database", + help="Alternative SQLite database to use for worker caching", + type=str, + default=None, + ) self.parser.add_argument( "-v", "--verbose", @@ -138,9 +133,32 @@ class BaseWorker(object): # Load all required secrets self.secrets = {name: self.load_secret(name) for name in required_secrets} - # Merging parents caches (if there are any) in the current task local cache + if self.args.database is not None: + self.use_cache = True + + if self.use_cache is True: + if self.args.database is not None: + assert os.path.isfile( + self.args.database + ), f"Database in {self.args.database} does not exist" + self.cache_path = self.args.database + elif os.environ.get("TASK_ID"): + cache_dir = os.path.join( + os.environ.get("PONOS_DATA", "/data"), os.environ.get("TASK_ID") + ) + assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}" + self.cache_path = os.path.join(cache_dir, "db.sqlite") + else: + self.cache_path = os.path.join(os.getcwd(), "db.sqlite") + + init_cache_db(self.cache_path) + create_tables() + else: + logger.debug("Cache is disabled") + + # Merging parents caches (if there are any) in the current task local cache, unless the database got overridden task_id = os.environ.get("TASK_ID") - if self.use_cache and task_id is not None: + if self.use_cache and self.args.database is None and task_id is not None: task = self.request("RetrieveTaskFromAgent", id=task_id) merge_parents_cache( task["parents"], diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index d6bcd9d4f4ecce31f4c562f600501a2ef4dc2fb9..a903a74df12950a7050000c149315ab0a3c3c88a 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -85,7 +85,10 @@ def test_list_elements_element_arg(mocker, mock_elements_worker): mocker.patch( "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args", return_value=Namespace( - element=["volumeid", "pageid"], verbose=False, elements_list=None + element=["volumeid", "pageid"], + verbose=False, + elements_list=None, + database=None, ), ) @@ -115,6 +118,7 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker): element=["anotherid", "againanotherid"], verbose=False, elements_list=open(path), + database=None, ), ) @@ -127,6 +131,27 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker): assert str(e.value) == "elements-list and element CLI args shouldn't be both set" +def test_database_arg(mocker, mock_elements_worker, tmp_path): + database_path = tmp_path / "my_database.sqlite" + database_path.touch() + + mocker.patch( + "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args", + return_value=Namespace( + element=["volumeid", "pageid"], + verbose=False, + elements_list=None, + database=str(database_path), + ), + ) + + worker = ElementsWorker() + worker.configure() + + assert worker.use_cache is True + assert worker.cache_path == str(database_path) + + def test_load_corpus_classes_api_error(responses, mock_elements_worker): corpus_id = "12341234-1234-1234-1234-123412341234" responses.add( diff --git a/tests/test_merge.py b/tests/test_merge.py index a2061622f38b6d6ce6ddec4c15e3d59173eac695..ac3c6b140f8049e798bf372af0097b6b94190f51 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -152,6 +152,8 @@ def test_merge_from_worker( # Configure worker with a specific data directory monkeypatch.setenv("PONOS_DATA", str(tmpdir)) + # Create the task's output dir, so that it can create its own database + (tmpdir / "my_task").mkdir() mock_base_worker_with_cache.configure() # Then we have 2 elements and a transcription