From 2b22fe696fe80a0e04888d19410ec5e020a93eb6 Mon Sep 17 00:00:00 2001 From: NolanB <nboukachab@teklia.com> Date: Thu, 18 Aug 2022 16:42:56 +0200 Subject: [PATCH] Commit for help fixing tests --- arkindex_worker/worker/training.py | 1 + tests/test_elements_worker/test_training.py | 40 ++++++++++++++++----- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index 182927f3..bb7048a5 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -120,6 +120,7 @@ class TrainingMixin(object): except ErrorResponse as e: if e.status_code >= 500: logger.error(f"Failed to create model version: {e.content}") + model_version_details = e.content.get("hash") # If the existing model is in Created state, this model is returned as a dict. # Else an error in a list is returned. diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index e1b64f42..2ab535ae 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -67,16 +67,19 @@ def test_create_model_version(): @pytest.mark.parametrize( "process_exception, status_code", [ - # (None, 200) ( ErrorResponse( - title="Mock error response", status_code=400, content="Bad gateway" + title="Mock error response", + status_code=400, + content="Mock error response", ), 400, ), ( ErrorResponse( - title="Mock error response", status_code=403, content="Bad gateway" + title="Mock error response", + status_code=403, + content="Mock error response", ), 403, ), @@ -92,21 +95,40 @@ def test_retrieve_created_model_version(process_exception, status_code): hash = "hash" archive_hash = "archive_hash" size = "30" + model_version_id = "fake_model_version_id" + model_version_details = { + "id": model_version_id, + "model_id": model_id, + "hash": hash, + "archive_hash": archive_hash, + "size": size, + "s3_url": "http://hehehe.com", + "s3_put_url": "http://hehehe.com", + } training.api_client.add_error_response( "CreateModelVersion", id=model_id, status_code=status_code, body={"hash": hash, "archive_hash": archive_hash, "size": size}, + response={"hash": model_version_details}, ) - if process_exception.status_code == status_code: - assert ( - training.create_model_version(model_id, hash, size, archive_hash) - == process_exception - ) + # training.create_model_version(model_id, hash, size, archive_hash) + # assert training.api_client.responses[0][1] == {"hash": model_version_details} - # with pytest.raises(Exception): + with pytest.raises(Exception) as e: + training.create_model_version(model_id, hash, size, archive_hash) + assert e.value + + # if process_exception.status_code == status_code: + # assert ( + # training.create_model_version(model_id, hash, size, archive_hash) + # == process_exception + # ) + # try: # training.create_model_version(model_id, hash, size, archive_hash) + # except Exception as e: + # assert e == process_exception # def test_retrieve_available_model_version(): -- GitLab