diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index 520e9435db5f6042e47b844d18aaa9aeaf861385..72eb646a5690145dfe7e26e6d4fd17f53f66afa8 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -64,6 +64,30 @@ class ClassificationMixin(object): return ml_class_id + def retrieve_ml_class(self, ml_class_id: str) -> str: + """ + Retrieve the name of the MLClass from its ID. + + :param ml_class_id: ID of the searched MLClass. + :return: The MLClass's name + """ + # Load the corpus' MLclasses if they are not available yet + if self.corpus_id not in self.classes: + self.load_corpus_classes() + + # Filter classes by this ml_class_id + ml_class_name = next( + filter( + lambda x: self.classes[self.corpus_id][x] == ml_class_id, + self.classes[self.corpus_id], + ), + None, + ) + assert ( + ml_class_name is not None + ), f"Missing class with id ({ml_class_id}) in corpus ({self.corpus_id})" + return ml_class_name + def create_classification( self, element: Union[Element, CachedElement], @@ -97,7 +121,6 @@ class ClassificationMixin(object): "Cannot create classification as this worker is in read-only mode" ) return - try: created = self.request( "CreateClassification", @@ -166,7 +189,7 @@ class ClassificationMixin(object): :param element: The element to create classifications on. :param classifications: The classifications to create, a list of dicts. Each of them contains - a **class_name** (str), the name of the MLClass for this classification; + a **ml_class_id** (str), the ID of the MLClass for this classification; a **confidence** (float), the confidence score, between 0 and 1; a **high_confidence** (bool), the high confidence state of the classification. @@ -181,10 +204,10 @@ class ClassificationMixin(object): ), "classifications shouldn't be null and should be of type list" for index, classification in enumerate(classifications): - class_name = classification.get("class_name") - assert class_name and isinstance( - class_name, str - ), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str" + ml_class_id = classification.get("ml_class_id") + assert ml_class_id and isinstance( + ml_class_id, str + ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str" confidence = classification.get("confidence") assert ( @@ -215,6 +238,7 @@ class ClassificationMixin(object): )["classifications"] for created_cl in created_cls: + created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"]) self.report.add_classification(element.id, created_cl["class_name"]) if self.use_cache: @@ -224,7 +248,7 @@ class ClassificationMixin(object): { "id": created_cl["id"], "element_id": element.id, - "class_name": created_cl["class_name"], + "class_name": created_cl.pop("class_name"), "confidence": created_cl["confidence"], "state": created_cl["state"], "worker_run_id": self.worker_run_id, diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 05b9cf1f51285f67de48a85c8e0126f1862efdc4..8a1089971def9d2470c3e4f45aae8c136f90b51d 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -162,6 +162,46 @@ def test_get_ml_class_reload(responses, mock_elements_worker): ] +def test_retrieve_ml_class_in_cache(mock_elements_worker): + """ + Look for a class that exists in cache -> No API Call + """ + mock_elements_worker.classes[mock_elements_worker.corpus_id] = {"class1": "uuid1"} + + assert mock_elements_worker.retrieve_ml_class("uuid1") == "class1" + + +def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker): + """ + Retrieve class not in cache -> Retrieve corpus ml classes via API + """ + responses.add( + responses.GET, + f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/", + status=200, + json={ + "count": 1, + "next": None, + "results": [ + { + "id": "uuid1", + "name": "class1", + }, + ], + }, + ) + assert mock_elements_worker.retrieve_ml_class("uuid1") == "class1" + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "GET", + f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/", + ), + ] + + def test_create_classification_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: mock_elements_worker.create_classification( @@ -520,12 +560,12 @@ def test_create_classifications_wrong_element(mock_elements_worker): element=None, classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": "landscape", + "ml_class_id": "uuid2", "confidence": 0.25, "high_confidence": False, }, @@ -541,12 +581,12 @@ def test_create_classifications_wrong_element(mock_elements_worker): element="not element type", classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": "landscape", + "ml_class_id": "uuid2", "confidence": 0.25, "high_confidence": False, }, @@ -584,19 +624,19 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): element=elt, classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "confidence": 0.25, + "ml_class_id": 0.25, "high_confidence": False, }, ], ) assert ( str(e.value) - == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str" + == "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str" ) with pytest.raises(AssertionError) as e: @@ -604,12 +644,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): element=elt, classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": None, + "ml_class_id": None, "confidence": 0.25, "high_confidence": False, }, @@ -617,7 +657,7 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): ) assert ( str(e.value) - == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str" + == "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str" ) with pytest.raises(AssertionError) as e: @@ -625,12 +665,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): element=elt, classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": 1234, + "ml_class_id": 1234, "confidence": 0.25, "high_confidence": False, }, @@ -638,7 +678,7 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): ) assert ( str(e.value) - == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str" + == "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str" ) with pytest.raises(AssertionError) as e: @@ -646,12 +686,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): element=elt, classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": "landscape", + "ml_class_id": "uuid2", "high_confidence": False, }, ], @@ -666,12 +706,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): element=elt, classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": "landscape", + "ml_class_id": "uuid2", "confidence": None, "high_confidence": False, }, @@ -687,12 +727,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): element=elt, classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": "landscape", + "ml_class_id": "uuid2", "confidence": "wrong confidence", "high_confidence": False, }, @@ -708,12 +748,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): element=elt, classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": "landscape", + "ml_class_id": "uuid2", "confidence": 0, "high_confidence": False, }, @@ -729,12 +769,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): element=elt, classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": "landscape", + "ml_class_id": "uuid2", "confidence": 2.00, "high_confidence": False, }, @@ -750,12 +790,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker): element=elt, classifications=[ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": "landscape", + "ml_class_id": "uuid2", "confidence": 0.25, "high_confidence": "wrong high_confidence", }, @@ -776,12 +816,12 @@ def test_create_classifications_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) classes = [ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": "landscape", + "ml_class_id": "uuid2", "confidence": 0.25, "high_confidence": False, }, @@ -806,15 +846,20 @@ def test_create_classifications_api_error(responses, mock_elements_worker): def test_create_classifications(responses, mock_elements_worker_with_cache): + # Set MLClass in cache + mock_elements_worker_with_cache.classes[ + mock_elements_worker_with_cache.corpus_id + ] = {"portrait": "uuid1", "landscape": "uuid2"} + elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") classes = [ { - "class_name": "portrait", + "ml_class_id": "uuid1", "confidence": 0.75, "high_confidence": False, }, { - "class_name": "landscape", + "ml_class_id": "uuid2", "confidence": 0.25, "high_confidence": False, }, @@ -830,14 +875,14 @@ def test_create_classifications(responses, mock_elements_worker_with_cache): "classifications": [ { "id": "00000000-0000-0000-0000-000000000000", - "class_name": "portrait", + "ml_class": "uuid1", "confidence": 0.75, "high_confidence": False, "state": "pending", }, { "id": "11111111-1111-1111-1111-111111111111", - "class_name": "landscape", + "ml_class": "uuid2", "confidence": 0.25, "high_confidence": False, "state": "pending", @@ -882,3 +927,108 @@ def test_create_classifications(responses, mock_elements_worker_with_cache): worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), ), ] + + +def test_create_classifications_not_in_cache( + responses, mock_elements_worker_with_cache +): + """ + CreateClassifications using ID that are not in `.classes` attribute. + Will load corpus MLClass to insert the corresponding name in Cache. + """ + elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") + classes = [ + { + "ml_class_id": "uuid1", + "confidence": 0.75, + "high_confidence": False, + }, + { + "ml_class_id": "uuid2", + "confidence": 0.25, + "high_confidence": False, + }, + ] + + responses.add( + responses.POST, + "http://testserver/api/v1/classification/bulk/", + status=200, + json={ + "parent": str(elt.id), + "worker_run_id": "56785678-5678-5678-5678-567856785678", + "classifications": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "ml_class": "uuid1", + "confidence": 0.75, + "high_confidence": False, + "state": "pending", + }, + { + "id": "11111111-1111-1111-1111-111111111111", + "ml_class": "uuid2", + "confidence": 0.25, + "high_confidence": False, + "state": "pending", + }, + ], + }, + ) + responses.add( + responses.GET, + f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/", + status=200, + json={ + "count": 2, + "next": None, + "results": [ + { + "id": "uuid1", + "name": "portrait", + }, + {"id": "uuid2", "name": "landscape"}, + ], + }, + ) + + mock_elements_worker_with_cache.create_classifications( + element=elt, classifications=classes + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 2 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ("POST", "http://testserver/api/v1/classification/bulk/"), + ( + "GET", + f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/", + ), + ] + + assert json.loads(responses.calls[-2].request.body) == { + "parent": str(elt.id), + "worker_run_id": "56785678-5678-5678-5678-567856785678", + "classifications": classes, + } + + # Check that created classifications were properly stored in SQLite cache + assert list(CachedClassification.select()) == [ + CachedClassification( + id=UUID("00000000-0000-0000-0000-000000000000"), + element_id=UUID(elt.id), + class_name="portrait", + confidence=0.75, + state="pending", + worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), + ), + CachedClassification( + id=UUID("11111111-1111-1111-1111-111111111111"), + element_id=UUID(elt.id), + class_name="landscape", + confidence=0.25, + state="pending", + worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), + ), + ]