From a2470c0f8f584d628841a6094f19ab0bb3dbdb64 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Fri, 15 Sep 2023 17:49:54 +0200
Subject: [PATCH] Call UpdateElement, update tests, add missing fields

---
 arkindex_worker/worker/element.py           |  74 +++++----
 tests/conftest.py                           |  13 ++
 tests/test_elements_worker/test_elements.py | 175 +++++++++-----------
 3 files changed, 138 insertions(+), 124 deletions(-)

diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py
index a64002d5..30c443ac 100644
--- a/arkindex_worker/worker/element.py
+++ b/arkindex_worker/worker/element.py
@@ -3,6 +3,7 @@
 ElementsWorker methods for elements and element types.
 """
 from typing import Dict, Iterable, List, NamedTuple, Optional, Union
+from uuid import UUID
 
 from peewee import IntegrityError
 
@@ -276,32 +277,25 @@ class ElementMixin(object):
         return created_ids
 
     def update_element(
-        self,
-        element: Union[Element, CachedElement],
-        type: Optional[str] = None,
-        name: Optional[str] = None,
-        polygon: Optional[List[List[Union[int, float]]]] = None,
-        confidence: Optional[float] = None,
+        self, element: Union[Element, CachedElement], type: str, name: str, **kwargs
     ) -> dict:
         """
-        Partially update an element through the API.
+        Updates an element through the API.
 
         :param element: The element to update.
         :param type: Optional new slug type of the element.
         :param name: Optional new name of the element.
         :param polygon: Optional new polygon of the element.
         :param confidence: Optional new confidence score, between 0.0 and 1.0.
