diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index 5e57326efdd5f061f0266e98e9cb18f4b655f8d1..6d356c3902a1cfd179bde47dec361e654d5fe46a 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -131,3 +131,71 @@ class ClassificationMixin(object): raise self.report.add_classification(element.id, ml_class) + + def create_classifications(self, element, classifications): + """ + Create multiple classifications at once on the given element through the API + """ + assert element and isinstance( + element, (Element, CachedElement) + ), "element shouldn't be null and should be an Element or CachedElement" + assert classifications and isinstance( + classifications, list + ), "classifications shouldn't be null and should be of type list" + + for index, classification in enumerate(classifications): + class_name = classification.get("class_name") + assert class_name and isinstance( + class_name, str + ), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str" + + confidence = classification.get("confidence") + assert ( + confidence is not None + and isinstance(confidence, float) + and 0 <= confidence <= 1 + ), f"Classification at index {index} in classifications: confidence shouldn't be null and should be a float in [0..1] range" + + high_confidence = classification.get("high_confidence") + if high_confidence is not None: + assert isinstance( + high_confidence, bool + ), f"Classification at index {index} in classifications: high_confidence should be of type bool" + + if self.is_read_only: + logger.warning( + "Cannot create classifications as this worker is in read-only mode" + ) + return + + created_cls = self.request( + "CreateClassifications", + body={ + "parent": str(element.id), + "worker_version": self.worker_version_id, + "classifications": classifications, + }, + )["classifications"] + + for created_cl in created_cls: + self.report.add_classification(element.id, created_cl["class_name"]) + + if self.use_cache: + # Store classifications in local cache + try: + to_insert = [ + { + "id": created_cl["id"], + "element_id": element.id, + "class_name": created_cl["class_name"], + "confidence": created_cl["confidence"], + "state": created_cl["state"], + "worker_version_id": self.worker_version_id, + } + for created_cl in created_cls + ] + CachedClassification.insert_many(to_insert).execute() + except IntegrityError as e: + logger.warning( + f"Couldn't save created classifications in local cache: {e}" + ) diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 002baf8206540b3bc6e5dd3816079ed0a004f15e..2532c3591f0927ca2ed5d3b1c066bb9c3b04a41c 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -501,3 +501,373 @@ def test_create_classification_duplicate(responses, mock_elements_worker): # Classification has NOT been created assert mock_elements_worker.report.report_data["elements"] == {} + + +def test_create_classifications_wrong_element(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=None, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element="not element type", + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) + + +def test_create_classifications_wrong_classifications(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=None, + ) + assert ( + str(e.value) == "classifications shouldn't be null and should be of type list" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=1234, + ) + assert ( + str(e.value) == "classifications shouldn't be null and should be of type list" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "confidence": 0.25, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": None, + "confidence": 0.25, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": 1234, + "confidence": 0.25, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": None, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": "wrong confidence", + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 2.00, + "high_confidence": False, + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": "wrong high_confidence", + }, + ], + ) + assert ( + str(e.value) + == "Classification at index 1 in classifications: high_confidence should be of type bool" + ) + + +def test_create_classifications_api_error(responses, mock_elements_worker): + responses.add( + responses.POST, + "http://testserver/api/v1/classification/bulk/", + status=500, + ) + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + classes = [ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": False, + }, + ] + + with pytest.raises(ErrorResponse): + mock_elements_worker.create_classifications( + element=elt, classifications=classes + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 5 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + # We retry 5 times the API call + ("POST", "http://testserver/api/v1/classification/bulk/"), + ("POST", "http://testserver/api/v1/classification/bulk/"), + ("POST", "http://testserver/api/v1/classification/bulk/"), + ("POST", "http://testserver/api/v1/classification/bulk/"), + ("POST", "http://testserver/api/v1/classification/bulk/"), + ] + + +def test_create_classifications(responses, mock_elements_worker_with_cache): + elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") + classes = [ + { + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": False, + }, + ] + + responses.add( + responses.POST, + "http://testserver/api/v1/classification/bulk/", + status=200, + json={ + "parent": str(elt.id), + "worker_version": "12341234-1234-1234-1234-123412341234", + "classifications": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "class_name": "portrait", + "confidence": 0.75, + "high_confidence": False, + "state": "pending", + }, + { + "id": "11111111-1111-1111-1111-111111111111", + "class_name": "landscape", + "confidence": 0.25, + "high_confidence": False, + "state": "pending", + }, + ], + }, + ) + + mock_elements_worker_with_cache.create_classifications( + element=elt, classifications=classes + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ("POST", "http://testserver/api/v1/classification/bulk/"), + ] + + assert json.loads(responses.calls[-1].request.body) == { + "parent": str(elt.id), + "worker_version": "12341234-1234-1234-1234-123412341234", + "classifications": classes, + } + + # Check that created classifications were properly stored in SQLite cache + assert list(CachedClassification.select()) == [ + CachedClassification( + id=UUID("00000000-0000-0000-0000-000000000000"), + element_id=UUID(elt.id), + class_name="portrait", + confidence=0.75, + state="pending", + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ), + CachedClassification( + id=UUID("11111111-1111-1111-1111-111111111111"), + element_id=UUID(elt.id), + class_name="landscape", + confidence=0.25, + state="pending", + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ), + ]