diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py index 82f53c281bbfa0525ef75356e8bbe278831ddd84..a22dbd0e54646c1c20223f27d2445fc01cee4940 100644 --- a/arkindex_worker/worker/element.py +++ b/arkindex_worker/worker/element.py @@ -60,6 +60,7 @@ class ElementMixin(object): name: str, polygon: List[List[Union[int, float]]], confidence: Optional[float] = None, + slim_output: bool = True, ) -> str: """ Create a child element on the given element through the API. @@ -95,6 +96,7 @@ class ElementMixin(object): assert confidence is None or ( isinstance(confidence, float) and 0 <= confidence <= 1 ), "confidence should be None or a float in [0..1] range" + assert isinstance(slim_output, bool), "slim_output should be of type bool" if self.is_read_only: logger.warning("Cannot create element as this worker is in read-only mode") @@ -102,6 +104,7 @@ class ElementMixin(object): sub_element = self.request( "CreateElement", + slim_output=slim_output, body={ "type": type, "name": name, @@ -115,7 +118,7 @@ class ElementMixin(object): ) self.report.add_element(element.id, type) - return sub_element["id"] + return sub_element["id"] if slim_output else sub_element def create_elements( self, diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index be2d56dd12e214c20e5b4cba8d16d8e9ad6b12b4..7ff1685547d830dc3784e69bb54c52e6c64f9e76 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -447,7 +447,7 @@ def test_create_sub_element_api_error(responses, mock_elements_worker): ) responses.add( responses.POST, - "http://testserver/api/v1/elements/create/", + "http://testserver/api/v1/elements/create/?slim_output=True", status=500, ) @@ -464,15 +464,16 @@ def test_create_sub_element_api_error(responses, mock_elements_worker): (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ # We retry 5 times the API call - ("POST", "http://testserver/api/v1/elements/create/"), - ("POST", "http://testserver/api/v1/elements/create/"), - ("POST", "http://testserver/api/v1/elements/create/"), - ("POST", "http://testserver/api/v1/elements/create/"), - ("POST", "http://testserver/api/v1/elements/create/"), + ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"), + ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"), + ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"), + ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"), + ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"), ] -def test_create_sub_element(responses, mock_elements_worker): +@pytest.mark.parametrize("slim_output", [True, False]) +def test_create_sub_element(responses, mock_elements_worker, slim_output): elt = Element( { "id": "12341234-1234-1234-1234-123412341234", @@ -480,25 +481,34 @@ def test_create_sub_element(responses, mock_elements_worker): "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}}, } ) + child_elt = { + "id": "12345678-1234-1234-1234-123456789123", + "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, + "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}}, + } responses.add( responses.POST, - "http://testserver/api/v1/elements/create/", + f"http://testserver/api/v1/elements/create/?slim_output={slim_output}", status=200, - json={"id": "12345678-1234-1234-1234-123456789123"}, + json=child_elt, ) - sub_element_id = mock_elements_worker.create_sub_element( + element_creation_response = mock_elements_worker.create_sub_element( element=elt, type="something", name="0", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], + slim_output=slim_output, ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ - ("POST", "http://testserver/api/v1/elements/create/"), + ( + "POST", + f"http://testserver/api/v1/elements/create/?slim_output={slim_output}", + ), ] assert json.loads(responses.calls[-1].request.body) == { "type": "something", @@ -510,7 +520,10 @@ def test_create_sub_element(responses, mock_elements_worker): "worker_version": "12341234-1234-1234-1234-123412341234", "confidence": None, } - assert sub_element_id == "12345678-1234-1234-1234-123456789123" + if slim_output: + assert element_creation_response == "12345678-1234-1234-1234-123456789123" + else: + assert Element(element_creation_response) == Element(child_elt) def test_create_sub_element_confidence(responses, mock_elements_worker): @@ -523,7 +536,7 @@ def test_create_sub_element_confidence(responses, mock_elements_worker): ) responses.add( responses.POST, - "http://testserver/api/v1/elements/create/", + "http://testserver/api/v1/elements/create/?slim_output=True", status=200, json={"id": "12345678-1234-1234-1234-123456789123"}, ) @@ -540,7 +553,7 @@ def test_create_sub_element_confidence(responses, mock_elements_worker): assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ - ("POST", "http://testserver/api/v1/elements/create/"), + ("POST", "http://testserver/api/v1/elements/create/?slim_output=True"), ] assert json.loads(responses.calls[-1].request.body) == { "type": "something",