diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index bc9257b83ba0b3392624a9bdf732cc952b3b2e46..7e857e98743897ffc01bf15533a0ab56e4dfb86b 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -2,8 +2,6 @@ ElementsWorker methods for classifications and ML classes. """ -from uuid import UUID - from apistar.exceptions import ErrorResponse from peewee import IntegrityError @@ -178,10 +176,14 @@ class ClassificationMixin: Create multiple classifications at once on the given element through the API. :param element: The element to create classifications on. - :param classifications: The classifications to create, a list of dicts. Each of them contains - a **ml_class_id** (str), the ID of the MLClass for this classification; - a **confidence** (float), the confidence score, between 0 and 1; - a **high_confidence** (bool), the high confidence state of the classification. + :param classifications: A list of dicts representing a classification each, with the following keys: + + ml_class (str) + Required. Name of the MLClass to use. + confidence (float) + Required. Confidence score for the classification. Must be between 0 and 1. + high_confidence (bool) + Optional. Whether or not the classification is of high confidence. :returns: List of created classifications, as returned in the ``classifications`` field by the ``CreateClassifications`` API endpoint. @@ -194,18 +196,10 @@ class ClassificationMixin: ), "classifications shouldn't be null and should be of type list" for index, classification in enumerate(classifications): - ml_class_id = classification.get("ml_class_id") + ml_class = classification.get("ml_class") assert ( - ml_class_id and isinstance(ml_class_id, str) - ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str" - - # Make sure it's a valid UUID - try: - UUID(ml_class_id) - except ValueError as e: - raise ValueError( - f"Classification at index {index} in classifications: ml_class_id is not a valid uuid." - ) from e + ml_class and isinstance(ml_class, str) + ), f"Classification at index {index} in classifications: ml_class shouldn't be null and should be of type str" confidence = classification.get("confidence") assert ( @@ -231,7 +225,13 @@ class ClassificationMixin: body={ "parent": str(element.id), "worker_run_id": self.worker_run_id, - "classifications": classifications, + "classifications": [ + { + **classification, + "ml_class": self.get_ml_class_id(classification["ml_class"]), + } + for classification in classifications + ], }, )["classifications"] diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 289af78bd67041773f993244ab7c4b60c67fbc4c..451fd63452a0a9382f63cefdab5bba06efb876ac 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -1,6 +1,6 @@ import json import re -from uuid import UUID, uuid4 +from uuid import UUID import pytest from apistar.exceptions import ErrorResponse @@ -325,9 +325,6 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses): ) # Check a class & classification has been created - for call in responses.calls: - print(call.request.url, call.request.body) - assert [ (call.request.url, json.loads(call.request.body)) for call in responses.calls[-2:] @@ -506,12 +503,12 @@ def test_create_classifications_wrong_data( "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "classifications": [ { - "ml_class_id": "uuid1", + "ml_class": "cat", "confidence": 0.75, "high_confidence": False, }, { - "ml_class_id": "uuid2", + "ml_class": "dog", "confidence": 0.25, "high_confidence": False, }, @@ -523,86 +520,71 @@ def test_create_classifications_wrong_data( @pytest.mark.parametrize( - ("arg_name", "data", "error_message", "error_type"), + ("arg_name", "data", "error_message"), [ - # Wrong classifications > ml_class_id + # Wrong classifications > ml_class ( - "ml_class_id", + "ml_class", DELETE_PARAMETER, - "ml_class_id shouldn't be null and should be of type str", - AssertionError, - ), # Updated + "ml_class shouldn't be null and should be of type str", + ), ( - "ml_class_id", + "ml_class", None, - "ml_class_id shouldn't be null and should be of type str", - AssertionError, + "ml_class shouldn't be null and should be of type str", ), ( - "ml_class_id", + "ml_class", 1234, - "ml_class_id shouldn't be null and should be of type str", - AssertionError, - ), - ( - "ml_class_id", - "not_an_uuid", - "ml_class_id is not a valid uuid.", - ValueError, + "ml_class shouldn't be null and should be of type str", ), # Wrong classifications > confidence ( "confidence", DELETE_PARAMETER, "confidence shouldn't be null and should be a float in [0..1] range", - AssertionError, ), ( "confidence", None, "confidence shouldn't be null and should be a float in [0..1] range", - AssertionError, ), ( "confidence", "wrong confidence", "confidence shouldn't be null and should be a float in [0..1] range", - AssertionError, ), ( "confidence", 0, "confidence shouldn't be null and should be a float in [0..1] range", - AssertionError, ), ( "confidence", 2.00, "confidence shouldn't be null and should be a float in [0..1] range", - AssertionError, ), # Wrong classifications > high_confidence ( "high_confidence", "wrong high_confidence", "high_confidence should be of type bool", - AssertionError, ), ], ) def test_create_classifications_wrong_classifications_data( - arg_name, data, error_message, error_type, mock_elements_worker + arg_name, data, error_message, mock_elements_worker ): all_data = { "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "classifications": [ { - "ml_class_id": str(uuid4()), + "ml_class": "cat", "confidence": 0.75, "high_confidence": False, }, { - "ml_class_id": str(uuid4()), + "ml_class": "dog", "confidence": 0.25, "high_confidence": False, # Overwrite with wrong data @@ -614,7 +596,7 @@ def test_create_classifications_wrong_classifications_data( del all_data["classifications"][1][arg_name] with pytest.raises( - error_type, + AssertionError, match=re.escape( f"Classification at index 1 in classifications: {error_message}" ), @@ -623,6 +605,7 @@ def test_create_classifications_wrong_classifications_data( def test_create_classifications_api_error(responses, mock_elements_worker): + mock_elements_worker.classes = {"cat": "0000", "dog": "1111"} responses.add( responses.POST, "http://testserver/api/v1/classification/bulk/", @@ -631,12 +614,12 @@ def test_create_classifications_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) classes = [ { - "ml_class_id": str(uuid4()), + "ml_class": "cat", "confidence": 0.75, "high_confidence": False, }, { - "ml_class_id": str(uuid4()), + "ml_class": "dog", "confidence": 0.25, "high_confidence": False, }, @@ -660,57 +643,96 @@ def test_create_classifications_api_error(responses, mock_elements_worker): ] -def test_create_classifications(responses, mock_elements_worker_with_cache): - # Set MLClass in cache - portrait_uuid = str(uuid4()) - landscape_uuid = str(uuid4()) - mock_elements_worker_with_cache.classes = { - "portrait": portrait_uuid, - "landscape": landscape_uuid, - } - - elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") - classes = [ - { - "ml_class_id": portrait_uuid, - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": landscape_uuid, - "confidence": 0.25, - "high_confidence": False, - }, - ] +def test_create_classifications_create_ml_class(mock_elements_worker, responses): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + # Automatically create a missing class! + responses.add( + responses.POST, + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", + status=201, + json={"id": "new-ml-class-1234"}, + ) responses.add( responses.POST, "http://testserver/api/v1/classification/bulk/", - status=200, + status=201, json={ "parent": str(elt.id), "worker_run_id": "56785678-5678-5678-5678-567856785678", "classifications": [ { "id": "00000000-0000-0000-0000-000000000000", - "ml_class": portrait_uuid, + "ml_class": "new-ml-class-1234", "confidence": 0.75, "high_confidence": False, "state": "pending", }, - { - "id": "11111111-1111-1111-1111-111111111111", - "ml_class": landscape_uuid, - "confidence": 0.25, - "high_confidence": False, - "state": "pending", - }, ], }, ) + mock_elements_worker.classes = {"another_class": "0000"} + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "ml_class": "a_class", + "confidence": 0.75, + "high_confidence": False, + } + ], + ) - mock_elements_worker_with_cache.create_classifications( - element=elt, classifications=classes + # Check a class & classification has been created + assert len(responses.calls) == len(BASE_API_CALLS) + 2 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "POST", + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", + ), + ("POST", "http://testserver/api/v1/classification/bulk/"), + ] + + assert json.loads(responses.calls[-2].request.body) == {"name": "a_class"} + assert json.loads(responses.calls[-1].request.body) == { + "parent": "12341234-1234-1234-1234-123412341234", + "worker_run_id": "56785678-5678-5678-5678-567856785678", + "classifications": [ + { + "ml_class": "new-ml-class-1234", + "confidence": 0.75, + "high_confidence": False, + } + ], + } + + +def test_create_classifications(responses, mock_elements_worker): + mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"} + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + "http://testserver/api/v1/classification/bulk/", + status=200, + json={"classifications": []}, + ) + + mock_elements_worker.create_classifications( + element=elt, + classifications=[ + { + "ml_class": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "ml_class": "landscape", + "confidence": 0.25, + "high_confidence": False, + }, + ], ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 @@ -723,52 +745,24 @@ def test_create_classifications(responses, mock_elements_worker_with_cache): assert json.loads(responses.calls[-1].request.body) == { "parent": str(elt.id), "worker_run_id": "56785678-5678-5678-5678-567856785678", - "classifications": classes, + "classifications": [ + { + "confidence": 0.75, + "high_confidence": False, + "ml_class": "0000", + }, + { + "confidence": 0.25, + "high_confidence": False, + "ml_class": "1111", + }, + ], } - # Check that created classifications were properly stored in SQLite cache - assert list(CachedClassification.select()) == [ - CachedClassification( - id=UUID("00000000-0000-0000-0000-000000000000"), - element_id=UUID(elt.id), - class_name="portrait", - confidence=0.75, - state="pending", - worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), - ), - CachedClassification( - id=UUID("11111111-1111-1111-1111-111111111111"), - element_id=UUID(elt.id), - class_name="landscape", - confidence=0.25, - state="pending", - worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), - ), - ] - -def test_create_classifications_not_in_cache( - responses, mock_elements_worker_with_cache -): - """ - CreateClassifications using ID that are not in `.classes` attribute. - Will load corpus MLClass to insert the corresponding name in Cache. - """ - portrait_uuid = str(uuid4()) - landscape_uuid = str(uuid4()) +def test_create_classifications_with_cache(responses, mock_elements_worker_with_cache): + mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"} elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") - classes = [ - { - "ml_class_id": portrait_uuid, - "confidence": 0.75, - "high_confidence": False, - }, - { - "ml_class_id": landscape_uuid, - "confidence": 0.25, - "high_confidence": False, - }, - ] responses.add( responses.POST, @@ -780,14 +774,14 @@ def test_create_classifications_not_in_cache( "classifications": [ { "id": "00000000-0000-0000-0000-000000000000", - "ml_class": portrait_uuid, + "ml_class": "0000", "confidence": 0.75, "high_confidence": False, "state": "pending", }, { "id": "11111111-1111-1111-1111-111111111111", - "ml_class": landscape_uuid, + "ml_class": "1111", "confidence": 0.25, "high_confidence": False, "state": "pending", @@ -795,42 +789,45 @@ def test_create_classifications_not_in_cache( ], }, ) - responses.add( - responses.GET, - f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/", - status=200, - json={ - "count": 2, - "next": None, - "results": [ - { - "id": portrait_uuid, - "name": "portrait", - }, - {"id": landscape_uuid, "name": "landscape"}, - ], - }, - ) mock_elements_worker_with_cache.create_classifications( - element=elt, classifications=classes + element=elt, + classifications=[ + { + "ml_class": "portrait", + "confidence": 0.75, + "high_confidence": False, + }, + { + "ml_class": "landscape", + "confidence": 0.25, + "high_confidence": False, + }, + ], ) - assert len(responses.calls) == len(BASE_API_CALLS) + 2 + assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ("POST", "http://testserver/api/v1/classification/bulk/"), - ( - "GET", - f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/", - ), ] - assert json.loads(responses.calls[-2].request.body) == { + assert json.loads(responses.calls[-1].request.body) == { "parent": str(elt.id), "worker_run_id": "56785678-5678-5678-5678-567856785678", - "classifications": classes, + "classifications": [ + { + "confidence": 0.75, + "high_confidence": False, + "ml_class": "0000", + }, + { + "confidence": 0.25, + "high_confidence": False, + "ml_class": "1111", + }, + ], } # Check that created classifications were properly stored in SQLite cache