diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index aec222298765cd4d6d09408bb84fb036704752ee..289af78bd67041773f993244ab7c4b60c67fbc4c 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -10,6 +10,10 @@ from arkindex_worker.models import Element from . import BASE_API_CALLS +# Special string used to know if the `arg_name` passed in +# `pytest.mark.parametrize` should be removed from the payload +DELETE_PARAMETER = "DELETE_PARAMETER" + def test_get_ml_class_id_load_classes(responses, mock_elements_worker): corpus_id = "11111111-1111-1111-1111-111111111111" @@ -190,54 +194,116 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker): ] -def test_create_classification_wrong_element(mock_elements_worker): - with pytest.raises( - AssertionError, - match="element shouldn't be null and should be an Element or CachedElement", - ): - mock_elements_worker.create_classification( - element=None, - ml_class="a_class", - confidence=0.42, - high_confidence=True, - ) - - with pytest.raises( - AssertionError, - match="element shouldn't be null and should be an Element or CachedElement", - ): +@pytest.mark.parametrize( + ("arg_name", "data", "error_message"), + [ + # Wrong element + ( + "element", + None, + "element shouldn't be null and should be an Element or CachedElement", + ), + ( + "element", + "not element type", + "element shouldn't be null and should be an Element or CachedElement", + ), + # Wrong ml_class + ( + "ml_class", + None, + "ml_class shouldn't be null and should be of type str", + ), + ( + "ml_class", + 1234, + "ml_class shouldn't be null and should be of type str", + ), + # Wrong confidence + ( + "confidence", + None, + "confidence shouldn't be null and should be a float in [0..1] range", + ), + ( + "confidence", + "wrong confidence", + "confidence shouldn't be null and should be a float in [0..1] range", + ), + ( + "confidence", + 0, + "confidence shouldn't be null and should be a float in [0..1] range", + ), + ( + "confidence", + 2.00, + "confidence shouldn't be null and should be a float in [0..1] range", + ), + # Wrong high_confidence + ( + "high_confidence", + None, + "high_confidence shouldn't be null and should be of type bool", + ), + ( + "high_confidence", + "wrong high_confidence", + "high_confidence shouldn't be null and should be of type bool", + ), + ], +) +def test_create_classification_wrong_data( + arg_name, data, error_message, mock_elements_worker +): + mock_elements_worker.classes = {"a_class": "0000"} + with pytest.raises(AssertionError, match=re.escape(error_message)): mock_elements_worker.create_classification( - element="not element type", - ml_class="a_class", - confidence=0.42, - high_confidence=True, + **{ + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "ml_class": "a_class", + "confidence": 0.42, + "high_confidence": True, + # Overwrite with wrong data + arg_name: data, + } ) -def test_create_classification_wrong_ml_class(mock_elements_worker, responses): +def test_create_classification_api_error(responses, mock_elements_worker): + mock_elements_worker.classes = {"a_class": "0000"} elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + "http://testserver/api/v1/classifications/", + status=500, + ) - with pytest.raises( - AssertionError, match="ml_class shouldn't be null and should be of type str" - ): + with pytest.raises(ErrorResponse): mock_elements_worker.create_classification( element=elt, - ml_class=None, + ml_class="a_class", confidence=0.42, high_confidence=True, ) - with pytest.raises( - AssertionError, match="ml_class shouldn't be null and should be of type str" - ): - mock_elements_worker.create_classification( - element=elt, - ml_class=1234, - confidence=0.42, - high_confidence=True, - ) + assert len(responses.calls) == len(BASE_API_CALLS) + 5 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + # We retry 5 times the API call + ("POST", "http://testserver/api/v1/classifications/"), + ("POST", "http://testserver/api/v1/classifications/"), + ("POST", "http://testserver/api/v1/classifications/"), + ("POST", "http://testserver/api/v1/classifications/"), + ("POST", "http://testserver/api/v1/classifications/"), + ] + - # Automatically create a missing class ! +def test_create_classification_create_ml_class(mock_elements_worker, responses): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + # Automatically create a missing class! responses.add( responses.POST, "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", @@ -283,119 +349,6 @@ def test_create_classification_wrong_ml_class(mock_elements_worker, responses): ] -def test_create_classification_wrong_confidence(mock_elements_worker): - mock_elements_worker.classes = {"a_class": "0000"} - elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) - with pytest.raises( - AssertionError, - match=re.escape( - "confidence shouldn't be null and should be a float in [0..1] range" - ), - ): - mock_elements_worker.create_classification( - element=elt, - ml_class="a_class", - confidence=None, - high_confidence=True, - ) - - with pytest.raises( - AssertionError, - match=re.escape( - "confidence shouldn't be null and should be a float in [0..1] range" - ), - ): - mock_elements_worker.create_classification( - element=elt, - ml_class="a_class", - confidence="wrong confidence", - high_confidence=True, - ) - - with pytest.raises( - AssertionError, - match=re.escape( - "confidence shouldn't be null and should be a float in [0..1] range" - ), - ): - mock_elements_worker.create_classification( - element=elt, - ml_class="a_class", - confidence=0, - high_confidence=True, - ) - - with pytest.raises( - AssertionError, - match=re.escape( - "confidence shouldn't be null and should be a float in [0..1] range" - ), - ): - mock_elements_worker.create_classification( - element=elt, - ml_class="a_class", - confidence=2.00, - high_confidence=True, - ) - - -def test_create_classification_wrong_high_confidence(mock_elements_worker): - mock_elements_worker.classes = {"a_class": "0000"} - elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) - - with pytest.raises( - AssertionError, - match="high_confidence shouldn't be null and should be of type bool", - ): - mock_elements_worker.create_classification( - element=elt, - ml_class="a_class", - confidence=0.42, - high_confidence=None, - ) - - with pytest.raises( - AssertionError, - match="high_confidence shouldn't be null and should be of type bool", - ): - mock_elements_worker.create_classification( - element=elt, - ml_class="a_class", - confidence=0.42, - high_confidence="wrong high_confidence", - ) - - -def test_create_classification_api_error(responses, mock_elements_worker): - mock_elements_worker.classes = {"a_class": "0000"} - elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) - responses.add( - responses.POST, - "http://testserver/api/v1/classifications/", - status=500, - ) - - with pytest.raises(ErrorResponse): - mock_elements_worker.create_classification( - element=elt, - ml_class="a_class", - confidence=0.42, - high_confidence=True, - ) - - assert len(responses.calls) == len(BASE_API_CALLS) + 5 - assert [ - (call.request.method, call.request.url) for call in responses.calls - ] == BASE_API_CALLS + [ - # We retry 5 times the API call - ("POST", "http://testserver/api/v1/classifications/"), - ("POST", "http://testserver/api/v1/classifications/"), - ("POST", "http://testserver/api/v1/classifications/"), - ("POST", "http://testserver/api/v1/classifications/"), - ("POST", "http://testserver/api/v1/classifications/"), - ] - - def test_create_classification(responses, mock_elements_worker): mock_elements_worker.classes = {"a_class": "0000"} elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) @@ -519,278 +472,154 @@ def test_create_classification_duplicate_worker_run(responses, mock_elements_wor } -def test_create_classifications_wrong_element(mock_elements_worker): - with pytest.raises( - AssertionError, - match="element shouldn't be null and should be an Element or CachedElement", - ): - mock_elements_worker.create_classifications( - element=None, - classifications=[ - { - "ml_class_id": "uuid1", - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": "uuid2", - "confidence": 0.25, - "high_confidence": False, - }, - ], - ) - - with pytest.raises( - AssertionError, - match="element shouldn't be null and should be an Element or CachedElement", - ): - mock_elements_worker.create_classifications( - element="not element type", - classifications=[ - { - "ml_class_id": "uuid1", - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": "uuid2", - "confidence": 0.25, - "high_confidence": False, - }, - ], - ) - - -def test_create_classifications_wrong_classifications(mock_elements_worker): - elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) - - with pytest.raises( - AssertionError, - match="classifications shouldn't be null and should be of type list", - ): - mock_elements_worker.create_classifications( - element=elt, - classifications=None, - ) - - with pytest.raises( - AssertionError, - match="classifications shouldn't be null and should be of type list", - ): - mock_elements_worker.create_classifications( - element=elt, - classifications=1234, - ) - - with pytest.raises( - AssertionError, - match="Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str", - ): - mock_elements_worker.create_classifications( - element=elt, - classifications=[ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": 0.25, - "high_confidence": False, - }, - ], - ) - - with pytest.raises( - AssertionError, - match="Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str", - ): - mock_elements_worker.create_classifications( - element=elt, - classifications=[ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": None, - "confidence": 0.25, - "high_confidence": False, - }, - ], - ) - - with pytest.raises( - AssertionError, - match="Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str", - ): +@pytest.mark.parametrize( + ("arg_name", "data", "error_message"), + [ + ( + "element", + None, + "element shouldn't be null and should be an Element or CachedElement", + ), + ( + "element", + "not element type", + "element shouldn't be null and should be an Element or CachedElement", + ), + ( + "classifications", + None, + "classifications shouldn't be null and should be of type list", + ), + ( + "classifications", + 1234, + "classifications shouldn't be null and should be of type list", + ), + ], +) +def test_create_classifications_wrong_data( + arg_name, data, error_message, mock_elements_worker +): + with pytest.raises(AssertionError, match=error_message): mock_elements_worker.create_classifications( - element=elt, - classifications=[ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": 1234, - "confidence": 0.25, - "high_confidence": False, - }, - ], + **{ + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "classifications": [ + { + "ml_class_id": "uuid1", + "confidence": 0.75, + "high_confidence": False, + }, + { + "ml_class_id": "uuid2", + "confidence": 0.25, + "high_confidence": False, + }, + ], + # Overwrite with wrong data + arg_name: data, + }, ) - with pytest.raises( - ValueError, - match="Classification at index 1 in classifications: ml_class_id is not a valid uuid.", - ): - mock_elements_worker.create_classifications( - element=elt, - classifications=[ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": "not_an_uuid", - "confidence": 0.25, - "high_confidence": False, - }, - ], - ) - with pytest.raises( - AssertionError, - match=re.escape( - "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" +@pytest.mark.parametrize( + ("arg_name", "data", "error_message", "error_type"), + [ + # Wrong classifications > ml_class_id + ( + "ml_class_id", + DELETE_PARAMETER, + "ml_class_id shouldn't be null and should be of type str", + AssertionError, + ), # Updated + ( + "ml_class_id", + None, + "ml_class_id shouldn't be null and should be of type str", + AssertionError, ), - ): - mock_elements_worker.create_classifications( - element=elt, - classifications=[ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "high_confidence": False, - }, - ], - ) - - with pytest.raises( - AssertionError, - match=re.escape( - "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ( + "ml_class_id", + 1234, + "ml_class_id shouldn't be null and should be of type str", + AssertionError, ), - ): - mock_elements_worker.create_classifications( - element=elt, - classifications=[ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "confidence": None, - "high_confidence": False, - }, - ], - ) - - with pytest.raises( - AssertionError, - match=re.escape( - "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ( + "ml_class_id", + "not_an_uuid", + "ml_class_id is not a valid uuid.", + ValueError, ), - ): - mock_elements_worker.create_classifications( - element=elt, - classifications=[ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "confidence": "wrong confidence", - "high_confidence": False, - }, - ], - ) - - with pytest.raises( - AssertionError, - match=re.escape( - "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + # Wrong classifications > confidence + ( + "confidence", + DELETE_PARAMETER, + "confidence shouldn't be null and should be a float in [0..1] range", + AssertionError, ), - ): - mock_elements_worker.create_classifications( - element=elt, - classifications=[ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "confidence": 0, - "high_confidence": False, - }, - ], - ) - - with pytest.raises( - AssertionError, - match=re.escape( - "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ( + "confidence", + None, + "confidence shouldn't be null and should be a float in [0..1] range", + AssertionError, ), - ): - mock_elements_worker.create_classifications( - element=elt, - classifications=[ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "confidence": 2.00, - "high_confidence": False, - }, - ], - ) + ( + "confidence", + "wrong confidence", + "confidence shouldn't be null and should be a float in [0..1] range", + AssertionError, + ), + ( + "confidence", + 0, + "confidence shouldn't be null and should be a float in [0..1] range", + AssertionError, + ), + ( + "confidence", + 2.00, + "confidence shouldn't be null and should be a float in [0..1] range", + AssertionError, + ), + # Wrong classifications > high_confidence + ( + "high_confidence", + "wrong high_confidence", + "high_confidence should be of type bool", + AssertionError, + ), + ], +) +def test_create_classifications_wrong_classifications_data( + arg_name, data, error_message, error_type, mock_elements_worker +): + all_data = { + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "classifications": [ + { + "ml_class_id": str(uuid4()), + "confidence": 0.75, + "high_confidence": False, + }, + { + "ml_class_id": str(uuid4()), + "confidence": 0.25, + "high_confidence": False, + # Overwrite with wrong data + arg_name: data, + }, + ], + } + if data == DELETE_PARAMETER: + del all_data["classifications"][1][arg_name] with pytest.raises( - AssertionError, + error_type, match=re.escape( - "Classification at index 1 in classifications: high_confidence should be of type bool" + f"Classification at index 1 in classifications: {error_message}" ), ): - mock_elements_worker.create_classifications( - element=elt, - classifications=[ - { - "ml_class_id": str(uuid4()), - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": str(uuid4()), - "confidence": 0.25, - "high_confidence": "wrong high_confidence", - }, - ], - ) + mock_elements_worker.create_classifications(**all_data) def test_create_classifications_api_error(responses, mock_elements_worker):