diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index e3beb8cc9e8929a9f23856a3f645e5bfb8729907..4f78a52f0af684e490b436b44ba2d5c28793ef0a 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -87,6 +87,11 @@ class TrainingMixin(object): This method creates a model archive and its associated hash, to create a unique version that will be stored on a bucket and published on arkindex. """ + if self.is_read_only: + logger.warning( + "Cannot publish a new model version as this worker is in read-only mode" + ) + return # Create the zst archive, get its hash and size with create_archive(path=model_path) as ( @@ -131,6 +136,11 @@ class TrainingMixin(object): - The version is in `Created` state: this version's details is used - The version is in `Available` state: you cannot create twice the same version, an error is raised """ + if self.is_read_only: + logger.warning( + "Cannot create a new model version as this worker is in read-only mode" + ) + return # Create a new model version with hash and size try: @@ -171,6 +181,11 @@ class TrainingMixin(object): """ Upload the archive of the model's files to an Amazon s3 compatible storage """ + if self.is_read_only: + logger.warning( + "Cannot upload this archive as this worker is in read-only mode" + ) + return s3_put_url = model_version_details.get("s3_put_url") logger.info("Uploading to s3...") @@ -191,6 +206,11 @@ class TrainingMixin(object): """ Update the specified model version to the state `Available` and use the given information" """ + if self.is_read_only: + logger.warning( + "Cannot update this model version as this worker is in read-only mode" + ) + return model_version_id = model_version_details.get("id") logger.info(f"Updating model version ({model_version_id})") diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index 9b6f014a8fc19c1a5bd5ba13425f3cbb5fbb82d6..739982d99fe4a71895dee31214bf22217f46eb15 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import os +import sys import pytest import responses @@ -17,6 +18,15 @@ class TrainingWorker(BaseWorker, TrainingMixin): pass +@pytest.fixture +def mock_training_worker(monkeypatch): + monkeypatch.setattr(sys, "argv", ["worker"]) + training_worker = TrainingWorker() + training_worker.api_client = MockApiClient() + training_worker.args = training_worker.parser.parse_args() + return training_worker + + def test_create_archive(model_file_dir): """Create an archive when the model's file is in a folder""" @@ -46,13 +56,11 @@ def test_create_archive(model_file_dir): (None, ""), ], ) -def test_create_model_version(tag, description): +def test_create_model_version(mock_training_worker, tag, description): """A new model version is returned""" model_version_id = "fake_model_version_id" model_id = "fake_model_id" - training = TrainingWorker() - training.api_client = MockApiClient() model_hash = "hash" archive_hash = "archive_hash" size = "30" @@ -68,7 +76,7 @@ def test_create_model_version(tag, description): "s3_put_url": "http://hehehe.com", } - training.api_client.add_response( + mock_training_worker.api_client.add_response( "CreateModelVersion", id=model_id, response=model_version_details, @@ -81,7 +89,7 @@ def test_create_model_version(tag, description): }, ) assert ( - training.create_model_version( + mock_training_worker.create_model_version( model_id, model_hash, size, archive_hash, tag, description ) == model_version_details @@ -126,7 +134,7 @@ def test_create_model_version(tag, description): ({"hash": ["A version for this model with this hash already exists."]}, 403), ], ) -def test_retrieve_created_model_version(content, status_code): +def test_retrieve_created_model_version(mock_training_worker, content, status_code): """ If there is an existing model version in Created mode, A 400 was raised, but the model is still returned in error content. @@ -135,14 +143,12 @@ def test_retrieve_created_model_version(content, status_code): """ model_id = "fake_model_id" - training = TrainingWorker() - training.api_client = MockApiClient() model_hash = "hash" archive_hash = "archive_hash" size = "30" tag = "tag" description = "description" - training.api_client.add_error_response( + mock_training_worker.api_client.add_error_response( "CreateModelVersion", id=model_id, status_code=status_code, @@ -157,26 +163,24 @@ def test_retrieve_created_model_version(content, status_code): ) if status_code == 400: assert ( - training.create_model_version( + mock_training_worker.create_model_version( model_id, model_hash, size, archive_hash, tag, description ) == content["hash"] ) elif status_code == 403: assert ( - training.create_model_version( + mock_training_worker.create_model_version( model_id, model_hash, size, archive_hash, tag, description ) is None ) -def test_handle_s3_uploading_errors(model_file_dir): - training = TrainingWorker() - training.api_client = MockApiClient() +def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir): s3_endpoint_url = "http://s3.localhost.com" responses.add_passthru(s3_endpoint_url) responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400)) file_path = model_file_dir / "model_file.pth" with pytest.raises(Exception): - training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url}) + mock_training_worker.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})