diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py index 82f53c281bbfa0525ef75356e8bbe278831ddd84..df22528255701e35f5d0f02b1dcad53e2d2e6e36 100644 --- a/arkindex_worker/worker/element.py +++ b/arkindex_worker/worker/element.py @@ -60,6 +60,7 @@ class ElementMixin(object): name: str, polygon: List[List[Union[int, float]]], confidence: Optional[float] = None, + slim_output: bool = True, ) -> str: """ Create a child element on the given element through the API. @@ -95,6 +96,7 @@ class ElementMixin(object): assert confidence is None or ( isinstance(confidence, float) and 0 <= confidence <= 1 ), "confidence should be None or a float in [0..1] range" + assert isinstance(slim_output, bool), "slim_output should be of type bool" if self.is_read_only: logger.warning("Cannot create element as this worker is in read-only mode") @@ -115,7 +117,7 @@ class ElementMixin(object): ) self.report.add_element(element.id, type) - return sub_element["id"] + return sub_element["id"] if slim_output else sub_element def create_elements( self, diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index be2d56dd12e214c20e5b4cba8d16d8e9ad6b12b4..e3ad26f775ab3128db757ad2d2cb2d99045e4e8f 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -472,7 +472,8 @@ def test_create_sub_element_api_error(responses, mock_elements_worker): ] -def test_create_sub_element(responses, mock_elements_worker): +@pytest.mark.parametrize("slim_output", [True, False]) +def test_create_sub_element(responses, mock_elements_worker, slim_output): elt = Element( { "id": "12341234-1234-1234-1234-123412341234", @@ -480,18 +481,24 @@ def test_create_sub_element(responses, mock_elements_worker): "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}}, } ) + child_elt = { + "id": "12345678-1234-1234-1234-123456789123", + "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, + "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}}, + } responses.add( responses.POST, "http://testserver/api/v1/elements/create/", status=200, - json={"id": "12345678-1234-1234-1234-123456789123"}, + json=child_elt, ) - sub_element_id = mock_elements_worker.create_sub_element( + element_creation_response = mock_elements_worker.create_sub_element( element=elt, type="something", name="0", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], + slim_output=slim_output, ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 @@ -510,7 +517,10 @@ def test_create_sub_element(responses, mock_elements_worker): "worker_version": "12341234-1234-1234-1234-123412341234", "confidence": None, } - assert sub_element_id == "12345678-1234-1234-1234-123456789123" + if slim_output: + assert element_creation_response == "12345678-1234-1234-1234-123456789123" + else: + assert Element(element_creation_response) == Element(child_elt) def test_create_sub_element_confidence(responses, mock_elements_worker):