From ec3904de1b9de91cbdd41c1132ba52b0b90ea10e Mon Sep 17 00:00:00 2001
From: Manon Blanco <blanco@teklia.com>
Date: Thu, 21 Dec 2023 12:26:21 +0000
Subject: [PATCH] Add an API helper to link two elements

---
 arkindex_worker/worker/element.py           | 31 +++++++-
 tests/test_elements_worker/test_elements.py | 88 +++++++++++++++++++++
 2 files changed, 118 insertions(+), 1 deletion(-)

diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py
index 7ab59eec..a1f09849 100644
--- a/arkindex_worker/worker/element.py
+++ b/arkindex_worker/worker/element.py
@@ -283,6 +283,35 @@ class ElementMixin:
 
         return created_ids
 
+    def create_element_parent(
+        self,
+        parent: Element,
+        child: Element,
+    ) -> dict[str, str]:
+        """
+        Link an element to a parent through the API.
+
+        :param parent: Parent element.
+        :param child: Child element.
+        :returns: A dict from the ``CreateElementParent`` API endpoint.
+        """
+        assert parent and isinstance(
+            parent, Element
+        ), "parent shouldn't be null and should be of type Element"
+        assert child and isinstance(
+            child, Element
+        ), "child shouldn't be null and should be of type Element"
+
+        if self.is_read_only:
+            logger.warning("Cannot link elements as this worker is in read-only mode")
+            return
+
+        return self.request(
+            "CreateElementParent",
+            parent=parent.id,
+            child=child.id,
+        )
+
     def partial_update_element(
         self, element: Element | CachedElement, **kwargs
     ) -> dict:
@@ -301,7 +330,7 @@ class ElementMixin:
             * *image* (``UUID``): Optional ID of the image of this element
 
 
-        :returns: A dict from the ``PartialUpdateElement`` API endpoint,
+        :returns: A dict from the ``PartialUpdateElement`` API endpoint.
         """
         assert element and isinstance(
             element, Element | CachedElement
diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py
index 47928f1e..2445f230 100644
--- a/tests/test_elements_worker/test_elements.py
+++ b/tests/test_elements_worker/test_elements.py
@@ -1241,6 +1241,94 @@ def test_create_elements_integrity_error(
     assert list(CachedElement.select()) == []
 
 
+@pytest.mark.parametrize(
+    ("params", "error_message"),
+    [
+        (
+            {"parent": None, "child": None},
+            "parent shouldn't be null and should be of type Element",
+        ),
+        (
+            {"parent": "not an element", "child": None},
+            "parent shouldn't be null and should be of type Element",
+        ),
+        (
+            {"parent": Element(zone=None), "child": None},
+            "child shouldn't be null and should be of type Element",
+        ),
+        (
+            {"parent": Element(zone=None), "child": "not an element"},
+            "child shouldn't be null and should be of type Element",
+        ),
+    ],
+)
+def test_create_element_parent_invalid_params(
+    mock_elements_worker, params, error_message
+):
+    with pytest.raises(AssertionError, match=re.escape(error_message)):
+        mock_elements_worker.create_element_parent(**params)
+
+
+def test_create_element_parent_api_error(responses, mock_elements_worker):
+    parent = Element({"id": "12341234-1234-1234-1234-123412341234"})
+    child = Element({"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"})
+    responses.add(
+        responses.POST,
+        "http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/",
+        status=500,
+    )
+
+    with pytest.raises(ErrorResponse):
+        mock_elements_worker.create_element_parent(
+            parent=parent,
+            child=child,
+        )
+
+    assert len(responses.calls) == len(BASE_API_CALLS) + 5
+    assert [
+        (call.request.method, call.request.url) for call in responses.calls
+    ] == BASE_API_CALLS + [
+        # We retry 5 times the API call
+        (
+            "POST",
+            "http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/",
+        ),
+    ] * 5
+
+
+def test_create_element_parent(responses, mock_elements_worker):
+    parent = Element({"id": "12341234-1234-1234-1234-123412341234"})
+    child = Element({"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"})
+    responses.add(
+        responses.POST,
+        "http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/",
+        status=200,
+        json={
+            "parent": "12341234-1234-1234-1234-123412341234",
+            "child": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
+        },
+    )
+
+    created_element_parent = mock_elements_worker.create_element_parent(
+        parent=parent,
+        child=child,
+    )
+
+    assert len(responses.calls) == len(BASE_API_CALLS) + 1
+    assert [
+        (call.request.method, call.request.url) for call in responses.calls
+    ] == BASE_API_CALLS + [
+        (
+            "POST",
+            "http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/",
+        ),
+    ]
+    assert created_element_parent == {
+        "parent": "12341234-1234-1234-1234-123412341234",
+        "child": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
+    }
+
+
 @pytest.mark.parametrize(
     ("payload", "error"),
     [
-- 
GitLab