From 0b2ee59611c7df3218ceeccb57eca00205a23fc4 Mon Sep 17 00:00:00 2001 From: NolanB <nboukachab@teklia.com> Date: Wed, 17 Aug 2022 17:51:09 +0200 Subject: [PATCH] Commit for help fixing tests after add TrainingWorker class --- arkindex_worker/worker/training.py | 1 + tests/test_elements_worker/test_training.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index 88bebe38..182927f3 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -128,6 +128,7 @@ class TrainingMixin(object): ): logger.error(model_version_details[0]) return + return model_version_details def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None: diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index 4661ac13..c0905ae4 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -5,10 +5,15 @@ from pathlib import Path import responses from responses import matchers -from arkindex.mock import MockApiClient +from arkindex_worker.worker import BaseWorker from arkindex_worker.worker.training import TrainingMixin, create_archive +class TrainingWorker(BaseWorker, TrainingMixin): + def __init__(self): + super().setup_api_client() + + def test_create_archive_folder(): model_file_dir = Path("tests/samples/model_files") @@ -28,15 +33,13 @@ def test_create_archive_folder(): def test_create_model_version(): - api_client = MockApiClient() - """A new model version is returned""" model_id = "fake_model_id" model_version_id = "fake_model_version_id" # Create a model archive and keep its hash and size. model_files_dir = Path("tests/samples/model_files") # model_file_path = model_files_dir / "model_file.pth" - training = TrainingMixin() + training = TrainingWorker() with create_archive(path=model_files_dir) as ( zst_archive_path, hash, @@ -66,7 +69,7 @@ def test_create_model_version(): ) assert ( - training.create_model_version(api_client, model_id, hash, size, archive_hash) + training.create_model_version(model_id, hash, size, archive_hash) == model_version_details ) -- GitLab