diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py index 68aeeb56afa034bb86b9b5790f5143ab1473789f..db0d8905d19aa129a199214db032665298b22c9c 100644 --- a/arkindex_worker/worker/element.py +++ b/arkindex_worker/worker/element.py @@ -281,12 +281,16 @@ class ElementMixin(object): folder: Optional[bool] = None, name: Optional[str] = None, recursive: Optional[bool] = None, + transcription_worker_version: Optional[Union[str, bool]] = None, + transcription_worker_run: Optional[Union[str, bool]] = None, type: Optional[str] = None, with_classes: Optional[bool] = None, with_corpus: Optional[bool] = None, + with_metadata: Optional[bool] = None, with_has_children: Optional[bool] = None, with_zone: Optional[bool] = None, worker_version: Optional[Union[str, bool]] = None, + worker_run: Optional[Union[str, bool]] = None, ) -> Union[Iterable[dict], Iterable[CachedElement]]: """ List children of an element. @@ -298,6 +302,10 @@ class ElementMixin(object): This parameter is not supported when caching is enabled. :param recursive: Look for elements recursively (grand-children, etc.) This parameter is not supported when caching is enabled. + :param transcription_worker_version: Restrict to elements that have a transcription created by a worker version with this UUID. + This parameter is not supported when caching is enabled. + :param transcription_worker_run: Restrict to elements that have a transcription created by a worker run with this UUID. + This parameter is not supported when caching is enabled. :param type: Restrict to elements with a specific type slug This parameter is not supported when caching is enabled. :param with_classes: Include each element's classifications in the response. @@ -307,10 +315,13 @@ class ElementMixin(object): :param with_has_children: Include the ``has_children`` attribute in the response, indicating if this element has child elements of its own. This parameter is not supported when caching is enabled. + :param with_metadata: Include each element's metadata in the response. + This parameter is not supported when caching is enabled. :param with_zone: Include the ``zone`` attribute in the response, holding the element's image and polygon. This parameter is not supported when caching is enabled. :param worker_version: Restrict to elements created by a worker version with this UUID. + :param worker_run: Restrict to elements created by a worker run with this UUID. :return: An iterable of dicts from the ``ListElementChildren`` API endpoint, or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled. """ @@ -327,6 +338,24 @@ class ElementMixin(object): if recursive is not None: assert isinstance(recursive, bool), "recursive should be of type bool" query_params["recursive"] = recursive + if transcription_worker_version is not None: + assert isinstance( + transcription_worker_version, (str, bool) + ), "transcription_worker_version should be of type str or bool" + if isinstance(transcription_worker_version, bool): + assert ( + transcription_worker_version is False + ), "if of type bool, transcription_worker_version can only be set to False" + query_params["transcription_worker_version"] = transcription_worker_version + if transcription_worker_run is not None: + assert isinstance( + transcription_worker_run, (str, bool) + ), "transcription_worker_run should be of type str or bool" + if isinstance(transcription_worker_run, bool): + assert ( + transcription_worker_run is False + ), "if of type bool, transcription_worker_run can only be set to False" + query_params["transcription_worker_run"] = transcription_worker_run if type: assert isinstance(type, str), "type should be of type str" query_params["type"] = type @@ -341,6 +370,11 @@ class ElementMixin(object): with_has_children, bool ), "with_has_children should be of type bool" query_params["with_has_children"] = with_has_children + if with_metadata is not None: + assert isinstance( + with_metadata, bool + ), "with_metadata should be of type bool" + query_params["with_metadata"] = with_metadata if with_zone is not None: assert isinstance(with_zone, bool), "with_zone should be of type bool" query_params["with_zone"] = with_zone @@ -353,13 +387,23 @@ class ElementMixin(object): worker_version is False ), "if of type bool, worker_version can only be set to False" query_params["worker_version"] = worker_version + if worker_run is not None: + assert isinstance( + worker_run, (str, bool) + ), "worker_run should be of type str or bool" + if isinstance(worker_run, bool): + assert ( + worker_run is False + ), "if of type bool, worker_run can only be set to False" + query_params["worker_run"] = worker_run if self.use_cache: # Checking that we only received query_params handled by the cache assert set(query_params.keys()) <= { "type", "worker_version", - }, "When using the local cache, you can only filter by 'type' and/or 'worker_version'" + "worker_run", + }, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" query = CachedElement.select().where(CachedElement.parent_id == element.id) if type: @@ -370,6 +414,10 @@ class ElementMixin(object): query = query.where( CachedElement.worker_version_id == worker_version_id ) + if worker_run is not None: + # If worker_run=False, filter by manual worker_run e.g. None + worker_run_id = worker_run if worker_run else None + query = query.where(CachedElement.worker_run_id == worker_run_id) return query else: diff --git a/tests/conftest.py b/tests/conftest.py index 142e0d1e08d4b76bcb54c2da1e1a99ed975b6680..e515aeaa60fe338784565b8f28808b315b3a1a14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -334,6 +334,7 @@ def mock_cached_elements(): type="page", polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), ) CachedElement.create( id=UUID("33333333-3333-3333-3333-333333333333"), @@ -341,6 +342,7 @@ def mock_cached_elements(): type="paragraph", polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", worker_version_id=None, + worker_run_id=None, ) assert CachedElement.select().count() == 3 diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 76325c96dd96572cfc18cda59db4d07fe9ec5a43..1d088f769b9544f1a78c6fdc62fcfe8dacfd6d6e 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -1314,26 +1314,55 @@ def test_list_element_children_wrong_with_zone(mock_elements_worker): assert str(e.value) == "with_zone should be of type bool" -def test_list_element_children_wrong_worker_version(mock_elements_worker): +def test_list_element_children_wrong_with_metadata(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.list_element_children( element=elt, - worker_version=1234, + with_metadata="not bool", ) - assert str(e.value) == "worker_version should be of type str or bool" + assert str(e.value) == "with_metadata should be of type bool" -def test_list_element_children_wrong_bool_worker_version(mock_elements_worker): +@pytest.mark.parametrize( + "param, value", + ( + ("worker_version", 1234), + ("worker_run", 1234), + ("transcription_worker_version", 1234), + ("transcription_worker_run", 1234), + ), +) +def test_list_element_children_wrong_worker_version(mock_elements_worker, param, value): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_children( + element=elt, + **{param: value}, + ) + assert str(e.value) == f"{param} should be of type str or bool" + + +@pytest.mark.parametrize( + "param", + ( + ("worker_version"), + ("worker_run"), + ("transcription_worker_version"), + ("transcription_worker_run"), + ), +) +def test_list_element_children_wrong_bool_worker_version(mock_elements_worker, param): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.list_element_children( element=elt, - worker_version=True, + **{param: True}, ) - assert str(e.value) == "if of type bool, worker_version can only be set to False" + assert str(e.value) == f"if of type bool, {param} can only be set to False" def test_list_element_children_api_error(responses, mock_elements_worker): @@ -1487,6 +1516,49 @@ def test_list_element_children_manual_worker_version(responses, mock_elements_wo ] +def test_list_element_children_manual_worker_run(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + expected_children = [ + { + "id": "0000", + "type": "page", + "name": "Test", + "corpus": {}, + "thumbnail_url": None, + "zone": {}, + "best_classes": None, + "has_children": None, + "worker_version_id": None, + "worker_run_id": None, + } + ] + responses.add( + responses.GET, + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_run=False", + status=200, + json={ + "count": 1, + "next": None, + "results": expected_children, + }, + ) + + for idx, child in enumerate( + mock_elements_worker.list_element_children(element=elt, worker_run=False) + ): + assert child == expected_children[idx] + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "GET", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_run=False", + ), + ] + + def test_list_element_children_with_cache_unhandled_param( mock_elements_worker_with_cache, ): @@ -1498,7 +1570,7 @@ def test_list_element_children_with_cache_unhandled_param( ) assert ( str(e.value) - == "When using the local cache, you can only filter by 'type' and/or 'worker_version'" + == "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" ) @@ -1552,6 +1624,25 @@ def test_list_element_children_with_cache_unhandled_param( }, ("33333333-3333-3333-3333-333333333333",), ), + # Filter on element and worker run should give second + ( + { + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), + "worker_run": "56785678-5678-5678-5678-567856785678", + }, + ("22222222-2222-2222-2222-222222222222",), + ), + # Filter on element, manual worker run should give first and third + ( + { + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), + "worker_run": False, + }, + ( + "11111111-1111-1111-1111-111111111111", + "33333333-3333-3333-3333-333333333333", + ), + ), ), ) def test_list_element_children_with_cache(