Skip to content
Snippets Groups Projects
Commit 92aada0c authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Load model version configuration from RetrieveWorkerRun

parent c8c61a90
No related branches found
No related tags found
1 merge request!222Load model version configuration from RetrieveWorkerRun
Pipeline #79635 passed
......@@ -144,6 +144,7 @@ class BaseWorker(object):
# or in configure_for_developers() from the environment
self.corpus_id = None
self.user_configuration = {}
self.model_configuration = {}
self.support_cache = support_cache
# use_cache will be updated in configure() if the cache is supported and if there
# is at least one available sqlite database either given or in the parent tasks
......@@ -252,6 +253,12 @@ class BaseWorker(object):
logger.info("Loaded user configuration from WorkerRun")
self.user_configuration.update(worker_configuration.get("configuration"))
# Load model version configuration when available
model_version = worker_run.get("model_version")
if model_version and model_version.get("configuration"):
logger.info("Loaded model version configuration from WorkerRun")
self.model_configuration.update(model_version.get("configuration"))
# if debug mode is set to true activate debug mode in logger
if self.user_configuration.get("debug"):
logger.setLevel(logging.DEBUG)
......
......@@ -150,6 +150,7 @@ def mock_worker_run_api(responses):
"name": "string",
"configuration": {},
},
"model_version": None,
"process": {
"name": None,
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
......
......@@ -195,6 +195,7 @@ def test_configure_worker_run(mocker, monkeypatch, responses):
"configuration": {"configuration": {}},
},
"configuration": user_configuration,
"model_version": None,
"process": {
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"corpus": "11111111-1111-1111-1111-111111111111",
......@@ -269,6 +270,7 @@ def test_configure_user_configuration_defaults(
"param_5": True,
},
},
"model_version": None,
"process": {
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"corpus": "11111111-1111-1111-1111-111111111111",
......@@ -321,6 +323,7 @@ def test_configure_user_config_debug(mocker, monkeypatch, responses, debug):
"revision": {"hash": "deadbeef1234"},
"configuration": {"configuration": {}},
},
"model_version": None,
"configuration": {
"id": "af0daaf4-983e-4703-a7ed-a10f146d6684",
"name": "BBB",
......@@ -375,6 +378,7 @@ def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses):
"revision": {"hash": "deadbeef1234"},
"configuration": {"configuration": {}},
},
"model_version": None,
"configuration": {"id": "bbbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", "name": "BBB"},
"process": {
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
......@@ -425,6 +429,7 @@ def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses)
"revision": {"hash": "deadbeef1234"},
"configuration": {},
},
"model_version": None,
"configuration": None,
"process": {
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
......@@ -444,6 +449,71 @@ def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses)
assert worker.user_configuration == {}
def test_configure_load_model_configuration(mocker, monkeypatch, responses):
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
payload = {
"id": "56785678-5678-5678-5678-567856785678",
"parents": [],
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"model_version_id": "12341234-1234-1234-1234-123412341234",
"dataimport_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"configuration_id": None,
"worker_version": {
"id": "12341234-1234-1234-1234-123412341234",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"revision": {"hash": "deadbeef1234"},
"configuration": {"configuration": {}},
},
"configuration": None,
"model_version": {
"id": "12341234-1234-1234-1234-123412341234",
"name": "Model version 1337",
"configuration": {
"param1": "value1",
"param2": 2,
"param3": None,
},
},
"process": {
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"corpus": "11111111-1111-1111-1111-111111111111",
},
}
responses.add(
responses.GET,
"http://testserver/api/v1/imports/workers/56785678-5678-5678-5678-567856785678/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
worker.args = worker.parser.parse_args()
assert worker.is_read_only is False
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
assert worker.model_configuration == {}
worker.configure()
assert worker.model_configuration == {
"param1": "value1",
"param2": 2,
"param3": None,
}
def test_load_missing_secret():
worker = BaseWorker()
worker.api_client = MockApiClient()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment