From f99050e98b409fcfb9a0d6e9e668272f2479f011 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Tue, 20 Sep 2022 12:39:11 +0000
Subject: [PATCH] Find local model directory method

---
 arkindex_worker/worker/base.py | 54 +++++++++++++++++++++++++++++++---
 tests/conftest.py              |  2 +-
 tests/test_base_worker.py      | 53 +++++++++++++++++++++++++++++++++
 3 files changed, 104 insertions(+), 5 deletions(-)

diff --git a/arkindex_worker/worker/base.py b/arkindex_worker/worker/base.py
index 8df49146..7e3f9776 100644
--- a/arkindex_worker/worker/base.py
+++ b/arkindex_worker/worker/base.py
@@ -45,6 +45,12 @@ def _is_500_error(exc: Exception) -> bool:
     return 500 <= exc.status_code < 600
 
 
+class ModelNotFoundError(Exception):
+    """
+    Exception raised when the path towards the model is invalid
+    """
+
+
 class BaseWorker(object):
     """
     Base class for Arkindex workers.
@@ -94,6 +100,12 @@ class BaseWorker(object):
             action="store_true",
             default=False,
         )
+        # To load models locally
+        self.parser.add_argument(
+            "--model-dir",
+            help=("The path to a local model's directory (development only)."),
+            type=Path,
+        )
 
         # Call potential extra arguments
         self.add_arguments()
@@ -110,6 +122,10 @@ class BaseWorker(object):
             self.work_dir = os.path.join(xdg_data_home, "arkindex")
             os.makedirs(self.work_dir, exist_ok=True)
 
+        # Store task ID. This is only available when running in production
+        # through a ponos agent
+        self.task_id = os.environ.get("PONOS_TASK")
+
         self.worker_version_id = os.environ.get("WORKER_VERSION_ID")
         if not self.worker_version_id:
             logger.warning(
@@ -245,12 +261,11 @@ class BaseWorker(object):
         """
         Setup the necessary attribute when using the cache system of `Base-Worker`.
         """
-        task_id = os.environ.get("PONOS_TASK")
         paths = None
         if self.support_cache and self.args.database is not None:
             self.use_cache = True
-        elif self.support_cache and task_id:
-            task = self.request("RetrieveTaskFromAgent", id=task_id)
+        elif self.support_cache and self.task_id:
+            task = self.request("RetrieveTaskFromAgent", id=self.task_id)
             paths = retrieve_parents_cache_path(
                 task["parents"],
                 data_dir=os.environ.get("PONOS_DATA", "/data"),
@@ -265,7 +280,9 @@ class BaseWorker(object):
                 ), f"Database in {self.args.database} does not exist"
                 self.cache_path = self.args.database
             else:
-                cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id)
+                cache_dir = os.path.join(
+                    os.environ.get("PONOS_DATA", "/data"), self.task_id
+                )
                 assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
                 self.cache_path = os.path.join(cache_dir, "db.sqlite")
 
@@ -335,6 +352,35 @@ class BaseWorker(object):
         # By default give raw secret payload
         return secret
 
+    def find_model_directory(self) -> Path:
+        """
+        Find the local path to the model. This supports two modes:
+        - the worker runs in ponos, the model is available at `/data/current`
+        - the worker runs locally, the developer may specify it using either
+           - the `model_dir` configuration parameter
+           - the `--model-dir` CLI parameter
+
+        :return: Path to the model on disk
+        """
+        if self.task_id:
+            # When running in production with ponos, the agent
+            # downloads the model and set it in the current task work dir
+            return Path(self.work_dir)
+        else:
+            model_dir = self.config.get("model_dir", self.args.model_dir)
+            if model_dir is None:
+                raise ModelNotFoundError(
+                    "No path to the model was provided. "
+                    "Please provide model_dir either through configuration "
+                    "or as CLI argument."
+                )
+            model_dir = Path(model_dir)
+            if not model_dir.exists():
+                raise ModelNotFoundError(
+                    f"The path {model_dir} does not link to any directory"
+                )
+            return model_dir
+
     @retry(
         retry=retry_if_exception(_is_500_error),
         wait=wait_exponential(multiplier=2, min=3),
diff --git a/tests/conftest.py b/tests/conftest.py
index 70357545..447588ed 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -237,9 +237,9 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_run_api):
     """Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK"""
     monkeypatch.setattr(sys, "argv", ["worker"])
 
+    monkeypatch.setenv("PONOS_TASK", "my_task")
     worker = BaseWorker(support_cache=True)
     worker.setup_api_client()
-    monkeypatch.setenv("PONOS_TASK", "my_task")
     return worker
 
 
diff --git a/tests/test_base_worker.py b/tests/test_base_worker.py
index 6239c81b..013cf121 100644
--- a/tests/test_base_worker.py
+++ b/tests/test_base_worker.py
@@ -11,6 +11,7 @@ import pytest
 from arkindex.mock import MockApiClient
 from arkindex_worker import logger
 from arkindex_worker.worker import BaseWorker
+from arkindex_worker.worker.base import ModelNotFoundError
 
 
 def test_init_default_local_share(monkeypatch):
@@ -545,3 +546,55 @@ def test_load_local_secret(monkeypatch, tmpdir):
     # The remote api is checked first
     assert len(worker.api_client.history) == 1
     assert worker.api_client.history[0].operation == "RetrieveSecret"
+
+
+def test_find_model_directory_ponos(monkeypatch):
+    monkeypatch.setenv("PONOS_TASK", "my_task")
+    monkeypatch.setenv("PONOS_DATA", "/data")
+    worker = BaseWorker()
+    assert worker.find_model_directory() == Path("/data/current")
+
+
+def test_find_model_directory_from_cli(monkeypatch):
+    monkeypatch.setattr(sys, "argv", ["worker", "--model-dir", "models"])
+    monkeypatch.setattr("pathlib.Path.exists", lambda x: True)
+    worker = BaseWorker()
+    worker.args = worker.parser.parse_args()
+    worker.config = {}
+    assert worker.find_model_directory() == Path("models")
+
+
+def test_find_model_directory_from_config(monkeypatch):
+    monkeypatch.setattr(sys, "argv", ["worker"])
+    monkeypatch.setattr("pathlib.Path.exists", lambda x: True)
+    worker = BaseWorker()
+    worker.args = worker.parser.parse_args()
+    worker.config = {"model_dir": "models"}
+    assert worker.find_model_directory() == Path("models")
+
+
+@pytest.mark.parametrize(
+    "model_path, exists, error",
+    (
+        [
+            None,
+            True,
+            "No path to the model was provided. Please provide model_dir either through configuration or as CLI argument.",
+        ],
+        ["models", False, "The path models does not link to any directory"],
+    ),
+)
+def test_find_model_directory_not_found(monkeypatch, model_path, exists, error):
+    if model_path:
+        monkeypatch.setattr(sys, "argv", ["worker", "--model-dir", model_path])
+    else:
+        monkeypatch.setattr(sys, "argv", ["worker"])
+
+    monkeypatch.setattr("pathlib.Path.exists", lambda x: exists)
+
+    worker = BaseWorker()
+    worker.args = worker.parser.parse_args()
+    worker.config = {"model_dir": model_path}
+
+    with pytest.raises(ModelNotFoundError, match=error):
+        worker.find_model_directory()
-- 
GitLab