diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index 88bebe38f03effe2ef4db603e2e35170d90a4bec..182927f3666801db3fd276efa74f10f8e6d29a1e 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -128,6 +128,7 @@ class TrainingMixin(object): ): logger.error(model_version_details[0]) return + return model_version_details def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None: diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index 4661ac130b66c08e688d451af6d1e4b31f6a0606..c0905ae47336ac7f985395fd3482bdb63221011e 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -5,10 +5,15 @@ from pathlib import Path import responses from responses import matchers -from arkindex.mock import MockApiClient +from arkindex_worker.worker import BaseWorker from arkindex_worker.worker.training import TrainingMixin, create_archive +class TrainingWorker(BaseWorker, TrainingMixin): + def __init__(self): + super().setup_api_client() + + def test_create_archive_folder(): model_file_dir = Path("tests/samples/model_files") @@ -28,15 +33,13 @@ def test_create_archive_folder(): def test_create_model_version(): - api_client = MockApiClient() - """A new model version is returned""" model_id = "fake_model_id" model_version_id = "fake_model_version_id" # Create a model archive and keep its hash and size. model_files_dir = Path("tests/samples/model_files") # model_file_path = model_files_dir / "model_file.pth" - training = TrainingMixin() + training = TrainingWorker() with create_archive(path=model_files_dir) as ( zst_archive_path, hash, @@ -66,7 +69,7 @@ def test_create_model_version(): ) assert ( - training.create_model_version(api_client, model_id, hash, size, archive_hash) + training.create_model_version(model_id, hash, size, archive_hash) == model_version_details )