diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index 690c3db71ad979d1a9a850d030bc911ceb6f562c..7e76cc707df9157e3f9e86569fe3399fb9b4e4de 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -61,24 +61,26 @@ def test_create_model_version(): @pytest.mark.parametrize( - "model_version_details, status_code", + "content, status_code", [ ( { - "id": "fake_model_version_id", - "model_id": "fake_model_id", - "hash": "hash", - "archive_hash": "archive_hash", - "size": "size", - "s3_url": "http://hehehe.com", - "s3_put_url": "http://hehehe.com", + "hash": { + "id": "fake_model_version_id", + "model_id": "fake_model_id", + "hash": "hash", + "archive_hash": "archive_hash", + "size": "size", + "s3_url": "http://hehehe.com", + "s3_put_url": "http://hehehe.com", + } }, 400, ), ({"hash": ["A version for this model with this hash already exists."]}, 403), ], ) -def test_retrieve_created_model_version(model_version_details, status_code): +def test_retrieve_created_model_version(content, status_code): """There is an existing model version in Created mode, A 400 was raised. But the model is still returned in error content """ @@ -93,18 +95,18 @@ def test_retrieve_created_model_version(model_version_details, status_code): id=model_id, status_code=status_code, body={"hash": hash, "archive_hash": archive_hash, "size": size}, - content={"hash": model_version_details}, + content=content, ) - if status_code == 400: assert ( training.create_model_version(model_id, hash, size, archive_hash) - == model_version_details + == content["hash"] ) elif status_code == 403: + none_value = None assert ( training.create_model_version(model_id, hash, size, archive_hash) - == model_version_details + == none_value )