diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index e8c85381dd8d7e3291242e76129e7a878613224d..664de9465a97f1157dd44b41626d63b8b540961e 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -350,16 +350,34 @@ class ElementsWorker(BaseWorker): ) return - self.api_client.request( - "CreateClassification", - body={ - "element": element.id, - "ml_class": self.get_ml_class_id(element.corpus.id, ml_class), - "worker_version": self.worker_version_id, - "confidence": confidence, - "high_confidence": high_confidence, - }, - ) + try: + self.api_client.request( + "CreateClassification", + body={ + "element": element.id, + "ml_class": self.get_ml_class_id(element.corpus.id, ml_class), + "worker_version": self.worker_version_id, + "confidence": confidence, + "high_confidence": high_confidence, + }, + ) + except ErrorResponse as e: + + # Detect already existing classification + if ( + e.status_code == 400 + and "non_field_errors" in e.content + and "The fields element, worker_version, ml_class must make a unique set." + in e.content["non_field_errors"] + ): + logger.warning( + f"This worker version has already set {ml_class} on element {element.id}" + ) + return + + # Propagate any other API error + raise + self.report.add_classification(element.id, ml_class) def create_entity(self, element, name, type, corpus, metas=None, validated=None): diff --git a/tests/test_elements_worker.py b/tests/test_elements_worker.py index cd1c8cfca5066960b8b327b83141c68a45ad73ec..8d0bc4114537bc55e37e6cb9aeadeaaac83ad2d0 100644 --- a/tests/test_elements_worker.py +++ b/tests/test_elements_worker.py @@ -991,6 +991,57 @@ def test_create_classification(responses, mock_elements_worker): "high_confidence": True, } + # Classification has been created and reported + assert mock_elements_worker.report.report_data["elements"][elt.id][ + "classifications" + ] == {"a_class": 1} + + +def test_create_classification_duplicate(responses, mock_elements_worker): + mock_elements_worker.classes = { + "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} + } + elt = Element( + { + "id": "12341234-1234-1234-1234-123412341234", + "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, + } + ) + responses.add( + responses.POST, + "http://testserver/api/v1/classifications/", + status=400, + json={ + "non_field_errors": [ + "The fields element, worker_version, ml_class must make a unique set." + ] + }, + ) + + mock_elements_worker.create_classification( + element=elt, + ml_class="a_class", + confidence=0.42, + high_confidence=True, + ) + + assert len(responses.calls) == 2 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/classifications/", + ] + + assert json.loads(responses.calls[1].request.body) == { + "element": "12341234-1234-1234-1234-123412341234", + "ml_class": "0000", + "worker_version": "12341234-1234-1234-1234-123412341234", + "confidence": 0.42, + "high_confidence": True, + } + + # Classification has NOT been created + assert mock_elements_worker.report.report_data["elements"] == {} + def test_create_entity_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: