diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py index 2d54563f04de1973a07dafa9ee17844fd8ac3e8c..a64002d5cbc70a30d9cb7154ec93d4208917be66 100644 --- a/arkindex_worker/worker/element.py +++ b/arkindex_worker/worker/element.py @@ -497,3 +497,165 @@ class ElementMixin(object): ) return children + + def list_element_parents( + self, + element: Union[Element, CachedElement], + folder: Optional[bool] = None, + name: Optional[str] = None, + recursive: Optional[bool] = None, + transcription_worker_version: Optional[Union[str, bool]] = None, + transcription_worker_run: Optional[Union[str, bool]] = None, + type: Optional[str] = None, + with_classes: Optional[bool] = None, + with_corpus: Optional[bool] = None, + with_metadata: Optional[bool] = None, + with_has_children: Optional[bool] = None, + with_zone: Optional[bool] = None, + worker_version: Optional[Union[str, bool]] = None, + worker_run: Optional[Union[str, bool]] = None, + ) -> Union[Iterable[dict], Iterable[CachedElement]]: + """ + List parents of an element. + + :param element: Child element to find parents of. + :param folder: Restrict to or exclude elements with folder types. + This parameter is not supported when caching is enabled. + :param name: Restrict to elements whose name contain a substring (case-insensitive). + This parameter is not supported when caching is enabled. + :param recursive: Look for elements recursively (grand-children, etc.) + This parameter is not supported when caching is enabled. + :param transcription_worker_version: Restrict to elements that have a transcription created by a worker version with this UUID. + This parameter is not supported when caching is enabled. + :param transcription_worker_run: Restrict to elements that have a transcription created by a worker run with this UUID. + This parameter is not supported when caching is enabled. + :param type: Restrict to elements with a specific type slug + This parameter is not supported when caching is enabled. + :param with_classes: Include each element's classifications in the response. + This parameter is not supported when caching is enabled. + :param with_corpus: Include each element's corpus in the response. + This parameter is not supported when caching is enabled. + :param with_has_children: Include the ``has_children`` attribute in the response, + indicating if this element has child elements of its own. + This parameter is not supported when caching is enabled. + :param with_metadata: Include each element's metadata in the response. + This parameter is not supported when caching is enabled. + :param with_zone: Include the ``zone`` attribute in the response, + holding the element's image and polygon. + This parameter is not supported when caching is enabled. + :param worker_version: Restrict to elements created by a worker version with this UUID. + :param worker_run: Restrict to elements created by a worker run with this UUID. + :return: An iterable of dicts from the ``ListElementParents`` API endpoint, + or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled. + """ + assert element and isinstance( + element, (Element, CachedElement) + ), "element shouldn't be null and should be an Element or CachedElement" + query_params = {} + if folder is not None: + assert isinstance(folder, bool), "folder should be of type bool" + query_params["folder"] = folder + if name: + assert isinstance(name, str), "name should be of type str" + query_params["name"] = name + if recursive is not None: + assert isinstance(recursive, bool), "recursive should be of type bool" + query_params["recursive"] = recursive + if transcription_worker_version is not None: + assert isinstance( + transcription_worker_version, (str, bool) + ), "transcription_worker_version should be of type str or bool" + if isinstance(transcription_worker_version, bool): + assert ( + transcription_worker_version is False + ), "if of type bool, transcription_worker_version can only be set to False" + query_params["transcription_worker_version"] = transcription_worker_version + if transcription_worker_run is not None: + assert isinstance( + transcription_worker_run, (str, bool) + ), "transcription_worker_run should be of type str or bool" + if isinstance(transcription_worker_run, bool): + assert ( + transcription_worker_run is False + ), "if of type bool, transcription_worker_run can only be set to False" + query_params["transcription_worker_run"] = transcription_worker_run + if type: + assert isinstance(type, str), "type should be of type str" + query_params["type"] = type + if with_classes is not None: + assert isinstance(with_classes, bool), "with_classes should be of type bool" + query_params["with_classes"] = with_classes + if with_corpus is not None: + assert isinstance(with_corpus, bool), "with_corpus should be of type bool" + query_params["with_corpus"] = with_corpus + if with_has_children is not None: + assert isinstance( + with_has_children, bool + ), "with_has_children should be of type bool" + query_params["with_has_children"] = with_has_children + if with_metadata is not None: + assert isinstance( + with_metadata, bool + ), "with_metadata should be of type bool" + query_params["with_metadata"] = with_metadata + if with_zone is not None: + assert isinstance(with_zone, bool), "with_zone should be of type bool" + query_params["with_zone"] = with_zone + if worker_version is not None: + assert isinstance( + worker_version, (str, bool) + ), "worker_version should be of type str or bool" + if isinstance(worker_version, bool): + assert ( + worker_version is False + ), "if of type bool, worker_version can only be set to False" + query_params["worker_version"] = worker_version + if worker_run is not None: + assert isinstance( + worker_run, (str, bool) + ), "worker_run should be of type str or bool" + if isinstance(worker_run, bool): + assert ( + worker_run is False + ), "if of type bool, worker_run can only be set to False" + query_params["worker_run"] = worker_run + + if self.use_cache: + # Checking that we only received query_params handled by the cache + assert set(query_params.keys()) <= { + "type", + "worker_version", + "worker_run", + }, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" + + parent_ids = CachedElement.select(CachedElement.parent_id).where( + CachedElement.id == element.id + ) + query = CachedElement.select().where(CachedElement.id.in_(parent_ids)) + if type: + query = query.where(CachedElement.type == type) + if worker_version is not None: + # If worker_version=False, filter by manual worker_version e.g. None + worker_version_id = worker_version or None + if worker_version_id: + query = query.where( + CachedElement.worker_version_id == worker_version_id + ) + else: + query = query.where(CachedElement.worker_version_id.is_null()) + + if worker_run is not None: + # If worker_run=False, filter by manual worker_run e.g. None + worker_run_id = worker_run or None + if worker_run_id: + query = query.where(CachedElement.worker_run_id == worker_run_id) + else: + query = query.where(CachedElement.worker_run_id.is_null()) + + return query + else: + parents = self.api_client.paginate( + "ListElementParents", id=element.id, **query_params + ) + + return parents diff --git a/tests/conftest.py b/tests/conftest.py index 1e338843877b011b2f8f084fd238cd184b8978ec..7b613c7657da8ab19300ea0ef8ef5a2815300abf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -321,9 +321,25 @@ def fake_gitlab_helper_factory(): @pytest.fixture def mock_cached_elements(): """Insert few elements in local cache""" + CachedElement.create( + id=UUID("99999999-9999-9999-9999-999999999999"), + parent_id=None, + type="something", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=None, + worker_run_id=None, + ) + CachedElement.create( + id=UUID("12341234-1234-1234-1234-123412341234"), + parent_id=UUID("99999999-9999-9999-9999-999999999999"), + type="double_page", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), + ) CachedElement.create( id=UUID("11111111-1111-1111-1111-111111111111"), - parent_id="12341234-1234-1234-1234-123412341234", + parent_id=UUID("12341234-1234-1234-1234-123412341234"), type="something", polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), @@ -344,7 +360,7 @@ def mock_cached_elements(): worker_version_id=None, worker_run_id=None, ) - assert CachedElement.select().count() == 3 + assert CachedElement.select().count() == 5 @pytest.fixture diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 12b438722866ab7932aa1fae855cf6cac7722d79..273e41ba54a041abb435218086cd17529f58a12b 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -1848,8 +1848,8 @@ def test_list_element_children_with_cache( filters, expected_ids, ): - # Check we have 2 elements already present in database - assert CachedElement.select().count() == 3 + # Check we have 5 elements already present in database + assert CachedElement.select().count() == 5 # Query database through cache elements = mock_elements_worker_with_cache.list_element_children(**filters) @@ -1862,3 +1862,451 @@ def test_list_element_children_with_cache( assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + + +def test_list_element_parents_wrong_element(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents(element=None) + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents(element="not element type") + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) + + +def test_list_element_parents_wrong_folder(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + folder="not bool", + ) + assert str(e.value) == "folder should be of type bool" + + +def test_list_element_parents_wrong_name(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + name=1234, + ) + assert str(e.value) == "name should be of type str" + + +def test_list_element_parents_wrong_recursive(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + recursive="not bool", + ) + assert str(e.value) == "recursive should be of type bool" + + +def test_list_element_parents_wrong_type(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + type=1234, + ) + assert str(e.value) == "type should be of type str" + + +def test_list_element_parents_wrong_with_classes(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + with_classes="not bool", + ) + assert str(e.value) == "with_classes should be of type bool" + + +def test_list_element_parents_wrong_with_corpus(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + with_corpus="not bool", + ) + assert str(e.value) == "with_corpus should be of type bool" + + +def test_list_element_parents_wrong_with_has_children(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + with_has_children="not bool", + ) + assert str(e.value) == "with_has_children should be of type bool" + + +def test_list_element_parents_wrong_with_zone(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + with_zone="not bool", + ) + assert str(e.value) == "with_zone should be of type bool" + + +def test_list_element_parents_wrong_with_metadata(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + with_metadata="not bool", + ) + assert str(e.value) == "with_metadata should be of type bool" + + +@pytest.mark.parametrize( + "param, value", + ( + ("worker_version", 1234), + ("worker_run", 1234), + ("transcription_worker_version", 1234), + ("transcription_worker_run", 1234), + ), +) +def test_list_element_parents_wrong_worker_version(mock_elements_worker, param, value): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + **{param: value}, + ) + assert str(e.value) == f"{param} should be of type str or bool" + + +@pytest.mark.parametrize( + "param", + ( + ("worker_version"), + ("worker_run"), + ("transcription_worker_version"), + ("transcription_worker_run"), + ), +) +def test_list_element_parents_wrong_bool_worker_version(mock_elements_worker, param): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_parents( + element=elt, + **{param: True}, + ) + assert str(e.value) == f"if of type bool, {param} can only be set to False" + + +def test_list_element_parents_api_error(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.GET, + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", + status=500, + ) + + with pytest.raises( + Exception, match="Stopping pagination as data will be incomplete" + ): + next(mock_elements_worker.list_element_parents(element=elt)) + + assert len(responses.calls) == len(BASE_API_CALLS) + 5 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + # We do 5 retries + ( + "GET", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", + ), + ( + "GET", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", + ), + ( + "GET", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", + ), + ( + "GET", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", + ), + ( + "GET", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", + ), + ] + + +def test_list_element_parents(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + expected_parents = [ + { + "id": "0000", + "type": "page", + "name": "Test", + "corpus": {}, + "thumbnail_url": None, + "zone": {}, + "best_classes": None, + "has_children": None, + "worker_version_id": None, + "worker_run_id": None, + }, + { + "id": "1111", + "type": "page", + "name": "Test 2", + "corpus": {}, + "thumbnail_url": None, + "zone": {}, + "best_classes": None, + "has_children": None, + "worker_version_id": None, + "worker_run_id": None, + }, + { + "id": "2222", + "type": "page", + "name": "Test 3", + "corpus": {}, + "thumbnail_url": None, + "zone": {}, + "best_classes": None, + "has_children": None, + "worker_version_id": None, + "worker_run_id": None, + }, + ] + responses.add( + responses.GET, + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", + status=200, + json={ + "count": 3, + "next": None, + "results": expected_parents, + }, + ) + + for idx, parent in enumerate( + mock_elements_worker.list_element_parents(element=elt) + ): + assert parent == expected_parents[idx] + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "GET", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", + ), + ] + + +def test_list_element_parents_manual_worker_version(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + expected_parents = [ + { + "id": "0000", + "type": "page", + "name": "Test", + "corpus": {}, + "thumbnail_url": None, + "zone": {}, + "best_classes": None, + "has_children": None, + "worker_version_id": None, + "worker_run_id": None, + } + ] + responses.add( + responses.GET, + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/?worker_version=False", + status=200, + json={ + "count": 1, + "next": None, + "results": expected_parents, + }, + ) + + for idx, parent in enumerate( + mock_elements_worker.list_element_parents(element=elt, worker_version=False) + ): + assert parent == expected_parents[idx] + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "GET", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/?worker_version=False", + ), + ] + + +def test_list_element_parents_manual_worker_run(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + expected_parents = [ + { + "id": "0000", + "type": "page", + "name": "Test", + "corpus": {}, + "thumbnail_url": None, + "zone": {}, + "best_classes": None, + "has_children": None, + "worker_version_id": None, + "worker_run_id": None, + } + ] + responses.add( + responses.GET, + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/?worker_run=False", + status=200, + json={ + "count": 1, + "next": None, + "results": expected_parents, + }, + ) + + for idx, parent in enumerate( + mock_elements_worker.list_element_parents(element=elt, worker_run=False) + ): + assert parent == expected_parents[idx] + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "GET", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/?worker_run=False", + ), + ] + + +def test_list_element_parents_with_cache_unhandled_param( + mock_elements_worker_with_cache, +): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker_with_cache.list_element_parents( + element=elt, with_corpus=True + ) + assert ( + str(e.value) + == "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" + ) + + +@pytest.mark.parametrize( + "filters, expected_id", + ( + # Filter on element + ( + { + "element": CachedElement(id="11111111-1111-1111-1111-111111111111"), + }, + "12341234-1234-1234-1234-123412341234", + ), + # Filter on element and double_page + ( + { + "element": CachedElement(id="22222222-2222-2222-2222-222222222222"), + "type": "double_page", + }, + "12341234-1234-1234-1234-123412341234", + ), + # Filter on element and worker version + ( + { + "element": CachedElement(id="33333333-3333-3333-3333-333333333333"), + "worker_version": "56785678-5678-5678-5678-567856785678", + }, + "12341234-1234-1234-1234-123412341234", + ), + # Filter on element, type double_page and worker version + ( + { + "element": CachedElement(id="11111111-1111-1111-1111-111111111111"), + "type": "double_page", + "worker_version": "56785678-5678-5678-5678-567856785678", + }, + "12341234-1234-1234-1234-123412341234", + ), + # Filter on element, manual worker version + ( + { + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), + "worker_version": False, + }, + "99999999-9999-9999-9999-999999999999", + ), + # Filter on element and worker run + ( + { + "element": CachedElement(id="22222222-2222-2222-2222-222222222222"), + "worker_run": "56785678-5678-5678-5678-567856785678", + }, + "12341234-1234-1234-1234-123412341234", + ), + # Filter on element, manual worker run + ( + { + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), + "worker_run": False, + }, + "99999999-9999-9999-9999-999999999999", + ), + ), +) +def test_list_element_parents_with_cache( + responses, + mock_elements_worker_with_cache, + mock_cached_elements, + filters, + expected_id, +): + # Check we have 5 elements already present in database + assert CachedElement.select().count() == 5 + + # Query database through cache + elements = mock_elements_worker_with_cache.list_element_parents(**filters) + assert elements.count() == 1 + for parent in elements.order_by("id"): + assert parent.id == UUID(expected_id) + + # Check the worker never hits the API for elements + assert len(responses.calls) == len(BASE_API_CALLS) + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS