From 144866aee63966d1c7733a0caae90d96ee8b3875 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Tue, 23 Aug 2022 15:43:28 +0000
Subject: [PATCH] Add slim_output argument when creating sub element

---
 arkindex_worker/worker/element.py           |  5 ++-
 tests/test_elements_worker/test_elements.py | 41 ++++++++++++++-------
 2 files changed, 31 insertions(+), 15 deletions(-)

diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py
index 82f53c28..a22dbd0e 100644
--- a/arkindex_worker/worker/element.py
+++ b/arkindex_worker/worker/element.py
@@ -60,6 +60,7 @@ class ElementMixin(object):
         name: str,
         polygon: List[List[Union[int, float]]],
         confidence: Optional[float] = None,
+        slim_output: bool = True,
     ) -> str:
         """
         Create a child element on the given element through the API.
@@ -95,6 +96,7 @@ class ElementMixin(object):
         assert confidence is None or (
             isinstance(confidence, float) and 0 <= confidence <= 1
         ), "confidence should be None or a float in [0..1] range"
+        assert isinstance(slim_output, bool), "slim_output should be of type bool"
 
         if self.is_read_only:
             logger.warning("Cannot create element as this worker is in read-only mode")
@@ -102,6 +104,7 @@ class ElementMixin(object):
 
         sub_element = self.request(
             "CreateElement",
+            slim_output=slim_output,
             body={
                 "type": type,
                 "name": name,
@@ -115,7 +118,7 @@ class ElementMixin(object):
         )
         self.report.add_element(element.id, type)
 
-        return sub_element["id"]
+        return sub_element["id"] if slim_output else sub_element
 
     def create_elements(
         self,
diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py
index be2d56dd..7ff16855 100644
--- a/tests/test_elements_worker/test_elements.py
+++ b/tests/test_elements_worker/test_elements.py
@@ -447,7 +447,7 @@ def test_create_sub_element_api_error(responses, mock_elements_worker):
     )
     responses.add(
         responses.POST,
-        "http://testserver/api/v1/elements/create/",
+        "http://testserver/api/v1/elements/create/?slim_output=True",
         status=500,
     )
 
@@ -464,15 +464,16 @@ def test_create_sub_element_api_error(responses, mock_elements_worker):
         (call.request.method, call.request.url) for call in responses.calls
     ] == BASE_API_CALLS + [
         # We retry 5 times the API call
-        ("POST", "http://testserver/api/v1/elements/create/"),
-        ("POST", "http://testserver/api/v1/elements/create/"),
-        ("POST", "http://testserver/api/v1/elements/create/"),
-        ("POST", "http://testserver/api/v1/elements/create/"),
-        ("POST", "http://testserver/api/v1/elements/create/"),
+        ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
+        ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
+        ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
+        ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
+        ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
     ]
 
 
-def test_create_sub_element(responses, mock_elements_worker):
+@pytest.mark.parametrize("slim_output", [True, False])
+def test_create_sub_element(responses, mock_elements_worker, slim_output):
     elt = Element(
         {
             "id": "12341234-1234-1234-1234-123412341234",
@@ -480,25 +481,34 @@ def test_create_sub_element(responses, mock_elements_worker):
             "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}},
         }
     )
+    child_elt = {
+        "id": "12345678-1234-1234-1234-123456789123",
+        "corpus": {"id": "11111111-1111-1111-1111-111111111111"},
+        "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}},
+    }
     responses.add(
         responses.POST,
-        "http://testserver/api/v1/elements/create/",
+        f"http://testserver/api/v1/elements/create/?slim_output={slim_output}",
         status=200,
-        json={"id": "12345678-1234-1234-1234-123456789123"},
+        json=child_elt,
     )
 
-    sub_element_id = mock_elements_worker.create_sub_element(
+    element_creation_response = mock_elements_worker.create_sub_element(
         element=elt,
         type="something",
         name="0",
         polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
+        slim_output=slim_output,
     )
 
     assert len(responses.calls) == len(BASE_API_CALLS) + 1
     assert [
         (call.request.method, call.request.url) for call in responses.calls
     ] == BASE_API_CALLS + [
-        ("POST", "http://testserver/api/v1/elements/create/"),
+        (
+            "POST",
+            f"http://testserver/api/v1/elements/create/?slim_output={slim_output}",
+        ),
     ]
     assert json.loads(responses.calls[-1].request.body) == {
         "type": "something",
@@ -510,7 +520,10 @@ def test_create_sub_element(responses, mock_elements_worker):
         "worker_version": "12341234-1234-1234-1234-123412341234",
         "confidence": None,
     }
-    assert sub_element_id == "12345678-1234-1234-1234-123456789123"
+    if slim_output:
+        assert element_creation_response == "12345678-1234-1234-1234-123456789123"
+    else:
+        assert Element(element_creation_response) == Element(child_elt)
 
 
 def test_create_sub_element_confidence(responses, mock_elements_worker):
@@ -523,7 +536,7 @@ def test_create_sub_element_confidence(responses, mock_elements_worker):
     )
     responses.add(
         responses.POST,
-        "http://testserver/api/v1/elements/create/",
+        "http://testserver/api/v1/elements/create/?slim_output=True",
         status=200,
         json={"id": "12345678-1234-1234-1234-123456789123"},
     )
@@ -540,7 +553,7 @@ def test_create_sub_element_confidence(responses, mock_elements_worker):
     assert [
         (call.request.method, call.request.url) for call in responses.calls
     ] == BASE_API_CALLS + [
-        ("POST", "http://testserver/api/v1/elements/create/"),
+        ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"),
     ]
     assert json.loads(responses.calls[-1].request.body) == {
         "type": "something",
-- 
GitLab