From 5664a432545422b5d8d95e11283dc974dc705f1b Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 20 Sep 2023 17:28:11 +0200
Subject: [PATCH] use partial update element

---
 arkindex_worker/worker/element.py           | 45 +++++++++++----------
 tests/test_elements_worker/test_elements.py | 44 +++++++++-----------
 2 files changed, 42 insertions(+), 47 deletions(-)

diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py
index a553c75a..fb42902f 100644
--- a/arkindex_worker/worker/element.py
+++ b/arkindex_worker/worker/element.py
@@ -276,17 +276,17 @@ class ElementMixin(object):
 
         return created_ids
 
-    def update_element(
-        self, element: Union[Element, CachedElement], type: str, name: str, **kwargs
+    def partial_update_element(
+        self, element: Union[Element, CachedElement], **kwargs
     ) -> dict:
         """
         Updates an element through the API.
 
         :param element: The element to update.
-        :param type: Slug type of the element.
-        :param name: Name of the element.
         :param **kwargs:
 
+            * *type* (``str``): Optional slug type of the element.
+            * *name* (``str``): Optional name of the element.
             * *polygon* (``list``): Optional polygon for this element
             * *confidence* (``float``): Optional confidence score of this element
             * *rotation_angle* (``int``): Optional rotation angle of this element
@@ -299,9 +299,14 @@ class ElementMixin(object):
         assert element and isinstance(
             element, (Element, CachedElement)
         ), "element shouldn't be null and should be an Element or CachedElement"
-        assert isinstance(type, str), "type should be a str"
-        assert isinstance(name, str), "name should be a str"
-        if polygon := kwargs.get("polygon"):
+
+        if (element_type := kwargs.get("type")) is not None:
+            assert isinstance(element_type, str), "type should be a str"
+
+        if (name := kwargs.get("name")) is not None:
+            assert isinstance(name, str), "name should be a str"
+
+        if (polygon := kwargs.get("polygon")) is not None:
             assert isinstance(polygon, list), "polygon should be a list"
             assert len(polygon) >= 3, "polygon should have at least three points"
             assert all(
@@ -311,20 +316,20 @@ class ElementMixin(object):
                 isinstance(coord, (int, float)) for point in polygon for coord in point
             ), "polygon points should be lists of two numbers"
 
-        if "confidence" in kwargs and (confidence := kwargs.get("confidence")):
+        if (confidence := kwargs.get("confidence")) is not None:
             assert (
                 isinstance(confidence, float) and 0 <= confidence <= 1
             ), "confidence should be None or a float in [0..1] range"
 
-        if rotation_angle := kwargs.get("rotation_angle"):
+        if (rotation_angle := kwargs.get("rotation_angle")) is not None:
             assert (
                 isinstance(rotation_angle, int) and rotation_angle >= 0
             ), "rotation_angle should be a positive integer"
 
-        if mirrored := kwargs.get("mirrored"):
+        if (mirrored := kwargs.get("mirrored")) is not None:
             assert isinstance(mirrored, bool), "mirrored should be a boolean"
 
-        if image := kwargs.get("image"):
+        if (image := kwargs.get("image")) is not None:
             assert isinstance(image, UUID), "image should be a UUID"
             # Cast to string
             kwargs["image"] = str(kwargs["image"])
@@ -333,27 +338,25 @@ class ElementMixin(object):
             logger.warning("Cannot update element as this worker is in read-only mode")
             return
 
-        payload = {"type": type, "name": name, **kwargs}
-
         updated_element = self.request(
-            "UpdateElement",
+            "PartialUpdateElement",
             id=element.id,
-            body=payload,
+            body=kwargs,
         )
 
         if self.use_cache:
             # Name is not present in CachedElement model
-            payload.pop("name")
+            kwargs.pop("name", None)
 
             # Stringify polygon if present
-            if "polygon" in payload:
-                payload["polygon"] = str(payload["polygon"])
+            if "polygon" in kwargs:
+                kwargs["polygon"] = str(kwargs["polygon"])
 
             # Retrieve the right image
-            if "image" in payload:
-                payload["image"] = CachedImage.get_by_id(payload["image"])
+            if "image" in kwargs:
+                kwargs["image"] = CachedImage.get_by_id(kwargs["image"])
 
-            CachedElement.update(**payload).where(
+            CachedElement.update(**kwargs).where(
                 CachedElement.id == element.id
             ).execute()
 
diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py
index e86bdcf5..1547bed2 100644
--- a/tests/test_elements_worker/test_elements.py
+++ b/tests/test_elements_worker/test_elements.py
@@ -1262,31 +1262,29 @@ def test_create_elements_integrity_error(
         ({"image": 1234}, "image should be a UUID"),
     ),
 )
-def test_update_element_wrong_param(mock_elements_worker, payload, error):
+def test_partial_update_element_wrong_param(mock_elements_worker, payload, error):
     api_payload = {
         "element": Element({"zone": None}),
-        "type": "type",
-        "name": "name",
         **payload,
     }
 
     with pytest.raises(AssertionError) as e:
-        mock_elements_worker.update_element(
+        mock_elements_worker.partial_update_element(
             **api_payload,
         )
     assert str(e.value) == error
 
 
-def test_update_element_api_error(responses, mock_elements_worker):
+def test_partial_update_element_api_error(responses, mock_elements_worker):
     elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     responses.add(
-        responses.PUT,
+        responses.PATCH,
         f"http://testserver/api/v1/element/{elt.id}/",
         status=500,
     )
 
     with pytest.raises(ErrorResponse):
-        mock_elements_worker.update_element(
+        mock_elements_worker.partial_update_element(
             element=elt,
             type="something",
             name="0",
@@ -1298,11 +1296,11 @@ def test_update_element_api_error(responses, mock_elements_worker):
         (call.request.method, call.request.url) for call in responses.calls
     ] == BASE_API_CALLS + [
         # We retry 5 times the API call
-        ("PUT", f"http://testserver/api/v1/element/{elt.id}/"),
-        ("PUT", f"http://testserver/api/v1/element/{elt.id}/"),
-        ("PUT", f"http://testserver/api/v1/element/{elt.id}/"),
-        ("PUT", f"http://testserver/api/v1/element/{elt.id}/"),
-        ("PUT", f"http://testserver/api/v1/element/{elt.id}/"),
+        ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"),
+        ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"),
+        ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"),
+        ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"),
+        ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"),
     ]
 
 
@@ -1331,7 +1329,7 @@ def test_update_element_api_error(responses, mock_elements_worker):
         ),
     ),
 )
-def test_update_element(
+def test_partial_update_element(
     responses,
     mock_elements_worker_with_cache,
     mock_cached_elements,
@@ -1342,20 +1340,18 @@ def test_update_element(
     new_image = CachedImage.select().first()
 
     elt_response = {
-        "type": "new type",
-        "name": "new name",
         "image": str(new_image.id),
         **payload,
     }
     responses.add(
-        responses.PUT,
+        responses.PATCH,
         f"http://testserver/api/v1/element/{elt.id}/",
         status=200,
         # UUID not allowed in JSON
         json=elt_response,
     )
 
-    element_update_response = mock_elements_worker_with_cache.update_element(
+    element_update_response = mock_elements_worker_with_cache.partial_update_element(
         element=elt,
         **{**elt_response, "image": new_image.id},
     )
@@ -1365,7 +1361,7 @@ def test_update_element(
         (call.request.method, call.request.url) for call in responses.calls
     ] == BASE_API_CALLS + [
         (
-            "PUT",
+            "PATCH",
             f"http://testserver/api/v1/element/{elt.id}/",
         ),
     ]
@@ -1374,7 +1370,6 @@ def test_update_element(
 
     cached_element = CachedElement.get(CachedElement.id == elt.id)
     # Always present in payload
-    assert cached_element.type == elt_response["type"]
     assert str(cached_element.image_id) == elt_response["image"]
     # Optional params
     if "polygon" in payload:
@@ -1385,24 +1380,22 @@ def test_update_element(
         assert getattr(cached_element, param) == elt_response[param]
 
 
-def test_update_element_confidence(
+def test_partial_update_element_confidence(
     responses, mock_elements_worker_with_cache, mock_cached_elements
 ):
     elt = CachedElement.select().first()
     elt_response = {
-        "type": "new type",
-        "name": "new name",
         "polygon": [[10, 10], [20, 20], [20, 10], [10, 20]],
         "confidence": 0.42,
     }
     responses.add(
-        responses.PUT,
+        responses.PATCH,
         f"http://testserver/api/v1/element/{elt.id}/",
         status=200,
         json=elt_response,
     )
 
-    element_update_response = mock_elements_worker_with_cache.update_element(
+    element_update_response = mock_elements_worker_with_cache.partial_update_element(
         element=elt,
         **elt_response,
     )
@@ -1412,7 +1405,7 @@ def test_update_element_confidence(
         (call.request.method, call.request.url) for call in responses.calls
     ] == BASE_API_CALLS + [
         (
-            "PUT",
+            "PATCH",
             f"http://testserver/api/v1/element/{elt.id}/",
         ),
     ]
@@ -1420,7 +1413,6 @@ def test_update_element_confidence(
     assert element_update_response == elt_response
 
     cached_element = CachedElement.get(CachedElement.id == elt.id)
-    assert cached_element.type == elt_response["type"]
     assert cached_element.polygon == str(elt_response["polygon"])
     assert cached_element.confidence == elt_response["confidence"]
 
-- 
GitLab