From 6b7ba8d860cec6e272151afbe71482db60264470 Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Wed, 21 Jul 2021 09:15:34 +0000
Subject: [PATCH] Add create_classifications method calling
 CreateClassifications endpoint

---
 arkindex_worker/worker/classification.py      |  68 ++++
 .../test_classifications.py                   | 370 ++++++++++++++++++
 2 files changed, 438 insertions(+)

diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py
index 5e57326e..6d356c39 100644
--- a/arkindex_worker/worker/classification.py
+++ b/arkindex_worker/worker/classification.py
@@ -131,3 +131,71 @@ class ClassificationMixin(object):
             raise
 
         self.report.add_classification(element.id, ml_class)
+
+    def create_classifications(self, element, classifications):
+        """
+        Create multiple classifications at once on the given element through the API
+        """
+        assert element and isinstance(
+            element, (Element, CachedElement)
+        ), "element shouldn't be null and should be an Element or CachedElement"
+        assert classifications and isinstance(
+            classifications, list
+        ), "classifications shouldn't be null and should be of type list"
+
+        for index, classification in enumerate(classifications):
+            class_name = classification.get("class_name")
+            assert class_name and isinstance(
+                class_name, str
+            ), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str"
+
+            confidence = classification.get("confidence")
+            assert (
+                confidence is not None
+                and isinstance(confidence, float)
+                and 0 <= confidence <= 1
+            ), f"Classification at index {index} in classifications: confidence shouldn't be null and should be a float in [0..1] range"
+
+            high_confidence = classification.get("high_confidence")
+            if high_confidence is not None:
+                assert isinstance(
+                    high_confidence, bool
+                ), f"Classification at index {index} in classifications: high_confidence should be of type bool"
+
+        if self.is_read_only:
+            logger.warning(
+                "Cannot create classifications as this worker is in read-only mode"
+            )
+            return
+
+        created_cls = self.request(
+            "CreateClassifications",
+            body={
+                "parent": str(element.id),
+                "worker_version": self.worker_version_id,
+                "classifications": classifications,
+            },
+        )["classifications"]
+
+        for created_cl in created_cls:
+            self.report.add_classification(element.id, created_cl["class_name"])
+
+        if self.use_cache:
+            # Store classifications in local cache
+            try:
+                to_insert = [
+                    {
+                        "id": created_cl["id"],
+                        "element_id": element.id,
+                        "class_name": created_cl["class_name"],
+                        "confidence": created_cl["confidence"],
+                        "state": created_cl["state"],
+                        "worker_version_id": self.worker_version_id,
+                    }
+                    for created_cl in created_cls
+                ]
+                CachedClassification.insert_many(to_insert).execute()
+            except IntegrityError as e:
+                logger.warning(
+                    f"Couldn't save created classifications in local cache: {e}"
+                )
diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py
index 002baf82..2532c359 100644
--- a/tests/test_elements_worker/test_classifications.py
+++ b/tests/test_elements_worker/test_classifications.py
@@ -501,3 +501,373 @@ def test_create_classification_duplicate(responses, mock_elements_worker):
 
     # Classification has NOT been created
     assert mock_elements_worker.report.report_data["elements"] == {}
