diff --git a/arkindex_worker/worker/__init__.py b/arkindex_worker/worker/__init__.py index 61a9f7afa83939295df0af64b066e37bb62fb2ca..2ea54f18fe3529a9972aefefa9ac0b16d30f8282 100644 --- a/arkindex_worker/worker/__init__.py +++ b/arkindex_worker/worker/__init__.py @@ -235,7 +235,7 @@ class ElementsWorker( self.report.error(element_id, e) # Save report as local artifact - self.report.save(os.path.join(self.work_dir, "ml_report.json")) + self.report.save(self.work_dir / "ml_report.json") if failed: logger.error( diff --git a/arkindex_worker/worker/base.py b/arkindex_worker/worker/base.py index edb818a55fb2369b9e04ec7a955f58ecce164e6e..15d3863debbc08e59d692218f66b122c4baa9dac 100644 --- a/arkindex_worker/worker/base.py +++ b/arkindex_worker/worker/base.py @@ -112,15 +112,15 @@ class BaseWorker(object): # Setup workdir either in Ponos environment or on host's home if os.environ.get("PONOS_DATA"): - self.work_dir = os.path.join(os.environ["PONOS_DATA"], "current") + self.work_dir = Path(os.environ["PONOS_DATA"], "current") else: # We use the official XDG convention to store file for developers # https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html xdg_data_home = os.environ.get( "XDG_DATA_HOME", os.path.expanduser("~/.local/share") ) - self.work_dir = os.path.join(xdg_data_home, "arkindex") - os.makedirs(self.work_dir, exist_ok=True) + self.work_dir = Path(xdg_data_home, "arkindex") + self.work_dir.mkdir(parents=True, exist_ok=True) # Store task ID. This is only available when running in production # through a ponos agent @@ -377,7 +377,7 @@ class BaseWorker(object): if self.task_id: # When running in production with ponos, the agent # downloads the model and set it in the current task work dir - return Path(self.work_dir) + return self.work_dir else: model_dir = self.config.get("model_dir", self.args.model_dir) if model_dir is None: diff --git a/tests/test_base_worker.py b/tests/test_base_worker.py index 88edf67fe89965fc96d5fabb9122e2a900909290..a85e9081c35aad04273b8d3975d558fa35d9b8a8 100644 --- a/tests/test_base_worker.py +++ b/tests/test_base_worker.py @@ -17,7 +17,7 @@ from arkindex_worker.worker.base import ModelNotFoundError def test_init_default_local_share(monkeypatch): worker = BaseWorker() - assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex") + assert str(worker.work_dir) == os.path.expanduser("~/.local/share/arkindex") def test_init_default_xdg_data_home(monkeypatch): @@ -25,13 +25,13 @@ def test_init_default_xdg_data_home(monkeypatch): monkeypatch.setenv("XDG_DATA_HOME", path) worker = BaseWorker() - assert worker.work_dir == f"{path}/arkindex" + assert str(worker.work_dir) == f"{path}/arkindex" def test_init_with_local_cache(monkeypatch): worker = BaseWorker(support_cache=True) - assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex") + assert str(worker.work_dir) == os.path.expanduser("~/.local/share/arkindex") assert worker.support_cache is True @@ -40,7 +40,7 @@ def test_init_var_ponos_data_given(monkeypatch): monkeypatch.setenv("PONOS_DATA", path) worker = BaseWorker() - assert worker.work_dir == f"{path}/current" + assert str(worker.work_dir) == f"{path}/current" def test_init_var_worker_run_id_missing(monkeypatch):