diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index 182927f3666801db3fd276efa74f10f8e6d29a1e..bb7048a5e6cc5ac177de6b4473524dcc89586cd3 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -120,6 +120,7 @@ class TrainingMixin(object): except ErrorResponse as e: if e.status_code >= 500: logger.error(f"Failed to create model version: {e.content}") + model_version_details = e.content.get("hash") # If the existing model is in Created state, this model is returned as a dict. # Else an error in a list is returned. diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index e1b64f4228bd824c7c9e009a6af0d53090b6fce2..2ab535aec82885c1518b9d443a9f1711236c0c6e 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -67,16 +67,19 @@ def test_create_model_version(): @pytest.mark.parametrize( "process_exception, status_code", [ - # (None, 200) ( ErrorResponse( - title="Mock error response", status_code=400, content="Bad gateway" + title="Mock error response", + status_code=400, + content="Mock error response", ), 400, ), ( ErrorResponse( - title="Mock error response", status_code=403, content="Bad gateway" + title="Mock error response", + status_code=403, + content="Mock error response", ), 403, ), @@ -92,21 +95,40 @@ def test_retrieve_created_model_version(process_exception, status_code): hash = "hash" archive_hash = "archive_hash" size = "30" + model_version_id = "fake_model_version_id" + model_version_details = { + "id": model_version_id, + "model_id": model_id, + "hash": hash, + "archive_hash": archive_hash, + "size": size, + "s3_url": "http://hehehe.com", + "s3_put_url": "http://hehehe.com", + } training.api_client.add_error_response( "CreateModelVersion", id=model_id, status_code=status_code, body={"hash": hash, "archive_hash": archive_hash, "size": size}, + response={"hash": model_version_details}, ) - if process_exception.status_code == status_code: - assert ( - training.create_model_version(model_id, hash, size, archive_hash) - == process_exception - ) + # training.create_model_version(model_id, hash, size, archive_hash) + # assert training.api_client.responses[0][1] == {"hash": model_version_details} - # with pytest.raises(Exception): + with pytest.raises(Exception) as e: + training.create_model_version(model_id, hash, size, archive_hash) + assert e.value + + # if process_exception.status_code == status_code: + # assert ( + # training.create_model_version(model_id, hash, size, archive_hash) + # == process_exception + # ) + # try: # training.create_model_version(model_id, hash, size, archive_hash) + # except Exception as e: + # assert e == process_exception # def test_retrieve_available_model_version():