From 5664a432545422b5d8d95e11283dc974dc705f1b Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Wed, 20 Sep 2023 17:28:11 +0200 Subject: [PATCH] use partial update element --- arkindex_worker/worker/element.py | 45 +++++++++++---------- tests/test_elements_worker/test_elements.py | 44 +++++++++----------- 2 files changed, 42 insertions(+), 47 deletions(-) diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py index a553c75a..fb42902f 100644 --- a/arkindex_worker/worker/element.py +++ b/arkindex_worker/worker/element.py @@ -276,17 +276,17 @@ class ElementMixin(object): return created_ids - def update_element( - self, element: Union[Element, CachedElement], type: str, name: str, **kwargs + def partial_update_element( + self, element: Union[Element, CachedElement], **kwargs ) -> dict: """ Updates an element through the API. :param element: The element to update. - :param type: Slug type of the element. - :param name: Name of the element. :param **kwargs: + * *type* (``str``): Optional slug type of the element. + * *name* (``str``): Optional name of the element. * *polygon* (``list``): Optional polygon for this element * *confidence* (``float``): Optional confidence score of this element * *rotation_angle* (``int``): Optional rotation angle of this element @@ -299,9 +299,14 @@ class ElementMixin(object): assert element and isinstance( element, (Element, CachedElement) ), "element shouldn't be null and should be an Element or CachedElement" - assert isinstance(type, str), "type should be a str" - assert isinstance(name, str), "name should be a str" - if polygon := kwargs.get("polygon"): + + if (element_type := kwargs.get("type")) is not None: + assert isinstance(element_type, str), "type should be a str" + + if (name := kwargs.get("name")) is not None: + assert isinstance(name, str), "name should be a str" + + if (polygon := kwargs.get("polygon")) is not None: assert isinstance(polygon, list), "polygon should be a list" assert len(polygon) >= 3, "polygon should have at least three points" assert all( @@ -311,20 +316,20 @@ class ElementMixin(object): isinstance(coord, (int, float)) for point in polygon for coord in point ), "polygon points should be lists of two numbers" - if "confidence" in kwargs and (confidence := kwargs.get("confidence")): + if (confidence := kwargs.get("confidence")) is not None: assert ( isinstance(confidence, float) and 0 <= confidence <= 1 ), "confidence should be None or a float in [0..1] range" - if rotation_angle := kwargs.get("rotation_angle"): + if (rotation_angle := kwargs.get("rotation_angle")) is not None: assert ( isinstance(rotation_angle, int) and rotation_angle >= 0 ), "rotation_angle should be a positive integer" - if mirrored := kwargs.get("mirrored"): + if (mirrored := kwargs.get("mirrored")) is not None: assert isinstance(mirrored, bool), "mirrored should be a boolean" - if image := kwargs.get("image"): + if (image := kwargs.get("image")) is not None: assert isinstance(image, UUID), "image should be a UUID" # Cast to string kwargs["image"] = str(kwargs["image"]) @@ -333,27 +338,25 @@ class ElementMixin(object): logger.warning("Cannot update element as this worker is in read-only mode") return - payload = {"type": type, "name": name, **kwargs} - updated_element = self.request( - "UpdateElement", + "PartialUpdateElement", id=element.id, - body=payload, + body=kwargs, ) if self.use_cache: # Name is not present in CachedElement model - payload.pop("name") + kwargs.pop("name", None) # Stringify polygon if present - if "polygon" in payload: - payload["polygon"] = str(payload["polygon"]) + if "polygon" in kwargs: + kwargs["polygon"] = str(kwargs["polygon"]) # Retrieve the right image - if "image" in payload: - payload["image"] = CachedImage.get_by_id(payload["image"]) + if "image" in kwargs: + kwargs["image"] = CachedImage.get_by_id(kwargs["image"]) - CachedElement.update(**payload).where( + CachedElement.update(**kwargs).where( CachedElement.id == element.id ).execute() diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index e86bdcf5..1547bed2 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -1262,31 +1262,29 @@ def test_create_elements_integrity_error( ({"image": 1234}, "image should be a UUID"), ), ) -def test_update_element_wrong_param(mock_elements_worker, payload, error): +def test_partial_update_element_wrong_param(mock_elements_worker, payload, error): api_payload = { "element": Element({"zone": None}), - "type": "type", - "name": "name", **payload, } with pytest.raises(AssertionError) as e: - mock_elements_worker.update_element( + mock_elements_worker.partial_update_element( **api_payload, ) assert str(e.value) == error -def test_update_element_api_error(responses, mock_elements_worker): +def test_partial_update_element_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( - responses.PUT, + responses.PATCH, f"http://testserver/api/v1/element/{elt.id}/", status=500, ) with pytest.raises(ErrorResponse): - mock_elements_worker.update_element( + mock_elements_worker.partial_update_element( element=elt, type="something", name="0", @@ -1298,11 +1296,11 @@ def test_update_element_api_error(responses, mock_elements_worker): (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ # We retry 5 times the API call - ("PUT", f"http://testserver/api/v1/element/{elt.id}/"), - ("PUT", f"http://testserver/api/v1/element/{elt.id}/"), - ("PUT", f"http://testserver/api/v1/element/{elt.id}/"), - ("PUT", f"http://testserver/api/v1/element/{elt.id}/"), - ("PUT", f"http://testserver/api/v1/element/{elt.id}/"), + ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"), + ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"), + ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"), + ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"), + ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"), ] @@ -1331,7 +1329,7 @@ def test_update_element_api_error(responses, mock_elements_worker): ), ), ) -def test_update_element( +def test_partial_update_element( responses, mock_elements_worker_with_cache, mock_cached_elements, @@ -1342,20 +1340,18 @@ def test_update_element( new_image = CachedImage.select().first() elt_response = { - "type": "new type", - "name": "new name", "image": str(new_image.id), **payload, } responses.add( - responses.PUT, + responses.PATCH, f"http://testserver/api/v1/element/{elt.id}/", status=200, # UUID not allowed in JSON json=elt_response, ) - element_update_response = mock_elements_worker_with_cache.update_element( + element_update_response = mock_elements_worker_with_cache.partial_update_element( element=elt, **{**elt_response, "image": new_image.id}, ) @@ -1365,7 +1361,7 @@ def test_update_element( (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( - "PUT", + "PATCH", f"http://testserver/api/v1/element/{elt.id}/", ), ] @@ -1374,7 +1370,6 @@ def test_update_element( cached_element = CachedElement.get(CachedElement.id == elt.id) # Always present in payload - assert cached_element.type == elt_response["type"] assert str(cached_element.image_id) == elt_response["image"] # Optional params if "polygon" in payload: @@ -1385,24 +1380,22 @@ def test_update_element( assert getattr(cached_element, param) == elt_response[param] -def test_update_element_confidence( +def test_partial_update_element_confidence( responses, mock_elements_worker_with_cache, mock_cached_elements ): elt = CachedElement.select().first() elt_response = { - "type": "new type", - "name": "new name", "polygon": [[10, 10], [20, 20], [20, 10], [10, 20]], "confidence": 0.42, } responses.add( - responses.PUT, + responses.PATCH, f"http://testserver/api/v1/element/{elt.id}/", status=200, json=elt_response, ) - element_update_response = mock_elements_worker_with_cache.update_element( + element_update_response = mock_elements_worker_with_cache.partial_update_element( element=elt, **elt_response, ) @@ -1412,7 +1405,7 @@ def test_update_element_confidence( (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( - "PUT", + "PATCH", f"http://testserver/api/v1/element/{elt.id}/", ), ] @@ -1420,7 +1413,6 @@ def test_update_element_confidence( assert element_update_response == elt_response cached_element = CachedElement.get(CachedElement.id == elt.id) - assert cached_element.type == elt_response["type"] assert cached_element.polygon == str(elt_response["polygon"]) assert cached_element.confidence == elt_response["confidence"] -- GitLab