diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 31ec96b63200a321a29c8628f200ef68e72ab98e..738de363046afc4582a2fd1cd448912fb306f387 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -10,6 +10,8 @@ from arkindex_worker.models import Element from . import BASE_API_CALLS +DELETE_PARAMETER = "DELETE_PARAMETER" + def test_get_ml_class_id_load_classes(responses, mock_elements_worker): corpus_id = "11111111-1111-1111-1111-111111111111" @@ -191,108 +193,79 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker): @pytest.mark.parametrize( - ("data", "error_message"), + ("arg_name", "data", "error_message"), [ # Wrong element ( - { - "element": None, - "ml_class": "a_class", - "confidence": 0.42, - "high_confidence": True, - }, + "element", + None, "element shouldn't be null and should be an Element or CachedElement", ), ( - { - "element": "not element type", - "ml_class": "a_class", - "confidence": 0.42, - "high_confidence": True, - }, + "element", + "not element type", "element shouldn't be null and should be an Element or CachedElement", ), # Wrong ml_class ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "ml_class": None, - "confidence": 0.42, - "high_confidence": True, - }, + "ml_class", + None, "ml_class shouldn't be null and should be of type str", ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "ml_class": 1234, - "confidence": 0.42, - "high_confidence": True, - }, + "ml_class", + 1234, "ml_class shouldn't be null and should be of type str", ), # Wrong confidence ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "ml_class": "a_class", - "confidence": None, - "high_confidence": True, - }, + "confidence", + None, "confidence shouldn't be null and should be a float in [0..1] range", ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "ml_class": "a_class", - "confidence": "wrong confidence", - "high_confidence": True, - }, + "confidence", + "wrong confidence", "confidence shouldn't be null and should be a float in [0..1] range", ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "ml_class": "a_class", - "confidence": 0, - "high_confidence": True, - }, + "confidence", + 0, "confidence shouldn't be null and should be a float in [0..1] range", ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "ml_class": "a_class", - "confidence": 2.00, - "high_confidence": True, - }, + "confidence", + 2.00, "confidence shouldn't be null and should be a float in [0..1] range", ), # Wrong high_confidence ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "ml_class": "a_class", - "confidence": 0.42, - "high_confidence": None, - }, + "high_confidence", + None, "high_confidence shouldn't be null and should be of type bool", ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "ml_class": "a_class", - "confidence": 0.42, - "high_confidence": "wrong high_confidence", - }, + "high_confidence", + "wrong high_confidence", "high_confidence shouldn't be null and should be of type bool", ), ], ) -def test_create_classification_wrong_data(data, error_message, mock_elements_worker): +def test_create_classification_wrong_data( + arg_name, data, error_message, mock_elements_worker +): mock_elements_worker.classes = {"a_class": "0000"} with pytest.raises(AssertionError, match=re.escape(error_message)): - mock_elements_worker.create_classification(**data) + mock_elements_worker.create_classification( + **{ + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "ml_class": "a_class", + "confidence": 0.42, + "high_confidence": True, + # Overwrite with wrong data + arg_name: data, + } + ) def test_create_classification_api_error(responses, mock_elements_worker): @@ -498,262 +471,153 @@ def test_create_classification_duplicate_worker_run(responses, mock_elements_wor @pytest.mark.parametrize( - ("data", "error_message", "error_type"), + ("arg_name", "data", "error_message"), [ - # Wrong element ( - { - "element": None, - "classifications": [ - { - "ml_class_id": "uuid1", - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": "uuid2", - "confidence": 0.25, - "high_confidence": False, - }, - ], - }, + "element", + None, "element shouldn't be null and should be an Element or CachedElement", - AssertionError, ), ( - { - "element": "not element type", - "classifications": [ - { - "ml_class_id": "uuid1", - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": "uuid2", - "confidence": 0.25, - "high_confidence": False, - }, - ], - }, + "element", + "not element type", "element shouldn't be null and should be an Element or CachedElement", - AssertionError, ), - # Wrong classifications ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": None, - }, + "classifications", + None, "classifications shouldn't be null and should be of type list", - AssertionError, ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": 1234, - }, + "classifications", + 1234, "classifications shouldn't be null and should be of type list", - AssertionError, ), - # Wrong classifications > ml_class_id - ( - { + ], +) +def test_create_classifications_wrong_data( + arg_name, data, error_message, mock_elements_worker +): + with pytest.raises(AssertionError, match=error_message): + mock_elements_worker.create_classifications( + **{ "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "classifications": [ { - "ml_class_id": str(uuid4()), + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "ml_class_id": 0.25, + "ml_class_id": "uuid2", + "confidence": 0.25, "high_confidence": False, }, ], + # Overwrite with wrong data + arg_name: data, }, - "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str", + ) + + +@pytest.mark.parametrize( + ("arg_name", "data", "error_message", "error_type"), + [ + # Wrong classifications > ml_class_id + ( + "ml_class_id", + DELETE_PARAMETER, + "ml_class_id shouldn't be null and should be of type str", AssertionError, - ), + ), # Updated ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": [ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": None, - "confidence": 0.25, - "high_confidence": False, - }, - ], - }, - "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str", + "ml_class_id", + None, + "ml_class_id shouldn't be null and should be of type str", AssertionError, ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": [ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": 1234, - "confidence": 0.25, - "high_confidence": False, - }, - ], - }, - "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str", + "ml_class_id", + 1234, + "ml_class_id shouldn't be null and should be of type str", AssertionError, ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": [ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": "not_an_uuid", - "confidence": 0.25, - "high_confidence": False, - }, - ], - }, - "Classification at index 1 in classifications: ml_class_id is not a valid uuid.", + "ml_class_id", + "not_an_uuid", + "ml_class_id is not a valid uuid.", ValueError, ), # Wrong classifications > confidence ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": [ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "high_confidence": False, - }, - ], - }, - "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range", + "confidence", + DELETE_PARAMETER, + "confidence shouldn't be null and should be a float in [0..1] range", AssertionError, ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": [ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "confidence": None, - "high_confidence": False, - }, - ], - }, - "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range", + "confidence", + None, + "confidence shouldn't be null and should be a float in [0..1] range", AssertionError, ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": [ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "confidence": "wrong confidence", - "high_confidence": False, - }, - ], - }, - "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range", + "confidence", + "wrong confidence", + "confidence shouldn't be null and should be a float in [0..1] range", AssertionError, ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": [ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "confidence": 0, - "high_confidence": False, - }, - ], - }, - "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range", + "confidence", + 0, + "confidence shouldn't be null and should be a float in [0..1] range", AssertionError, ), ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": [ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "confidence": 2.00, - "high_confidence": False, - }, - ], - }, - "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range", + "confidence", + 2.00, + "confidence shouldn't be null and should be a float in [0..1] range", AssertionError, ), # Wrong classifications > high_confidence ( - { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), - "classifications": [ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "confidence": 0.25, - "high_confidence": "wrong high_confidence", - }, - ], - }, - "Classification at index 1 in classifications: high_confidence should be of type bool", + "high_confidence", + "wrong high_confidence", + "high_confidence should be of type bool", AssertionError, ), ], ) -def test_create_classifications_wrong_data( - data, error_message, error_type, mock_elements_worker +def test_create_classifications_wrong_classifications_data( + arg_name, data, error_message, error_type, mock_elements_worker ): - with pytest.raises(error_type, match=re.escape(error_message)): - mock_elements_worker.create_classifications(**data) + all_data = { + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "classifications": [ + { + "ml_class_id": str(uuid4()), + "confidence": 0.75, + "high_confidence": False, + }, + { + "ml_class_id": str(uuid4()), + "confidence": 0.25, + "high_confidence": False, + # Overwrite with wrong data + arg_name: data, + }, + ], + } + if data == DELETE_PARAMETER: + del all_data["classifications"][1][arg_name] + + with pytest.raises( + error_type, + match=re.escape( + f"Classification at index 1 in classifications: {error_message}" + ), + ): + mock_elements_worker.create_classifications(**all_data) def test_create_classifications_api_error(responses, mock_elements_worker):