-        :returns: A dict from the ``PartialUpdateElement`` API endpoint,
+        :returns: A dict from the ``UpdateElement`` API endpoint,
         """
         assert element and isinstance(
             element, (Element, CachedElement)
         ), "element shouldn't be null and should be an Element or CachedElement"
-        assert type is None or isinstance(type, str), "type should be None or a str"
-        assert name is None or isinstance(name, str), "name should be None or a str"
-        assert polygon is None or isinstance(
-            polygon, list
-        ), "polygon should be None or a list"
-        if polygon:
+        assert isinstance(type, str), "type should be a str"
+        assert isinstance(name, str), "name should be a str"
+        if polygon := kwargs.get("polygon"):
+            assert isinstance(polygon, list), "polygon should be a list"
             assert len(polygon) >= 3, "polygon should have at least three points"
             assert all(
                 isinstance(point, list) and len(point) == 2 for point in polygon
@@ -309,33 +303,51 @@ class ElementMixin(object):
             assert all(
                 isinstance(coord, (int, float)) for point in polygon for coord in point
             ), "polygon points should be lists of two numbers"
-        assert confidence is None or (
-            isinstance(confidence, float) and 0 <= confidence <= 1
-        ), "confidence should be None or a float in [0..1] range"
+
+        if "confidence" in kwargs and (confidence := kwargs.get("confidence")):
+            assert (
+                isinstance(confidence, float) and 0 <= confidence <= 1
+            ), "confidence should be None or a float in [0..1] range"
+
+        if rotation_angle := kwargs.get("rotation_angle"):
+            assert (
+                isinstance(rotation_angle, int) and rotation_angle >= 0
+            ), "rotation_angle should be a positive integer"
+
+        if mirrored := kwargs.get("mirrored"):
+            assert isinstance(mirrored, bool), "mirrored should be a boolean"
+
+        if image := kwargs.get("image"):
+            assert isinstance(image, UUID), "image should be a UUID"
+            # Cast to string
+            kwargs["image"] = str(kwargs["image"])
 
         if self.is_read_only:
             logger.warning("Cannot update element as this worker is in read-only mode")
             return
 
+        payload = {"type": type, "name": name, **kwargs}
+
         updated_element = self.request(
-            "PartialUpdateElement",
+            "UpdateElement",
             id=element.id,
-            body={
-                "type": type,
-                "name": name,
-                "polygon": polygon,
-                "confidence": confidence,
-            },
+            body=payload,
         )
 
         if self.use_cache:
-            CachedElement.update(
-                {
-                    CachedElement.type: type,
-                    CachedElement.polygon: str(polygon),
-                    CachedElement.confidence: confidence,
-                }
-            ).where(CachedElement.id == element.id).execute()
+            # Name is not present in CachedElement model
+            payload.pop("name")
+
+            # Stringify polygon if present
+            if "polygon" in payload:
+                payload["polygon"] = str(payload["polygon"])
+
+            # Retrieve the right image
+            payload["image"] = CachedImage.get_by_id(payload["image"])
+
+            CachedElement.update(**payload).where(
+                CachedElement.id == element.id
+            ).execute()
 
         return updated_element
 
diff --git a/tests/conftest.py b/tests/conftest.py
index 7b613c76..d41e242d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -16,6 +16,7 @@ from arkindex_worker.cache import (
     MODELS,
     SQL_VERSION,
     CachedElement,
+    CachedImage,
     CachedTranscription,
     Version,
     create_version_table,
@@ -363,6 +364,18 @@ def mock_cached_elements():
     assert CachedElement.select().count() == 5
 
 
+@pytest.fixture
+def mock_cached_images():
+    """Insert few elements in local cache"""
+    CachedImage.create(
+        id=UUID("99999999-9999-9999-9999-999999999999"),
+        width=1250,
+        height=2500,
+        url="http://testserver/iiif/3/image",
+    )
+    assert CachedImage.select().count() == 1
+
+
 @pytest.fixture
 def mock_cached_transcriptions():
     """Insert few transcriptions in local cache, on a shared element"""
diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py
index 273e41ba..7f80bd45 100644
--- a/tests/test_elements_worker/test_elements.py
+++ b/tests/test_elements_worker/test_elements.py
@@ -1210,97 +1210,77 @@ def test_create_elements_integrity_error(
     assert list(CachedElement.select()) == []
 
 
-def test_update_element_wrong_element(mock_elements_worker):
-    with pytest.raises(AssertionError) as e:
-        mock_elements_worker.update_element(
-            element=None,
-        )
-    assert (
-        str(e.value)
-        == "element shouldn't be null and should be an Element or CachedElement"
-    )
-
-    with pytest.raises(AssertionError) as e:
-        mock_elements_worker.update_element(
-            element="not element type",
-        )
-    assert (
-        str(e.value)
-        == "element shouldn't be null and should be an Element or CachedElement"
-    )
-
-
-def test_update_element_wrong_type(mock_elements_worker):
-    with pytest.raises(AssertionError) as e:
-        mock_elements_worker.update_element(
-            element=Element({"zone": None}),
-            type=1234,
-        )
-    assert str(e.value) == "type should be None or a str"
-
-
-def test_update_element_wrong_name(mock_elements_worker):
-    with pytest.raises(AssertionError) as e:
-        mock_elements_worker.update_element(
-            element=Element({"zone": None}),
-            name=1234,
-        )
-    assert str(e.value) == "name should be None or a str"
-
-
-def test_update_element_wrong_polygon(mock_elements_worker):
-    elt = Element({"zone": None})
-
-    with pytest.raises(AssertionError) as e:
-        mock_elements_worker.update_element(
-            element=elt,
-            polygon="not a polygon",
-        )
-    assert str(e.value) == "polygon should be None or a list"
-
-    with pytest.raises(AssertionError) as e:
-        mock_elements_worker.update_element(
-            element=elt,
-            polygon=[[1, 1], [2, 2]],
-        )
-    assert str(e.value) == "polygon should have at least three points"
-
-    with pytest.raises(AssertionError) as e:
-        mock_elements_worker.update_element(
-            element=elt,
-            polygon=[[1, 1, 1], [2, 2, 1], [2, 1, 1], [1, 2, 1]],
-        )
-    assert str(e.value) == "polygon points should be lists of two items"
-
-    with pytest.raises(AssertionError) as e:
-        mock_elements_worker.update_element(
-            element=elt,
-            polygon=[[1], [2], [2], [1]],
-        )
-    assert str(e.value) == "polygon points should be lists of two items"
-
-    with pytest.raises(AssertionError) as e:
-        mock_elements_worker.update_element(
-            element=elt,
-            polygon=[["not a coord", 1], [2, 2], [2, 1], [1, 2]],
-        )
-    assert str(e.value) == "polygon points should be lists of two numbers"
-
+@pytest.mark.parametrize(
+    "payload, error",
+    (
+        # Element
+        (
+            {"element": None},
+            "element shouldn't be null and should be an Element or CachedElement",
+        ),
+        (
+            {"element": "not element type"},
+            "element shouldn't be null and should be an Element or CachedElement",
+        ),
+        # Type
+        ({"type": 1234}, "type should be a str"),
+        # Name
+        ({"name": 1234}, "name should be a str"),
+        # Polygon
+        ({"polygon": "not a polygon"}, "polygon should be a list"),
+        ({"polygon": [[1, 1], [2, 2]]}, "polygon should have at least three points"),
+        (
+            {"polygon": [[1, 1, 1], [2, 2, 1], [2, 1, 1], [1, 2, 1]]},
+            "polygon points should be lists of two items",
+        ),
+        (
+            {"polygon": [[1], [2], [2], [1]]},
+            "polygon points should be lists of two items",
+        ),
+        (
+            {"polygon": [["not a coord", 1], [2, 2], [2, 1], [1, 2]]},
+            "polygon points should be lists of two numbers",
+        ),
+        # Confidence
+        ({"confidence": "lol"}, "confidence should be None or a float in [0..1] range"),
+        ({"confidence": "0.2"}, "confidence should be None or a float in [0..1] range"),
+        ({"confidence": -1.0}, "confidence should be None or a float in [0..1] range"),
+        ({"confidence": 1.42}, "confidence should be None or a float in [0..1] range"),
+        (
+            {"confidence": float("inf")},
+            "confidence should be None or a float in [0..1] range",
+        ),
+        # Rotation angle
+        ({"rotation_angle": "lol"}, "rotation_angle should be a positive integer"),
+        ({"rotation_angle": -1}, "rotation_angle should be a positive integer"),
+        ({"rotation_angle": 0.5}, "rotation_angle should be a positive integer"),
+        # Mirrored
+        ({"mirrored": "lol"}, "mirrored should be a boolean"),
+        ({"mirrored": 1234}, "mirrored should be a boolean"),
+        # Image
+        ({"image": "lol"}, "image should be a UUID"),
+        ({"image": 1234}, "image should be a UUID"),
+    ),
+)
+def test_update_element_wrong_param(mock_elements_worker, payload, error):
+    api_payload = {
+        "element": Element({"zone": None}),
+        "type": "type",
+        "name": "name",
+        **payload,
+    }
 
-@pytest.mark.parametrize("confidence", ["lol", "0.2", -1.0, 1.42, float("inf")])
-def test_update_element_wrong_confidence(mock_elements_worker, confidence):
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.update_element(
-            element=Element({"zone": None}),
-            confidence=confidence,
+            **api_payload,
         )
-    assert str(e.value) == "confidence should be None or a float in [0..1] range"
+    assert str(e.value) == error
 
 
 def test_update_element_api_error(responses, mock_elements_worker):
     elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     responses.add(
-        responses.PATCH,
+        responses.PUT,
         f"http://testserver/api/v1/element/{elt.id}/",
         status=500,
     )
@@ -1318,34 +1298,40 @@ def test_update_element_api_error(responses, mock_elements_worker):
         (call.request.method, call.request.url) for call in responses.calls
     ] == BASE_API_CALLS + [
         # We retry 5 times the API call
-        ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"),
-        ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"),
-        ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"),
-        ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"),
-        ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"),
+        ("PUT", f"http://testserver/api/v1/element/{elt.id}/"),
+        ("PUT", f"http://testserver/api/v1/element/{elt.id}/"),
+        ("PUT", f"http://testserver/api/v1/element/{elt.id}/"),
+        ("PUT", f"http://testserver/api/v1/element/{elt.id}/"),
+        ("PUT", f"http://testserver/api/v1/element/{elt.id}/"),
     ]
 
 
 def test_update_element(
-    responses, mock_elements_worker_with_cache, mock_cached_elements
+    responses, mock_elements_worker_with_cache, mock_cached_elements, mock_cached_images
 ):
     elt = CachedElement.select().first()
+    new_image = CachedImage.select().first()
+
     elt_response = {
         "type": "new type",
         "name": "new name",
         "polygon": [[10, 10], [20, 20], [20, 10], [10, 20]],
         "confidence": None,
+        "rotation_angle": 45,
+        "mirrored": False,
+        "image": str(new_image.id),
     }
     responses.add(
-        responses.PATCH,
+        responses.PUT,
         f"http://testserver/api/v1/element/{elt.id}/",
         status=200,
+        # UUID not allowed in JSON
         json=elt_response,
     )
 
     element_update_response = mock_elements_worker_with_cache.update_element(
         element=elt,
-        **elt_response,
+        **{**elt_response, "image": new_image.id},
     )
 
     assert len(responses.calls) == len(BASE_API_CALLS) + 1
@@ -1353,7 +1339,7 @@ def test_update_element(
         (call.request.method, call.request.url) for call in responses.calls
     ] == BASE_API_CALLS + [
         (
-            "PATCH",
+            "PUT",
             f"http://testserver/api/v1/element/{elt.id}/",
         ),
     ]
@@ -1364,6 +1350,9 @@ def test_update_element(
     assert cached_element.type == elt_response["type"]
     assert cached_element.polygon == str(elt_response["polygon"])
     assert cached_element.confidence == elt_response["confidence"]
+    assert cached_element.rotation_angle == elt_response["rotation_angle"]
+    assert cached_element.mirrored == elt_response["mirrored"]
+    assert str(cached_element.image_id) == elt_response["image"]
 
 
 def test_update_element_confidence(
@@ -1377,7 +1366,7 @@ def test_update_element_confidence(
         "confidence": 0.42,
     }
     responses.add(
-        responses.PATCH,
+        responses.PUT,
         f"http://testserver/api/v1/element/{elt.id}/",
         status=200,
         json=elt_response,
@@ -1393,7 +1382,7 @@ def test_update_element_confidence(
         (call.request.method, call.request.url) for call in responses.calls
     ] == BASE_API_CALLS + [
         (
-            "PATCH",
+            "PUT",
             f"http://testserver/api/v1/element/{elt.id}/",
         ),
     ]
-- 
GitLab