diff --git a/arkindex_worker/worker/base.py b/arkindex_worker/worker/base.py index 8df49146508800ea6437df50132d164ea8054f38..7e3f97769d76eb18348fb28ec95213a0db94ed73 100644 --- a/arkindex_worker/worker/base.py +++ b/arkindex_worker/worker/base.py @@ -45,6 +45,12 @@ def _is_500_error(exc: Exception) -> bool: return 500 <= exc.status_code < 600 +class ModelNotFoundError(Exception): + """ + Exception raised when the path towards the model is invalid + """ + + class BaseWorker(object): """ Base class for Arkindex workers. @@ -94,6 +100,12 @@ class BaseWorker(object): action="store_true", default=False, ) + # To load models locally + self.parser.add_argument( + "--model-dir", + help=("The path to a local model's directory (development only)."), + type=Path, + ) # Call potential extra arguments self.add_arguments() @@ -110,6 +122,10 @@ class BaseWorker(object): self.work_dir = os.path.join(xdg_data_home, "arkindex") os.makedirs(self.work_dir, exist_ok=True) + # Store task ID. This is only available when running in production + # through a ponos agent + self.task_id = os.environ.get("PONOS_TASK") + self.worker_version_id = os.environ.get("WORKER_VERSION_ID") if not self.worker_version_id: logger.warning( @@ -245,12 +261,11 @@ class BaseWorker(object): """ Setup the necessary attribute when using the cache system of `Base-Worker`. """ - task_id = os.environ.get("PONOS_TASK") paths = None if self.support_cache and self.args.database is not None: self.use_cache = True - elif self.support_cache and task_id: - task = self.request("RetrieveTaskFromAgent", id=task_id) + elif self.support_cache and self.task_id: + task = self.request("RetrieveTaskFromAgent", id=self.task_id) paths = retrieve_parents_cache_path( task["parents"], data_dir=os.environ.get("PONOS_DATA", "/data"), @@ -265,7 +280,9 @@ class BaseWorker(object): ), f"Database in {self.args.database} does not exist" self.cache_path = self.args.database else: - cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id) + cache_dir = os.path.join( + os.environ.get("PONOS_DATA", "/data"), self.task_id + ) assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}" self.cache_path = os.path.join(cache_dir, "db.sqlite") @@ -335,6 +352,35 @@ class BaseWorker(object): # By default give raw secret payload return secret + def find_model_directory(self) -> Path: + """ + Find the local path to the model. This supports two modes: + - the worker runs in ponos, the model is available at `/data/current` + - the worker runs locally, the developer may specify it using either + - the `model_dir` configuration parameter + - the `--model-dir` CLI parameter + + :return: Path to the model on disk + """ + if self.task_id: + # When running in production with ponos, the agent + # downloads the model and set it in the current task work dir + return Path(self.work_dir) + else: + model_dir = self.config.get("model_dir", self.args.model_dir) + if model_dir is None: + raise ModelNotFoundError( + "No path to the model was provided. " + "Please provide model_dir either through configuration " + "or as CLI argument." + ) + model_dir = Path(model_dir) + if not model_dir.exists(): + raise ModelNotFoundError( + f"The path {model_dir} does not link to any directory" + ) + return model_dir + @retry( retry=retry_if_exception(_is_500_error), wait=wait_exponential(multiplier=2, min=3), diff --git a/tests/conftest.py b/tests/conftest.py index 703575452a5b784f597e4b6a72f83a2be290d866..447588ed928dca58d4abf58f762abbd2add3a0c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -237,9 +237,9 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_run_api): """Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK""" monkeypatch.setattr(sys, "argv", ["worker"]) + monkeypatch.setenv("PONOS_TASK", "my_task") worker = BaseWorker(support_cache=True) worker.setup_api_client() - monkeypatch.setenv("PONOS_TASK", "my_task") return worker diff --git a/tests/test_base_worker.py b/tests/test_base_worker.py index 6239c81b27495f3892d6376caa1b10eebad55b2a..013cf121ec21475b5326401f8d836b679172f3c3 100644 --- a/tests/test_base_worker.py +++ b/tests/test_base_worker.py @@ -11,6 +11,7 @@ import pytest from arkindex.mock import MockApiClient from arkindex_worker import logger from arkindex_worker.worker import BaseWorker +from arkindex_worker.worker.base import ModelNotFoundError def test_init_default_local_share(monkeypatch): @@ -545,3 +546,55 @@ def test_load_local_secret(monkeypatch, tmpdir): # The remote api is checked first assert len(worker.api_client.history) == 1 assert worker.api_client.history[0].operation == "RetrieveSecret" + + +def test_find_model_directory_ponos(monkeypatch): + monkeypatch.setenv("PONOS_TASK", "my_task") + monkeypatch.setenv("PONOS_DATA", "/data") + worker = BaseWorker() + assert worker.find_model_directory() == Path("/data/current") + + +def test_find_model_directory_from_cli(monkeypatch): + monkeypatch.setattr(sys, "argv", ["worker", "--model-dir", "models"]) + monkeypatch.setattr("pathlib.Path.exists", lambda x: True) + worker = BaseWorker() + worker.args = worker.parser.parse_args() + worker.config = {} + assert worker.find_model_directory() == Path("models") + + +def test_find_model_directory_from_config(monkeypatch): + monkeypatch.setattr(sys, "argv", ["worker"]) + monkeypatch.setattr("pathlib.Path.exists", lambda x: True) + worker = BaseWorker() + worker.args = worker.parser.parse_args() + worker.config = {"model_dir": "models"} + assert worker.find_model_directory() == Path("models") + + +@pytest.mark.parametrize( + "model_path, exists, error", + ( + [ + None, + True, + "No path to the model was provided. Please provide model_dir either through configuration or as CLI argument.", + ], + ["models", False, "The path models does not link to any directory"], + ), +) +def test_find_model_directory_not_found(monkeypatch, model_path, exists, error): + if model_path: + monkeypatch.setattr(sys, "argv", ["worker", "--model-dir", model_path]) + else: + monkeypatch.setattr(sys, "argv", ["worker"]) + + monkeypatch.setattr("pathlib.Path.exists", lambda x: exists) + + worker = BaseWorker() + worker.args = worker.parser.parse_args() + worker.config = {"model_dir": model_path} + + with pytest.raises(ModelNotFoundError, match=error): + worker.find_model_directory()