diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index 2ab535aec82885c1518b9d443a9f1711236c0c6e..690c3db71ad979d1a9a850d030bc911ceb6f562c 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -1,17 +1,13 @@ # -*- coding: utf-8 -*- -# import json import os import pytest import responses -from apistar.exceptions import ErrorResponse from arkindex.mock import MockApiClient from arkindex_worker.worker import BaseWorker from arkindex_worker.worker.training import TrainingMixin, create_archive -# from responses import matchers - class TrainingWorker(BaseWorker, TrainingMixin): pass @@ -65,27 +61,24 @@ def test_create_model_version(): @pytest.mark.parametrize( - "process_exception, status_code", + "model_version_details, status_code", [ ( - ErrorResponse( - title="Mock error response", - status_code=400, - content="Mock error response", - ), + { + "id": "fake_model_version_id", + "model_id": "fake_model_id", + "hash": "hash", + "archive_hash": "archive_hash", + "size": "size", + "s3_url": "http://hehehe.com", + "s3_put_url": "http://hehehe.com", + }, 400, ), - ( - ErrorResponse( - title="Mock error response", - status_code=403, - content="Mock error response", - ), - 403, - ), + ({"hash": ["A version for this model with this hash already exists."]}, 403), ], ) -def test_retrieve_created_model_version(process_exception, status_code): +def test_retrieve_created_model_version(model_version_details, status_code): """There is an existing model version in Created mode, A 400 was raised. But the model is still returned in error content """ @@ -95,70 +88,24 @@ def test_retrieve_created_model_version(process_exception, status_code): hash = "hash" archive_hash = "archive_hash" size = "30" - model_version_id = "fake_model_version_id" - model_version_details = { - "id": model_version_id, - "model_id": model_id, - "hash": hash, - "archive_hash": archive_hash, - "size": size, - "s3_url": "http://hehehe.com", - "s3_put_url": "http://hehehe.com", - } training.api_client.add_error_response( "CreateModelVersion", id=model_id, status_code=status_code, body={"hash": hash, "archive_hash": archive_hash, "size": size}, - response={"hash": model_version_details}, + content={"hash": model_version_details}, ) - # training.create_model_version(model_id, hash, size, archive_hash) - # assert training.api_client.responses[0][1] == {"hash": model_version_details} - - with pytest.raises(Exception) as e: - training.create_model_version(model_id, hash, size, archive_hash) - assert e.value - - # if process_exception.status_code == status_code: - # assert ( - # training.create_model_version(model_id, hash, size, archive_hash) - # == process_exception - # ) - # try: - # training.create_model_version(model_id, hash, size, archive_hash) - # except Exception as e: - # assert e == process_exception - - -# def test_retrieve_available_model_version(): -# """Raise error when there is an existing model version in Available mode""" -# model_id = "fake_model_id" -# model_version_id = "fake_model_version_id" -# training = TrainingWorker() -# training.api_client = MockApiClient() -# hash = "hash" -# archive_hash = "archive_hash" -# size = "30" -# model_version_details = { -# "id": model_version_id, -# "model_id": model_id, -# "hash": hash, -# "archive_hash": archive_hash, -# "size": size, -# "s3_url": "http://hehehe.com", -# "s3_put_url": "http://hehehe.com", -# } - -# training.api_client.add_error_response( -# "CreateModelVersion", -# id=model_id, -# status_code=403, -# body={"hash": hash, "archive_hash": archive_hash, "size": size} -# ) - -# with pytest.raises(Exception): -# training.create_model_version(model_id, hash, size, archive_hash) + if status_code == 400: + assert ( + training.create_model_version(model_id, hash, size, archive_hash) + == model_version_details + ) + elif status_code == 403: + assert ( + training.create_model_version(model_id, hash, size, archive_hash) + == model_version_details + ) def test_handle_s3_uploading_errors(model_file_dir):