diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 002baf8206540b3bc6e5dd3816079ed0a004f15e..2532c3591f0927ca2ed5d3b1c066bb9c3b04a41c 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -501,3 +501,373 @@ def test_create_classification_duplicate(responses, mock_elements_worker): # Classification has NOT been created assert mock_elements_worker.report.report_data["elements"] == {} + + +def test_create_classifications_wrong_element(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=None, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element="not element type", + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) + + +def test_create_classifications_wrong_classifications(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=None, + ) + assert ( + str(e.value) == "classifications shouldn't be null and should be of type list" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=1234, + ) + assert ( + str(e.value) == "classifications shouldn't be null and should be of type list" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "confidence": 0.25, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": None, + "confidence": 0.25, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": 1234, + "confidence": 0.25, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": None, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": "wrong confidence", + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 2.00, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": "wrong high_confidence", + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: high_confidence should be of type bool" + ) + + +def test_create_classifications_api_error(responses, mock_elements_worker): + responses.add( + responses.POST, + "http://testserver/api/v1/classification/bulk/", + status=500, + ) + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + classes = [ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": False, + }, + ] + + with pytest.raises(ErrorResponse): + mock_elements_worker.create_classifications( + element=elt, classifications=classes + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 5 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + # We retry 5 times the API call + ("POST", "http://testserver/api/v1/classification/bulk/"), + ("POST", "http://testserver/api/v1/classification/bulk/"), + ("POST", "http://testserver/api/v1/classification/bulk/"), + ("POST", "http://testserver/api/v1/classification/bulk/"), + ("POST", "http://testserver/api/v1/classification/bulk/"), + ] + + +def test_create_classifications(responses, mock_elements_worker_with_cache): + elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") + classes = [ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": False, + }, + ] + + responses.add( + responses.POST, + "http://testserver/api/v1/classification/bulk/", + status=200, + json={ + "parent": str(elt.id), + "worker_version": "12341234-1234-1234-1234-123412341234", + "classifications": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + "state": "pending", + }, + { + "id": "11111111-1111-1111-1111-111111111111", + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": False, + "state": "pending", + }, + ], + }, + ) + + mock_elements_worker_with_cache.create_classifications( + element=elt, classifications=classes + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ("POST", "http://testserver/api/v1/classification/bulk/"), + ] + + assert json.loads(responses.calls[-1].request.body) == { + "parent": str(elt.id), + "worker_version": "12341234-1234-1234-1234-123412341234", + "classifications": classes, + } + + # Check that created classifications were properly stored in SQLite cache + assert list(CachedClassification.select()) == [ + CachedClassification( + id=UUID("00000000-0000-0000-0000-000000000000"), + element_id=UUID(elt.id), + class_name="portrait", + confidence=0.75, + state="pending", + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ), + CachedClassification( + id=UUID("11111111-1111-1111-1111-111111111111"), + element_id=UUID(elt.id), + class_name="landscape", + confidence=0.25, + state="pending", + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ), + ]