From c7afdc56cfd9a4ec0c218606da4308a6552989bb Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 7 Sep 2022 15:51:03 +0200
Subject: [PATCH] do not add null field to the payload

---
 arkindex_worker/worker/training.py          |  31 +++---
 tests/test_elements_worker/test_training.py | 103 +++++++++++---------
 2 files changed, 75 insertions(+), 59 deletions(-)

diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py
index 4f78a52f..ca6f52a0 100644
--- a/arkindex_worker/worker/training.py
+++ b/arkindex_worker/worker/training.py
@@ -81,7 +81,11 @@ class TrainingMixin(object):
     """
 
     def publish_model_version(
-        self, model_path: DirPath, model_id: str, tag: str = None, description: str = ""
+        self,
+        model_path: DirPath,
+        model_id: str,
+        tag: str = None,
+        description: str = None,
     ):
         """
         This method creates a model archive and its associated hash,
@@ -144,30 +148,29 @@ class TrainingMixin(object):
 
         # Create a new model version with hash and size
         try:
+            payload = {"hash": hash, "size": size, "archive_hash": archive_hash}
+            if tag:
+                payload["tag"] = tag
+            if description:
+                payload["description"] = description
             model_version_details = self.request(
                 "CreateModelVersion",
                 id=model_id,
-                body={
-                    "hash": hash,
-                    "size": size,
-                    "archive_hash": archive_hash,
-                    "tag": tag,
-                    "description": description,
-                },
+                body=payload,
             )
             logger.info(
                 f"Model version ({model_version_details['id']}) was created successfully"
             )
         except ErrorResponse as e:
-            if e.status_code >= 500:
+            model_version_details = (
+                e.content.get("hash") if hasattr(e, "content") else None
+            )
+            if e.status_code >= 500 or model_version_details is None:
                 logger.error(f"Failed to create model version: {e.content}")
-
-            model_version_details = e.content.get("hash")
+                raise e
             # If the existing model is in Created state, this model is returned as a dict.
             # Else an error in a list is returned.
-            if model_version_details and isinstance(
-                model_version_details, (list, tuple)
-            ):
+            if isinstance(model_version_details, (list, tuple)):
                 logger.error(model_version_details[0])
                 return
 
diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py
index 739982d9..e8f41962 100644
--- a/tests/test_elements_worker/test_training.py
+++ b/tests/test_elements_worker/test_training.py
@@ -53,7 +53,7 @@ def test_create_archive(model_file_dir):
         ("", "description"),
         ("tag", ""),
         ("", ""),
-        (None, ""),
+        (None, None),
     ],
 )
 def test_create_model_version(mock_training_worker, tag, description):
@@ -76,17 +76,21 @@ def test_create_model_version(mock_training_worker, tag, description):
         "s3_put_url": "http://hehehe.com",
     }
 
+    expected_payload = {
+        "hash": model_hash,
+        "archive_hash": archive_hash,
+        "size": size,
+    }
+    if description:
+        expected_payload["description"] = description
+    if tag:
+        expected_payload["tag"] = tag
+
     mock_training_worker.api_client.add_response(
         "CreateModelVersion",
         id=model_id,
         response=model_version_details,
-        body={
-            "hash": model_hash,
-            "archive_hash": archive_hash,
-            "size": size,
-            "tag": tag,
-            "description": description,
-        },
+        body=expected_payload,
     )
     assert (
         mock_training_worker.create_model_version(
@@ -101,37 +105,19 @@ def test_create_model_version(mock_training_worker, tag, description):
     [
         (
             {
-                "hash": {
-                    "id": "fake_model_version_id",
-                    "model_id": "fake_model_id",
-                    "hash": "hash",
-                    "archive_hash": "archive_hash",
-                    "size": "size",
-                    "tag": "tag",
-                    "description": "description",
-                    "s3_url": "http://hehehe.com",
-                    "s3_put_url": "http://hehehe.com",
-                }
+                "id": "fake_model_version_id",
+                "model_id": "fake_model_id",
+                "hash": "hash",
+                "archive_hash": "archive_hash",
+                "size": "size",
+                "tag": "tag",
+                "description": "description",
+                "s3_url": "http://hehehe.com",
+                "s3_put_url": "http://hehehe.com",
             },
             400,
         ),
-        (
-            {
-                "hash": {
-                    "id": "fake_model_version_id",
-                    "model_id": "fake_model_id",
-                    "hash": "hash",
-                    "archive_hash": "archive_hash",
-                    "size": "size",
-                    "tag": None,
-                    "description": "",
-                    "s3_url": "http://hehehe.com",
-                    "s3_put_url": "http://hehehe.com",
-                }
-            },
-            400,
-        ),
-        ({"hash": ["A version for this model with this hash already exists."]}, 403),
+        (["A version for this model with this hash already exists."], 403),
     ],
 )
 def test_retrieve_created_model_version(mock_training_worker, content, status_code):
@@ -141,13 +127,10 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co
     Else if an existing model version in Available mode,
     403 was raised, but None will be returned
     """
-
     model_id = "fake_model_id"
     model_hash = "hash"
     archive_hash = "archive_hash"
     size = "30"
-    tag = "tag"
-    description = "description"
     mock_training_worker.api_client.add_error_response(
         "CreateModelVersion",
         id=model_id,
@@ -156,27 +139,57 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co
             "hash": model_hash,
             "archive_hash": archive_hash,
             "size": size,
-            "tag": tag,
-            "description": description,
         },
-        content=content,
+        content={"hash": content},
     )
     if status_code == 400:
         assert (
             mock_training_worker.create_model_version(
-                model_id, model_hash, size, archive_hash, tag, description
+                model_id, model_hash, size, archive_hash, tag=None, description=None
             )
-            == content["hash"]
+            == content
         )
     elif status_code == 403:
         assert (
             mock_training_worker.create_model_version(
-                model_id, model_hash, size, archive_hash, tag, description
+                model_id, model_hash, size, archive_hash, tag=None, description=None
             )
             is None
         )
 
 
+@pytest.mark.parametrize(
+    "content, status_code",
+    (
+        # error 500
+        ({"id": "fake_id"}, 500),
+        # model_version details is None
+        ({}, 403),
+        (None, 403),
+    ),
+)
+def test_handle_500_create_model_version(mock_training_worker, content, status_code):
+    model_id = "fake_model_id"
+    model_hash = "hash"
+    archive_hash = "archive_hash"
+    size = "30"
+    mock_training_worker.api_client.add_error_response(
+        "CreateModelVersion",
+        id=model_id,
+        status_code=status_code,
+        body={
+            "hash": model_hash,
+            "archive_hash": archive_hash,
+            "size": size,
+        },
+        content=content,
+    )
+    with pytest.raises(Exception):
+        mock_training_worker.create_model_version(
+            model_id, model_hash, size, archive_hash, tag=None, description=None
+        )
+
+
 def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir):
     s3_endpoint_url = "http://s3.localhost.com"
     responses.add_passthru(s3_endpoint_url)
-- 
GitLab