Skip to content
Snippets Groups Projects
Commit 174b7545 authored by Manon Blanco's avatar Manon Blanco Committed by Yoann Schneider
Browse files

Move `configure` method to `BaseWorker`

parent 22abe1f8
No related branches found
No related tags found
1 merge request!583Move `configure` method to `BaseWorker`
Pipeline #187313 passed
...@@ -173,29 +173,6 @@ class ElementsWorker( ...@@ -173,29 +173,6 @@ class ElementsWorker(
), "Worker must be configured to access its process activity state" ), "Worker must be configured to access its process activity state"
return self.process_information.get("activity_state") == "ready" return self.process_information.get("activity_state") == "ready"
def configure(self):
"""
Setup the worker using CLI arguments and environment variables.
"""
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
else:
super().configure()
super().configure_cache()
# Retrieve the model configuration
if self.model_configuration:
self.config.update(self.model_configuration)
logger.info("Model version configuration retrieved")
# Retrieve the user configuration
if self.user_configuration:
self.config.update(self.user_configuration)
logger.info("User configuration retrieved")
def run(self): def run(self):
""" """
Implements an Arkindex worker that goes through each element returned by Implements an Arkindex worker that goes through each element returned by
...@@ -397,29 +374,6 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): ...@@ -397,29 +374,6 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
default=[], default=[],
) )
def configure(self):
"""
Setup the worker using CLI arguments and environment variables.
"""
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
else:
super().configure()
super().configure_cache()
# Retrieve the model configuration
if self.model_configuration:
self.config.update(self.model_configuration)
logger.info("Model version configuration retrieved")
# Retrieve the user configuration
if self.user_configuration:
self.config.update(self.user_configuration)
logger.info("User configuration retrieved")
def cleanup_downloaded_artifact(self) -> None: def cleanup_downloaded_artifact(self) -> None:
""" """
Cleanup the downloaded dataset artifact if any Cleanup the downloaded dataset artifact if any
......
...@@ -219,7 +219,7 @@ class BaseWorker: ...@@ -219,7 +219,7 @@ class BaseWorker:
# Load all required secrets # Load all required secrets
self.secrets = {name: self.load_secret(Path(name)) for name in required_secrets} self.secrets = {name: self.load_secret(Path(name)) for name in required_secrets}
def configure(self): def configure_worker_run(self):
""" """
Setup the necessary configuration needed using CLI args and environment variables. Setup the necessary configuration needed using CLI args and environment variables.
This is the method called when running a worker on Arkindex. This is the method called when running a worker on Arkindex.
...@@ -320,6 +320,29 @@ class BaseWorker: ...@@ -320,6 +320,29 @@ class BaseWorker:
else: else:
logger.debug("Cache is disabled") logger.debug("Cache is disabled")
def configure(self):
"""
Setup the worker using CLI arguments and environment variables.
"""
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
if self.is_read_only:
self.configure_for_developers()
else:
self.configure_worker_run()
self.configure_cache()
# Retrieve the model configuration
if self.model_configuration:
self.config.update(self.model_configuration)
logger.info("Model version configuration retrieved")
# Retrieve the user configuration
if self.user_configuration:
self.config.update(self.user_configuration)
logger.info("User configuration retrieved")
def load_secret(self, name: Path): def load_secret(self, name: Path):
""" """
Load a Ponos secret by name. Load a Ponos secret by name.
......
...@@ -206,6 +206,7 @@ def test_configure_worker_run(mocker, responses, caplog): ...@@ -206,6 +206,7 @@ def test_configure_worker_run(mocker, responses, caplog):
"Loaded Worker Fake worker @ 123412 from API", "Loaded Worker Fake worker @ 123412 from API",
), ),
("arkindex_worker", logging.INFO, "Loaded user configuration from WorkerRun"), ("arkindex_worker", logging.INFO, "Loaded user configuration from WorkerRun"),
("arkindex_worker", logging.INFO, "User configuration retrieved"),
] ]
assert worker.user_configuration == {"a": "b"} assert worker.user_configuration == {"a": "b"}
...@@ -284,12 +285,21 @@ def test_configure_user_configuration_defaults(mocker, responses): ...@@ -284,12 +285,21 @@ def test_configure_user_configuration_defaults(mocker, responses):
worker.configure() worker.configure()
assert worker.config == {"param_1": "/some/path/file.pth", "param_2": 12}
assert worker.user_configuration == { assert worker.user_configuration == {
"integer_parameter": 0, "integer_parameter": 0,
"param_3": "Animula vagula blandula", "param_3": "Animula vagula blandula",
"param_5": True, "param_5": True,
} }
# All configurations are merged
assert worker.config == {
# Default config
"param_1": "/some/path/file.pth",
"param_2": 12,
# User config
"integer_parameter": 0,
"param_3": "Animula vagula blandula",
"param_5": True,
}
@pytest.mark.parametrize("debug", [True, False]) @pytest.mark.parametrize("debug", [True, False])
...@@ -676,7 +686,6 @@ def test_find_parents_file_paths(responses, mock_base_worker_with_cache, tmp_pat ...@@ -676,7 +686,6 @@ def test_find_parents_file_paths(responses, mock_base_worker_with_cache, tmp_pat
mock_base_worker_with_cache.args = mock_base_worker_with_cache.parser.parse_args() mock_base_worker_with_cache.args = mock_base_worker_with_cache.parser.parse_args()
mock_base_worker_with_cache.configure() mock_base_worker_with_cache.configure()
mock_base_worker_with_cache.configure_cache()
assert mock_base_worker_with_cache.find_parents_file_paths(filename) == [ assert mock_base_worker_with_cache.find_parents_file_paths(filename) == [
tmp_path / "first" / filename, tmp_path / "first" / filename,
...@@ -753,3 +762,195 @@ def test_corpus_id_set_read_only_mode( ...@@ -753,3 +762,195 @@ def test_corpus_id_set_read_only_mode(
mock_elements_worker_read_only.configure() mock_elements_worker_read_only.configure()
assert mock_elements_worker_read_only.corpus_id == corpus_id assert mock_elements_worker_read_only.corpus_id == corpus_id
@pytest.mark.parametrize(
(
"wk_version_config",
"wk_version_user_config",
"frontend_user_config",
"model_config",
"expected_config",
),
[
({}, {}, {}, {}, {}),
# Keep parameters from worker version configuration
({"parameter": 0}, {}, {}, {}, {"parameter": 0}),
# Keep parameters from worker version configuration + user_config defaults
(
{"parameter": 0},
{
"parameter2": {
"type": "int",
"title": "Lambda",
"default": 0,
"required": False,
}
},
{},
{},
{"parameter": 0, "parameter2": 0},
),
# Keep parameters from worker version configuration + user_config no defaults
(
{"parameter": 0},
{
"parameter2": {
"type": "int",
"title": "Lambda",
"required": False,
}
},
{},
{},
{"parameter": 0},
),
# Keep parameters from worker version configuration but user_config defaults overrides
(
{"parameter": 0},
{
"parameter": {
"type": "int",
"title": "Lambda",
"default": 1,
"required": False,
}
},
{},
{},
{"parameter": 1},
),
# Keep parameters from worker version configuration + frontend config
(
{"parameter": 0},
{},
{"parameter2": 0},
{},
{"parameter": 0, "parameter2": 0},
),
# Keep parameters from worker version configuration + frontend config overrides
({"parameter": 0}, {}, {"parameter": 1}, {}, {"parameter": 1}),
# Keep parameters from worker version configuration + model config
(
{"parameter": 0},
{},
{},
{"parameter2": 0},
{"parameter": 0, "parameter2": 0},
),
# Keep parameters from worker version configuration + model config overrides
({"parameter": 0}, {}, {}, {"parameter": 1}, {"parameter": 1}),
# Keep parameters from worker version configuration + user_config default + model config overrides
(
{"parameter": 0},
{
"parameter": {
"type": "int",
"title": "Lambda",
"default": 1,
"required": False,
}
},
{},
{"parameter": 2},
{"parameter": 2},
),
# Keep parameters from worker version configuration + model config + frontend config overrides
({"parameter": 0}, {}, {"parameter": 2}, {"parameter": 1}, {"parameter": 2}),
# Keep parameters from worker version configuration + user_config default + model config + frontend config overrides all
(
{"parameter": 0},
{
"parameter": {
"type": "int",
"title": "Lambda",
"default": 1,
"required": False,
}
},
{"parameter": 3},
{"parameter": 2},
{"parameter": 3},
),
],
)
def test_worker_config_multiple_source(
monkeypatch,
responses,
wk_version_config,
wk_version_user_config,
frontend_user_config,
model_config,
expected_config,
):
# Compute WorkerRun info
payload = {
"id": "56785678-5678-5678-5678-567856785678",
"parents": [],
"worker_version": {
"id": "12341234-1234-1234-1234-123412341234",
"configuration": {
"docker": {"image": "python:3"},
"configuration": wk_version_config,
"secrets": [],
"user_configuration": wk_version_user_config,
},
"revision": {
"hash": "deadbeef1234",
"name": "some git revision",
},
"docker_image": "python:3",
"docker_image_name": "python:3",
"state": "created",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
},
"configuration": {
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"name": "Configuration entered by user",
"configuration": frontend_user_config,
},
"model_version": {
"id": "12341234-1234-1234-1234-123412341234",
"name": "Model version 1337",
"configuration": model_config,
"model": {
"id": "hahahaha-haha-haha-haha-hahahahahaha",
"name": "My model",
},
},
"process": {
"name": None,
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "running",
"mode": "workers",
"corpus": CORPUS_ID,
"use_cache": False,
"activity_state": "ready",
"model_id": None,
"train_folder_id": None,
"validation_folder_id": None,
"test_folder_id": None,
},
"summary": "Worker Fake worker @ 123412",
}
responses.add(
responses.GET,
"http://testserver/api/v1/process/workers/56785678-5678-5678-5678-567856785678/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
# Create and configure a worker
monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker()
worker.configure()
# Check final config
assert worker.config == expected_config
...@@ -6,7 +6,6 @@ from apistar.exceptions import ErrorResponse ...@@ -6,7 +6,6 @@ from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement from arkindex_worker.cache import CachedElement
from arkindex_worker.worker import ActivityState, ElementsWorker from arkindex_worker.worker import ActivityState, ElementsWorker
from tests import CORPUS_ID
from . import BASE_API_CALLS from . import BASE_API_CALLS
...@@ -313,202 +312,3 @@ def test_start_activity_error( ...@@ -313,202 +312,3 @@ def test_start_activity_error(
assert logger.error.call_args_list == [ assert logger.error.call_args_list == [
mocker.call("Ran on 1 element: 0 completed, 1 failed") mocker.call("Ran on 1 element: 0 completed, 1 failed")
] ]
@pytest.mark.parametrize(
(
"wk_version_config",
"wk_version_user_config",
"frontend_user_config",
"model_config",
"expected_config",
),
[
({}, {}, {}, {}, {}),
# Keep parameters from worker version configuration
({"parameter": 0}, {}, {}, {}, {"parameter": 0}),
# Keep parameters from worker version configuration + user_config defaults
(
{"parameter": 0},
{
"parameter2": {
"type": "int",
"title": "Lambda",
"default": 0,
"required": False,
}
},
{},
{},
{"parameter": 0, "parameter2": 0},
),
# Keep parameters from worker version configuration + user_config no defaults
(
{"parameter": 0},
{
"parameter2": {
"type": "int",
"title": "Lambda",
"required": False,
}
},
{},
{},
{"parameter": 0},
),
# Keep parameters from worker version configuration but user_config defaults overrides
(
{"parameter": 0},
{
"parameter": {
"type": "int",
"title": "Lambda",
"default": 1,
"required": False,
}
},
{},
{},
{"parameter": 1},
),
# Keep parameters from worker version configuration + frontend config
(
{"parameter": 0},
{},
{"parameter2": 0},
{},
{"parameter": 0, "parameter2": 0},
),
# Keep parameters from worker version configuration + frontend config overrides
({"parameter": 0}, {}, {"parameter": 1}, {}, {"parameter": 1}),
# Keep parameters from worker version configuration + model config
(
{"parameter": 0},
{},
{},
{"parameter2": 0},
{"parameter": 0, "parameter2": 0},
),
# Keep parameters from worker version configuration + model config overrides
({"parameter": 0}, {}, {}, {"parameter": 1}, {"parameter": 1}),
# Keep parameters from worker version configuration + user_config default + model config overrides
(
{"parameter": 0},
{
"parameter": {
"type": "int",
"title": "Lambda",
"default": 1,
"required": False,
}
},
{},
{"parameter": 2},
{"parameter": 2},
),
# Keep parameters from worker version configuration + model config + frontend config overrides
({"parameter": 0}, {}, {"parameter": 2}, {"parameter": 1}, {"parameter": 2}),
# Keep parameters from worker version configuration + user_config default + model config + frontend config overrides all
(
{"parameter": 0},
{
"parameter": {
"type": "int",
"title": "Lambda",
"default": 1,
"required": False,
}
},
{"parameter": 3},
{"parameter": 2},
{"parameter": 3},
),
],
)
def test_worker_config_multiple_source(
monkeypatch,
responses,
wk_version_config,
wk_version_user_config,
frontend_user_config,
model_config,
expected_config,
):
# Compute WorkerRun info
payload = {
"id": "56785678-5678-5678-5678-567856785678",
"parents": [],
"worker_version": {
"id": "12341234-1234-1234-1234-123412341234",
"configuration": {
"docker": {"image": "python:3"},
"configuration": wk_version_config,
"secrets": [],
"user_configuration": wk_version_user_config,
},
"revision": {
"hash": "deadbeef1234",
"name": "some git revision",
},
"docker_image": "python:3",
"docker_image_name": "python:3",
"state": "created",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
},
"configuration": {
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"name": "Configuration entered by user",
"configuration": frontend_user_config,
},
"model_version": {
"id": "12341234-1234-1234-1234-123412341234",
"name": "Model version 1337",
"configuration": model_config,
"model": {
"id": "hahahaha-haha-haha-haha-hahahahahaha",
"name": "My model",
},
},
"process": {
"name": None,
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "running",
"mode": "workers",
"corpus": CORPUS_ID,
"use_cache": False,
"activity_state": "ready",
"model_id": None,
"train_folder_id": None,
"validation_folder_id": None,
"test_folder_id": None,
},
"summary": "Worker Fake worker @ 123412",
}
responses.add(
responses.GET,
"http://testserver/api/v1/process/workers/56785678-5678-5678-5678-567856785678/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
# Create and configure a worker
monkeypatch.setattr(sys, "argv", ["worker"])
worker = ElementsWorker()
worker.configure()
# Do what people do with a model configuration
if worker.model_configuration:
worker.config.update(worker.model_configuration)
if worker.user_configuration:
worker.config.update(worker.user_configuration)
# Check final config
assert worker.config == expected_config
...@@ -181,7 +181,6 @@ def test_merge_from_worker( ...@@ -181,7 +181,6 @@ def test_merge_from_worker(
(tmp_path / "my_task").mkdir() (tmp_path / "my_task").mkdir()
mock_base_worker_with_cache.args = mock_base_worker_with_cache.parser.parse_args() mock_base_worker_with_cache.args = mock_base_worker_with_cache.parser.parse_args()
mock_base_worker_with_cache.configure() mock_base_worker_with_cache.configure()
mock_base_worker_with_cache.configure_cache()
# Store parent tasks IDs as attribute # Store parent tasks IDs as attribute
assert mock_base_worker_with_cache.task_parents == ["first", "second"] assert mock_base_worker_with_cache.task_parents == ["first", "second"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment