diff --git a/arkindex_worker/worker/__init__.py b/arkindex_worker/worker/__init__.py index 1ff5ba7631d7f648a6ce398034fc738665cb3502..67f01a1e331c0c764b01dfd76d5dee8373b0ec63 100644 --- a/arkindex_worker/worker/__init__.py +++ b/arkindex_worker/worker/__init__.py @@ -36,8 +36,8 @@ class ElementsWorker( EntityMixin, MetaDataMixin, ): - def __init__(self, description="Arkindex Elements Worker", use_cache=False): - super().__init__(description, use_cache) + def __init__(self, description="Arkindex Elements Worker", support_cache=False): + super().__init__(description, support_cache) # Add report concerning elements self.report = Reporter("unknown worker") diff --git a/arkindex_worker/worker/base.py b/arkindex_worker/worker/base.py index d9ab11cea1204ad31fcd3dfd81daf9b125236f94..ff0905530aa21dfca4e520bec2dd466e13194f6b 100644 --- a/arkindex_worker/worker/base.py +++ b/arkindex_worker/worker/base.py @@ -33,7 +33,7 @@ def _is_500_error(exc): class BaseWorker(object): - def __init__(self, description="Arkindex Base Worker", use_cache=False): + def __init__(self, description="Arkindex Base Worker", support_cache=False): self.parser = argparse.ArgumentParser(description=description) # Setup workdir either in Ponos environment or on host's home @@ -56,7 +56,10 @@ class BaseWorker(object): logger.info(f"Worker will use {self.work_dir} as working directory") - self.use_cache = use_cache + self.support_cache = support_cache + # use_cache will be updated in configure() if the cache is supported and if there + # is at least one available sqlite database either given or in the parent tasks + self.use_cache = False @property def is_read_only(self): @@ -133,22 +136,24 @@ class BaseWorker(object): # Load all required secrets self.secrets = {name: self.load_secret(name) for name in required_secrets} - if self.args.database is not None: + task_id = os.environ.get("PONOS_TASK") + task = None + if self.support_cache and self.args.database is not None: self.use_cache = True + elif self.support_cache and task_id: + task = self.request("RetrieveTaskFromAgent", id=task_id) + self.use_cache = len(task["parents"]) > 0 - task_id = os.environ.get("PONOS_TASK") - if self.use_cache is True: + if self.use_cache: if self.args.database is not None: assert os.path.isfile( self.args.database ), f"Database in {self.args.database} does not exist" self.cache_path = self.args.database - elif task_id: + else: cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id) assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}" self.cache_path = os.path.join(cache_dir, "db.sqlite") - else: - self.cache_path = os.path.join(os.getcwd(), "db.sqlite") init_cache_db(self.cache_path) create_tables() @@ -156,8 +161,7 @@ class BaseWorker(object): logger.debug("Cache is disabled") # Merging parents caches (if there are any) in the current task local cache, unless the database got overridden - if self.use_cache and self.args.database is None and task_id is not None: - task = self.request("RetrieveTaskFromAgent", id=task_id) + if self.use_cache and self.args.database is None and task: merge_parents_cache( task["parents"], self.cache_path, diff --git a/tests/conftest.py b/tests/conftest.py index 840784cadf09ef204381dbd2eb50c2f3f8defea6..449572aa624911afed207f3f2277aa548c1ccff8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -177,7 +177,7 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api): """Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK""" monkeypatch.setattr(sys, "argv", ["worker"]) - worker = BaseWorker(use_cache=True) + worker = BaseWorker(support_cache=True) monkeypatch.setenv("PONOS_TASK", "my_task") return worker @@ -188,7 +188,7 @@ def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api): monkeypatch.setattr(sys, "argv", ["worker"]) monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111") - worker = ElementsWorker(use_cache=True) + worker = ElementsWorker(support_cache=True) worker.configure() return worker diff --git a/tests/test_base_worker.py b/tests/test_base_worker.py index 40279ea0edab216f73cabd3245407d5fa6049995..acd5c5576f9c807208ee00dce496d971178acb48 100644 --- a/tests/test_base_worker.py +++ b/tests/test_base_worker.py @@ -29,7 +29,7 @@ def test_init_default_xdg_data_home(monkeypatch): def test_init_with_local_cache(monkeypatch): - worker = BaseWorker(use_cache=True) + worker = BaseWorker(support_cache=True) assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex") assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"