+
+
+def test_create_classifications_wrong_element(mock_elements_worker):
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=None,
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "class_name": "landscape",
+                    "confidence": 0.25,
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "element shouldn't be null and should be an Element or CachedElement"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element="not element type",
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "class_name": "landscape",
+                    "confidence": 0.25,
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "element shouldn't be null and should be an Element or CachedElement"
+    )
+
+
+def test_create_classifications_wrong_classifications(mock_elements_worker):
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=None,
+        )
+    assert (
+        str(e.value) == "classifications shouldn't be null and should be of type list"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=1234,
+        )
+    assert (
+        str(e.value) == "classifications shouldn't be null and should be of type list"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "confidence": 0.25,
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "class_name": None,
+                    "confidence": 0.25,
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "class_name": 1234,
+                    "confidence": 0.25,
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "class_name": "landscape",
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "class_name": "landscape",
+                    "confidence": None,
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "class_name": "landscape",
+                    "confidence": "wrong confidence",
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "class_name": "landscape",
+                    "confidence": 0,
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "class_name": "landscape",
+                    "confidence": 2.00,
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=[
+                {
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "class_name": "landscape",
+                    "confidence": 0.25,
+                    "high_confidence": "wrong high_confidence",
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Classification at index 1 in classifications: high_confidence should be of type bool"
+    )
+
+
+def test_create_classifications_api_error(responses, mock_elements_worker):
+    responses.add(
+        responses.POST,
+        "http://testserver/api/v1/classification/bulk/",
+        status=500,
+    )
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
+    classes = [
+        {
+            "class_name": "portrait",
+            "confidence": 0.75,
+            "high_confidence": False,
+        },
+        {
+            "class_name": "landscape",
+            "confidence": 0.25,
+            "high_confidence": False,
+        },
+    ]
+
+    with pytest.raises(ErrorResponse):
+        mock_elements_worker.create_classifications(
+            element=elt, classifications=classes
+        )
+
+    assert len(responses.calls) == len(BASE_API_CALLS) + 5
+    assert [
+        (call.request.method, call.request.url) for call in responses.calls
+    ] == BASE_API_CALLS + [
+        # We retry 5 times the API call
+        ("POST", "http://testserver/api/v1/classification/bulk/"),
+        ("POST", "http://testserver/api/v1/classification/bulk/"),
+        ("POST", "http://testserver/api/v1/classification/bulk/"),
+        ("POST", "http://testserver/api/v1/classification/bulk/"),
+        ("POST", "http://testserver/api/v1/classification/bulk/"),
+    ]
+
+
+def test_create_classifications(responses, mock_elements_worker_with_cache):
+    elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
+    classes = [
+        {
+            "class_name": "portrait",
+            "confidence": 0.75,
+            "high_confidence": False,
+        },
+        {
+            "class_name": "landscape",
+            "confidence": 0.25,
+            "high_confidence": False,
+        },
+    ]
+
+    responses.add(
+        responses.POST,
+        "http://testserver/api/v1/classification/bulk/",
+        status=200,
+        json={
+            "parent": str(elt.id),
+            "worker_version": "12341234-1234-1234-1234-123412341234",
+            "classifications": [
+                {
+                    "id": "00000000-0000-0000-0000-000000000000",
+                    "class_name": "portrait",
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                    "state": "pending",
+                },
+                {
+                    "id": "11111111-1111-1111-1111-111111111111",
+                    "class_name": "landscape",
+                    "confidence": 0.25,
+                    "high_confidence": False,
+                    "state": "pending",
+                },
+            ],
+        },
+    )
+
+    mock_elements_worker_with_cache.create_classifications(
+        element=elt, classifications=classes
+    )
+
+    assert len(responses.calls) == len(BASE_API_CALLS) + 1
+    assert [
+        (call.request.method, call.request.url) for call in responses.calls
+    ] == BASE_API_CALLS + [
+        ("POST", "http://testserver/api/v1/classification/bulk/"),
+    ]
+
+    assert json.loads(responses.calls[-1].request.body) == {
+        "parent": str(elt.id),
+        "worker_version": "12341234-1234-1234-1234-123412341234",
+        "classifications": classes,
+    }
+
+    # Check that created classifications were properly stored in SQLite cache
+    assert list(CachedClassification.select()) == [
+        CachedClassification(
+            id=UUID("00000000-0000-0000-0000-000000000000"),
+            element_id=UUID(elt.id),
+            class_name="portrait",
+            confidence=0.75,
+            state="pending",
+            worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
+        ),
+        CachedClassification(
+            id=UUID("11111111-1111-1111-1111-111111111111"),
+            element_id=UUID(elt.id),
+            class_name="landscape",
+            confidence=0.25,
+            state="pending",
+            worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
+        ),
+    ]
-- 
GitLab