From c7afdc56cfd9a4ec0c218606da4308a6552989bb Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Wed, 7 Sep 2022 15:51:03 +0200 Subject: [PATCH] do not add null field to the payload --- arkindex_worker/worker/training.py | 31 +++--- tests/test_elements_worker/test_training.py | 103 +++++++++++--------- 2 files changed, 75 insertions(+), 59 deletions(-) diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index 4f78a52f..ca6f52a0 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -81,7 +81,11 @@ class TrainingMixin(object): """ def publish_model_version( - self, model_path: DirPath, model_id: str, tag: str = None, description: str = "" + self, + model_path: DirPath, + model_id: str, + tag: str = None, + description: str = None, ): """ This method creates a model archive and its associated hash, @@ -144,30 +148,29 @@ class TrainingMixin(object): # Create a new model version with hash and size try: + payload = {"hash": hash, "size": size, "archive_hash": archive_hash} + if tag: + payload["tag"] = tag + if description: + payload["description"] = description model_version_details = self.request( "CreateModelVersion", id=model_id, - body={ - "hash": hash, - "size": size, - "archive_hash": archive_hash, - "tag": tag, - "description": description, - }, + body=payload, ) logger.info( f"Model version ({model_version_details['id']}) was created successfully" ) except ErrorResponse as e: - if e.status_code >= 500: + model_version_details = ( + e.content.get("hash") if hasattr(e, "content") else None + ) + if e.status_code >= 500 or model_version_details is None: logger.error(f"Failed to create model version: {e.content}") - - model_version_details = e.content.get("hash") + raise e # If the existing model is in Created state, this model is returned as a dict. # Else an error in a list is returned. - if model_version_details and isinstance( - model_version_details, (list, tuple) - ): + if isinstance(model_version_details, (list, tuple)): logger.error(model_version_details[0]) return diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index 739982d9..e8f41962 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -53,7 +53,7 @@ def test_create_archive(model_file_dir): ("", "description"), ("tag", ""), ("", ""), - (None, ""), + (None, None), ], ) def test_create_model_version(mock_training_worker, tag, description): @@ -76,17 +76,21 @@ def test_create_model_version(mock_training_worker, tag, description): "s3_put_url": "http://hehehe.com", } + expected_payload = { + "hash": model_hash, + "archive_hash": archive_hash, + "size": size, + } + if description: + expected_payload["description"] = description + if tag: + expected_payload["tag"] = tag + mock_training_worker.api_client.add_response( "CreateModelVersion", id=model_id, response=model_version_details, - body={ - "hash": model_hash, - "archive_hash": archive_hash, - "size": size, - "tag": tag, - "description": description, - }, + body=expected_payload, ) assert ( mock_training_worker.create_model_version( @@ -101,37 +105,19 @@ def test_create_model_version(mock_training_worker, tag, description): [ ( { - "hash": { - "id": "fake_model_version_id", - "model_id": "fake_model_id", - "hash": "hash", - "archive_hash": "archive_hash", - "size": "size", - "tag": "tag", - "description": "description", - "s3_url": "http://hehehe.com", - "s3_put_url": "http://hehehe.com", - } + "id": "fake_model_version_id", + "model_id": "fake_model_id", + "hash": "hash", + "archive_hash": "archive_hash", + "size": "size", + "tag": "tag", + "description": "description", + "s3_url": "http://hehehe.com", + "s3_put_url": "http://hehehe.com", }, 400, ), - ( - { - "hash": { - "id": "fake_model_version_id", - "model_id": "fake_model_id", - "hash": "hash", - "archive_hash": "archive_hash", - "size": "size", - "tag": None, - "description": "", - "s3_url": "http://hehehe.com", - "s3_put_url": "http://hehehe.com", - } - }, - 400, - ), - ({"hash": ["A version for this model with this hash already exists."]}, 403), + (["A version for this model with this hash already exists."], 403), ], ) def test_retrieve_created_model_version(mock_training_worker, content, status_code): @@ -141,13 +127,10 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co Else if an existing model version in Available mode, 403 was raised, but None will be returned """ - model_id = "fake_model_id" model_hash = "hash" archive_hash = "archive_hash" size = "30" - tag = "tag" - description = "description" mock_training_worker.api_client.add_error_response( "CreateModelVersion", id=model_id, @@ -156,27 +139,57 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co "hash": model_hash, "archive_hash": archive_hash, "size": size, - "tag": tag, - "description": description, }, - content=content, + content={"hash": content}, ) if status_code == 400: assert ( mock_training_worker.create_model_version( - model_id, model_hash, size, archive_hash, tag, description + model_id, model_hash, size, archive_hash, tag=None, description=None ) - == content["hash"] + == content ) elif status_code == 403: assert ( mock_training_worker.create_model_version( - model_id, model_hash, size, archive_hash, tag, description + model_id, model_hash, size, archive_hash, tag=None, description=None ) is None ) +@pytest.mark.parametrize( + "content, status_code", + ( + # error 500 + ({"id": "fake_id"}, 500), + # model_version details is None + ({}, 403), + (None, 403), + ), +) +def test_handle_500_create_model_version(mock_training_worker, content, status_code): + model_id = "fake_model_id" + model_hash = "hash" + archive_hash = "archive_hash" + size = "30" + mock_training_worker.api_client.add_error_response( + "CreateModelVersion", + id=model_id, + status_code=status_code, + body={ + "hash": model_hash, + "archive_hash": archive_hash, + "size": size, + }, + content=content, + ) + with pytest.raises(Exception): + mock_training_worker.create_model_version( + model_id, model_hash, size, archive_hash, tag=None, description=None + ) + + def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir): s3_endpoint_url = "http://s3.localhost.com" responses.add_passthru(s3_endpoint_url) -- GitLab