Skip to content
Snippets Groups Projects
Commit bdde8927 authored by Eva Bardou's avatar Eva Bardou :frog: Committed by Yoann Schneider
Browse files

Rename `--model-dir` to `--extras-dir`

parent 5df83b6d
No related branches found
No related tags found
1 merge request!419Rename `--model-dir` to `--extras-dir`
Pipeline #138171 passed
...@@ -46,9 +46,9 @@ def _is_500_error(exc: Exception) -> bool: ...@@ -46,9 +46,9 @@ def _is_500_error(exc: Exception) -> bool:
return 500 <= exc.status_code < 600 return 500 <= exc.status_code < 600
class ModelNotFoundError(Exception): class ExtrasDirNotFoundError(Exception):
""" """
Exception raised when the path towards the model is invalid Exception raised when the path towards the extras directory is invalid
""" """
...@@ -101,10 +101,12 @@ class BaseWorker(object): ...@@ -101,10 +101,12 @@ class BaseWorker(object):
action="store_true", action="store_true",
default=False, default=False,
) )
# To load models locally # To load models, datasets, etc, locally
self.parser.add_argument( self.parser.add_argument(
"--model-dir", "--extras-dir",
help=("The path to a local model's directory (development only)."), help=(
"The path to a local directory to store extra files like models, datasets, etc (development only)."
),
type=Path, type=Path,
) )
...@@ -371,39 +373,39 @@ class BaseWorker(object): ...@@ -371,39 +373,39 @@ class BaseWorker(object):
# By default give raw secret payload # By default give raw secret payload
return secret return secret
def find_model_directory(self) -> Path: def find_extras_directory(self) -> Path:
""" """
Find the local path to the model. This supports two modes: Find the local path to the directory to store extra files. This supports two modes:
- the worker runs in ponos, the model is available at `/data/extra_files` (first try) or `/data/current`. - the worker runs in ponos, the directory is available at `/data/extra_files` (first try) or `/data/current`.
- the worker runs locally, the developer may specify it using either - the worker runs locally, the developer may specify it using either
- the `model_dir` configuration parameter - the `extras_dir` configuration parameter
- the `--model-dir` CLI parameter - the `--extras-dir` CLI parameter
:return: Path to the model on disk :return: Path to the directory for extra files on disk
""" """
if self.task_id: if self.task_id:
# When running in production with ponos, the agent # When running in production with ponos, the agent
# downloads the model and set it either in # downloads the model and set it either in
# - `/data/extra_files` # - `/data/extra_files`
# - the current task work dir # - the current task work dir
extra_dir = self.task_data_dir / "extra_files" extras_dir = self.task_data_dir / "extra_files"
if extra_dir.exists(): if extras_dir.exists():
return extra_dir return extras_dir
return self.work_dir return self.work_dir
else: else:
model_dir = self.config.get("model_dir", self.args.model_dir) extras_dir = self.config.get("extras_dir", self.args.extras_dir)
if model_dir is None: if extras_dir is None:
raise ModelNotFoundError( raise ExtrasDirNotFoundError(
"No path to the model was provided. " "No path to the directory for extra files was provided. "
"Please provide model_dir either through configuration " "Please provide extras_dir either through configuration "
"or as CLI argument." "or as CLI argument."
) )
model_dir = Path(model_dir) extras_dir = Path(extras_dir)
if not model_dir.exists(): if not extras_dir.exists():
raise ModelNotFoundError( raise ExtrasDirNotFoundError(
f"The path {model_dir} does not link to any directory" f"The path {extras_dir} does not link to any directory"
) )
return model_dir return extras_dir
def find_parents_file_paths(self, filename: Path) -> List[Path]: def find_parents_file_paths(self, filename: Path) -> List[Path]:
""" """
......
...@@ -11,7 +11,7 @@ import pytest ...@@ -11,7 +11,7 @@ import pytest
from arkindex.mock import MockApiClient from arkindex.mock import MockApiClient
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.worker import BaseWorker, ElementsWorker from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.base import ModelNotFoundError from arkindex_worker.worker.base import ExtrasDirNotFoundError
from tests.conftest import FIXTURES_DIR from tests.conftest import FIXTURES_DIR
...@@ -602,55 +602,55 @@ def test_load_local_secret(monkeypatch, tmp_path): ...@@ -602,55 +602,55 @@ def test_load_local_secret(monkeypatch, tmp_path):
assert worker.api_client.history[0].operation == "RetrieveSecret" assert worker.api_client.history[0].operation == "RetrieveSecret"
def test_find_model_directory_ponos_no_extra_files(monkeypatch): def test_find_extras_directory_ponos_no_extra_files(monkeypatch):
monkeypatch.setenv("PONOS_TASK", "my_task") monkeypatch.setenv("PONOS_TASK", "my_task")
monkeypatch.setenv("PONOS_DATA", "/data") monkeypatch.setenv("PONOS_DATA", "/data")
worker = BaseWorker() worker = BaseWorker()
assert worker.find_model_directory() == Path("/data/current") assert worker.find_extras_directory() == Path("/data/current")
def test_find_model_directory_ponos_with_extra_files(monkeypatch): def test_find_extras_directory_ponos_with_extra_files(monkeypatch):
monkeypatch.setenv("PONOS_TASK", "my_task") monkeypatch.setenv("PONOS_TASK", "my_task")
monkeypatch.setenv("PONOS_DATA", "/data") monkeypatch.setenv("PONOS_DATA", "/data")
# Make the `extra_files` folder exist # Make the `extra_files` folder exist
monkeypatch.setattr("pathlib.Path.exists", lambda x: True) monkeypatch.setattr("pathlib.Path.exists", lambda x: True)
worker = BaseWorker() worker = BaseWorker()
assert worker.find_model_directory() == Path("/data/extra_files") assert worker.find_extras_directory() == Path("/data/extra_files")
def test_find_model_directory_from_cli(monkeypatch): def test_find_extras_directory_from_cli(monkeypatch):
monkeypatch.setattr(sys, "argv", ["worker", "--model-dir", "models"]) monkeypatch.setattr(sys, "argv", ["worker", "--extras-dir", "extra_files"])
monkeypatch.setattr("pathlib.Path.exists", lambda x: True) monkeypatch.setattr("pathlib.Path.exists", lambda x: True)
worker = BaseWorker() worker = BaseWorker()
worker.args = worker.parser.parse_args() worker.args = worker.parser.parse_args()
worker.config = {} worker.config = {}
assert worker.find_model_directory() == Path("models") assert worker.find_extras_directory() == Path("extra_files")
def test_find_model_directory_from_config(monkeypatch): def test_find_extras_directory_from_config(monkeypatch):
monkeypatch.setattr(sys, "argv", ["worker"]) monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setattr("pathlib.Path.exists", lambda x: True) monkeypatch.setattr("pathlib.Path.exists", lambda x: True)
worker = BaseWorker() worker = BaseWorker()
worker.args = worker.parser.parse_args() worker.args = worker.parser.parse_args()
worker.config = {"model_dir": "models"} worker.config = {"extras_dir": "extra_files"}
assert worker.find_model_directory() == Path("models") assert worker.find_extras_directory() == Path("extra_files")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_path, exists, error", "extras_path, exists, error",
( (
[ [
None, None,
True, True,
"No path to the model was provided. Please provide model_dir either through configuration or as CLI argument.", "No path to the directory for extra files was provided. Please provide extras_dir either through configuration or as CLI argument.",
], ],
["models", False, "The path models does not link to any directory"], ["extra_files", False, "The path extra_files does not link to any directory"],
), ),
) )
def test_find_model_directory_not_found(monkeypatch, model_path, exists, error): def test_find_extras_directory_not_found(monkeypatch, extras_path, exists, error):
if model_path: if extras_path:
monkeypatch.setattr(sys, "argv", ["worker", "--model-dir", model_path]) monkeypatch.setattr(sys, "argv", ["worker", "--extras-dir", extras_path])
else: else:
monkeypatch.setattr(sys, "argv", ["worker"]) monkeypatch.setattr(sys, "argv", ["worker"])
...@@ -658,10 +658,10 @@ def test_find_model_directory_not_found(monkeypatch, model_path, exists, error): ...@@ -658,10 +658,10 @@ def test_find_model_directory_not_found(monkeypatch, model_path, exists, error):
worker = BaseWorker() worker = BaseWorker()
worker.args = worker.parser.parse_args() worker.args = worker.parser.parse_args()
worker.config = {"model_dir": model_path} worker.config = {"extras_dir": extras_path}
with pytest.raises(ModelNotFoundError, match=error): with pytest.raises(ExtrasDirNotFoundError, match=error):
worker.find_model_directory() worker.find_extras_directory()
def test_find_parents_file_paths(responses, mock_base_worker_with_cache, tmp_path): def test_find_parents_file_paths(responses, mock_base_worker_with_cache, tmp_path):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment