From 0b2ee59611c7df3218ceeccb57eca00205a23fc4 Mon Sep 17 00:00:00 2001
From: NolanB <nboukachab@teklia.com>
Date: Wed, 17 Aug 2022 17:51:09 +0200
Subject: [PATCH] Commit for help fixing tests after add TrainingWorker class

---
 arkindex_worker/worker/training.py          |  1 +
 tests/test_elements_worker/test_training.py | 13 ++++++++-----
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py
index 88bebe38..182927f3 100644
--- a/arkindex_worker/worker/training.py
+++ b/arkindex_worker/worker/training.py
@@ -128,6 +128,7 @@ class TrainingMixin(object):
             ):
                 logger.error(model_version_details[0])
                 return
+
         return model_version_details
 
     def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None:
diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py
index 4661ac13..c0905ae4 100644
--- a/tests/test_elements_worker/test_training.py
+++ b/tests/test_elements_worker/test_training.py
@@ -5,10 +5,15 @@ from pathlib import Path
 import responses
 from responses import matchers
 
-from arkindex.mock import MockApiClient
+from arkindex_worker.worker import BaseWorker
 from arkindex_worker.worker.training import TrainingMixin, create_archive
 
 
+class TrainingWorker(BaseWorker, TrainingMixin):
+    def __init__(self):
+        super().setup_api_client()
+
+
 def test_create_archive_folder():
     model_file_dir = Path("tests/samples/model_files")
 
@@ -28,15 +33,13 @@ def test_create_archive_folder():
 
 
 def test_create_model_version():
-    api_client = MockApiClient()
-
     """A new model version is returned"""
     model_id = "fake_model_id"
     model_version_id = "fake_model_version_id"
     # Create a model archive and keep its hash and size.
     model_files_dir = Path("tests/samples/model_files")
     # model_file_path = model_files_dir / "model_file.pth"
-    training = TrainingMixin()
+    training = TrainingWorker()
     with create_archive(path=model_files_dir) as (
         zst_archive_path,
         hash,
@@ -66,7 +69,7 @@ def test_create_model_version():
     )
 
     assert (
-        training.create_model_version(api_client, model_id, hash, size, archive_hash)
+        training.create_model_version(model_id, hash, size, archive_hash)
         == model_version_details
     )
 
-- 
GitLab