Skip to content
Snippets Groups Projects
Commit f99050e9 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Find local model directory method

parent 74bda6d6
No related branches found
No related tags found
1 merge request!217Find local model directory method
Pipeline #79616 passed
...@@ -45,6 +45,12 @@ def _is_500_error(exc: Exception) -> bool: ...@@ -45,6 +45,12 @@ def _is_500_error(exc: Exception) -> bool:
return 500 <= exc.status_code < 600 return 500 <= exc.status_code < 600
class ModelNotFoundError(Exception):
"""
Exception raised when the path towards the model is invalid
"""
class BaseWorker(object): class BaseWorker(object):
""" """
Base class for Arkindex workers. Base class for Arkindex workers.
...@@ -94,6 +100,12 @@ class BaseWorker(object): ...@@ -94,6 +100,12 @@ class BaseWorker(object):
action="store_true", action="store_true",
default=False, default=False,
) )
# To load models locally
self.parser.add_argument(
"--model-dir",
help=("The path to a local model's directory (development only)."),
type=Path,
)
# Call potential extra arguments # Call potential extra arguments
self.add_arguments() self.add_arguments()
...@@ -110,6 +122,10 @@ class BaseWorker(object): ...@@ -110,6 +122,10 @@ class BaseWorker(object):
self.work_dir = os.path.join(xdg_data_home, "arkindex") self.work_dir = os.path.join(xdg_data_home, "arkindex")
os.makedirs(self.work_dir, exist_ok=True) os.makedirs(self.work_dir, exist_ok=True)
# Store task ID. This is only available when running in production
# through a ponos agent
self.task_id = os.environ.get("PONOS_TASK")
self.worker_version_id = os.environ.get("WORKER_VERSION_ID") self.worker_version_id = os.environ.get("WORKER_VERSION_ID")
if not self.worker_version_id: if not self.worker_version_id:
logger.warning( logger.warning(
...@@ -245,12 +261,11 @@ class BaseWorker(object): ...@@ -245,12 +261,11 @@ class BaseWorker(object):
""" """
Setup the necessary attribute when using the cache system of `Base-Worker`. Setup the necessary attribute when using the cache system of `Base-Worker`.
""" """
task_id = os.environ.get("PONOS_TASK")
paths = None paths = None
if self.support_cache and self.args.database is not None: if self.support_cache and self.args.database is not None:
self.use_cache = True self.use_cache = True
elif self.support_cache and task_id: elif self.support_cache and self.task_id:
task = self.request("RetrieveTaskFromAgent", id=task_id) task = self.request("RetrieveTaskFromAgent", id=self.task_id)
paths = retrieve_parents_cache_path( paths = retrieve_parents_cache_path(
task["parents"], task["parents"],
data_dir=os.environ.get("PONOS_DATA", "/data"), data_dir=os.environ.get("PONOS_DATA", "/data"),
...@@ -265,7 +280,9 @@ class BaseWorker(object): ...@@ -265,7 +280,9 @@ class BaseWorker(object):
), f"Database in {self.args.database} does not exist" ), f"Database in {self.args.database} does not exist"
self.cache_path = self.args.database self.cache_path = self.args.database
else: else:
cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id) cache_dir = os.path.join(
os.environ.get("PONOS_DATA", "/data"), self.task_id
)
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}" assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite") self.cache_path = os.path.join(cache_dir, "db.sqlite")
...@@ -335,6 +352,35 @@ class BaseWorker(object): ...@@ -335,6 +352,35 @@ class BaseWorker(object):
# By default give raw secret payload # By default give raw secret payload
return secret return secret
def find_model_directory(self) -> Path:
"""
Find the local path to the model. This supports two modes:
- the worker runs in ponos, the model is available at `/data/current`
- the worker runs locally, the developer may specify it using either
- the `model_dir` configuration parameter
- the `--model-dir` CLI parameter
:return: Path to the model on disk
"""
if self.task_id:
# When running in production with ponos, the agent
# downloads the model and set it in the current task work dir
return Path(self.work_dir)
else:
model_dir = self.config.get("model_dir", self.args.model_dir)
if model_dir is None:
raise ModelNotFoundError(
"No path to the model was provided. "
"Please provide model_dir either through configuration "
"or as CLI argument."
)
model_dir = Path(model_dir)
if not model_dir.exists():
raise ModelNotFoundError(
f"The path {model_dir} does not link to any directory"
)
return model_dir
@retry( @retry(
retry=retry_if_exception(_is_500_error), retry=retry_if_exception(_is_500_error),
wait=wait_exponential(multiplier=2, min=3), wait=wait_exponential(multiplier=2, min=3),
......
...@@ -237,9 +237,9 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_run_api): ...@@ -237,9 +237,9 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_run_api):
"""Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK""" """Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK"""
monkeypatch.setattr(sys, "argv", ["worker"]) monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("PONOS_TASK", "my_task")
worker = BaseWorker(support_cache=True) worker = BaseWorker(support_cache=True)
worker.setup_api_client() worker.setup_api_client()
monkeypatch.setenv("PONOS_TASK", "my_task")
return worker return worker
......
...@@ -11,6 +11,7 @@ import pytest ...@@ -11,6 +11,7 @@ import pytest
from arkindex.mock import MockApiClient from arkindex.mock import MockApiClient
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.worker import BaseWorker from arkindex_worker.worker import BaseWorker
from arkindex_worker.worker.base import ModelNotFoundError
def test_init_default_local_share(monkeypatch): def test_init_default_local_share(monkeypatch):
...@@ -545,3 +546,55 @@ def test_load_local_secret(monkeypatch, tmpdir): ...@@ -545,3 +546,55 @@ def test_load_local_secret(monkeypatch, tmpdir):
# The remote api is checked first # The remote api is checked first
assert len(worker.api_client.history) == 1 assert len(worker.api_client.history) == 1
assert worker.api_client.history[0].operation == "RetrieveSecret" assert worker.api_client.history[0].operation == "RetrieveSecret"
def test_find_model_directory_ponos(monkeypatch):
monkeypatch.setenv("PONOS_TASK", "my_task")
monkeypatch.setenv("PONOS_DATA", "/data")
worker = BaseWorker()
assert worker.find_model_directory() == Path("/data/current")
def test_find_model_directory_from_cli(monkeypatch):
monkeypatch.setattr(sys, "argv", ["worker", "--model-dir", "models"])
monkeypatch.setattr("pathlib.Path.exists", lambda x: True)
worker = BaseWorker()
worker.args = worker.parser.parse_args()
worker.config = {}
assert worker.find_model_directory() == Path("models")
def test_find_model_directory_from_config(monkeypatch):
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setattr("pathlib.Path.exists", lambda x: True)
worker = BaseWorker()
worker.args = worker.parser.parse_args()
worker.config = {"model_dir": "models"}
assert worker.find_model_directory() == Path("models")
@pytest.mark.parametrize(
"model_path, exists, error",
(
[
None,
True,
"No path to the model was provided. Please provide model_dir either through configuration or as CLI argument.",
],
["models", False, "The path models does not link to any directory"],
),
)
def test_find_model_directory_not_found(monkeypatch, model_path, exists, error):
if model_path:
monkeypatch.setattr(sys, "argv", ["worker", "--model-dir", model_path])
else:
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setattr("pathlib.Path.exists", lambda x: exists)
worker = BaseWorker()
worker.args = worker.parser.parse_args()
worker.config = {"model_dir": model_path}
with pytest.raises(ModelNotFoundError, match=error):
worker.find_model_directory()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment