diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py index 985481ad0b7a6d1a7dddd630c9bcdf8802bf257a..a64002d5cbc70a30d9cb7154ec93d4208917be66 100644 --- a/arkindex_worker/worker/element.py +++ b/arkindex_worker/worker/element.py @@ -514,27 +514,39 @@ class ElementMixin(object): with_zone: Optional[bool] = None, worker_version: Optional[Union[str, bool]] = None, worker_run: Optional[Union[str, bool]] = None, - ) -> Iterable[dict]: + ) -> Union[Iterable[dict], Iterable[CachedElement]]: """ - List parents of an element through the API. + List parents of an element. :param element: Child element to find parents of. :param folder: Restrict to or exclude elements with folder types. + This parameter is not supported when caching is enabled. :param name: Restrict to elements whose name contain a substring (case-insensitive). - :param recursive: Look for elements recursively (grand-parents, etc.) + This parameter is not supported when caching is enabled. + :param recursive: Look for elements recursively (grand-children, etc.) + This parameter is not supported when caching is enabled. :param transcription_worker_version: Restrict to elements that have a transcription created by a worker version with this UUID. + This parameter is not supported when caching is enabled. :param transcription_worker_run: Restrict to elements that have a transcription created by a worker run with this UUID. + This parameter is not supported when caching is enabled. :param type: Restrict to elements with a specific type slug + This parameter is not supported when caching is enabled. :param with_classes: Include each element's classifications in the response. + This parameter is not supported when caching is enabled. :param with_corpus: Include each element's corpus in the response. + This parameter is not supported when caching is enabled. :param with_has_children: Include the ``has_children`` attribute in the response, indicating if this element has child elements of its own. + This parameter is not supported when caching is enabled. :param with_metadata: Include each element's metadata in the response. + This parameter is not supported when caching is enabled. :param with_zone: Include the ``zone`` attribute in the response, holding the element's image and polygon. + This parameter is not supported when caching is enabled. :param worker_version: Restrict to elements created by a worker version with this UUID. :param worker_run: Restrict to elements created by a worker run with this UUID. - :return: An iterable of dicts from the ``ListElementChildren`` API endpoint. + :return: An iterable of dicts from the ``ListElementParents`` API endpoint, + or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled. """ assert element and isinstance( element, (Element, CachedElement) @@ -608,6 +620,42 @@ class ElementMixin(object): ), "if of type bool, worker_run can only be set to False" query_params["worker_run"] = worker_run - return self.api_client.paginate( - "ListElementParents", id=element.id, **query_params - ) + if self.use_cache: + # Checking that we only received query_params handled by the cache + assert set(query_params.keys()) <= { + "type", + "worker_version", + "worker_run", + }, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" + + parent_ids = CachedElement.select(CachedElement.parent_id).where( + CachedElement.id == element.id + ) + query = CachedElement.select().where(CachedElement.id.in_(parent_ids)) + if type: + query = query.where(CachedElement.type == type) + if worker_version is not None: + # If worker_version=False, filter by manual worker_version e.g. None + worker_version_id = worker_version or None + if worker_version_id: + query = query.where( + CachedElement.worker_version_id == worker_version_id + ) + else: + query = query.where(CachedElement.worker_version_id.is_null()) + + if worker_run is not None: + # If worker_run=False, filter by manual worker_run e.g. None + worker_run_id = worker_run or None + if worker_run_id: + query = query.where(CachedElement.worker_run_id == worker_run_id) + else: + query = query.where(CachedElement.worker_run_id.is_null()) + + return query + else: + parents = self.api_client.paginate( + "ListElementParents", id=element.id, **query_params + ) + + return parents diff --git a/tests/conftest.py b/tests/conftest.py index 1e338843877b011b2f8f084fd238cd184b8978ec..7b613c7657da8ab19300ea0ef8ef5a2815300abf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -321,9 +321,25 @@ def fake_gitlab_helper_factory(): @pytest.fixture def mock_cached_elements(): """Insert few elements in local cache""" + CachedElement.create( + id=UUID("99999999-9999-9999-9999-999999999999"), + parent_id=None, + type="something", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=None, + worker_run_id=None, + ) + CachedElement.create( + id=UUID("12341234-1234-1234-1234-123412341234"), + parent_id=UUID("99999999-9999-9999-9999-999999999999"), + type="double_page", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), + ) CachedElement.create( id=UUID("11111111-1111-1111-1111-111111111111"), - parent_id="12341234-1234-1234-1234-123412341234", + parent_id=UUID("12341234-1234-1234-1234-123412341234"), type="something", polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), @@ -344,7 +360,7 @@ def mock_cached_elements(): worker_version_id=None, worker_run_id=None, ) - assert CachedElement.select().count() == 3 + assert CachedElement.select().count() == 5 @pytest.fixture diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 1b7f96d7aba6eae0b62b18070982bd1e8c82597e..1d7b7a284359ac27f9bc70e45024c4a2ce21dad1 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -2211,3 +2211,99 @@ def test_list_element_parents_manual_worker_run(responses, mock_elements_worker) "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/?worker_run=False", ), ] + + +def test_list_element_parents_with_cache_unhandled_param( + mock_elements_worker_with_cache, +): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker_with_cache.list_element_parents( + element=elt, with_corpus=True + ) + assert ( + str(e.value) + == "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" + ) + + +@pytest.mark.parametrize( + "filters, expected_id", + ( + # Filter on element + ( + { + "element": CachedElement(id="11111111-1111-1111-1111-111111111111"), + }, + "12341234-1234-1234-1234-123412341234", + ), + # Filter on element and double_page + ( + { + "element": CachedElement(id="22222222-2222-2222-2222-222222222222"), + "type": "double_page", + }, + "12341234-1234-1234-1234-123412341234", + ), + # Filter on element and worker version + ( + { + "element": CachedElement(id="33333333-3333-3333-3333-333333333333"), + "worker_version": "56785678-5678-5678-5678-567856785678", + }, + "12341234-1234-1234-1234-123412341234", + ), + # Filter on element, type double_page and worker version + ( + { + "element": CachedElement(id="11111111-1111-1111-1111-111111111111"), + "type": "double_page", + "worker_version": "56785678-5678-5678-5678-567856785678", + }, + "12341234-1234-1234-1234-123412341234", + ), + # Filter on element, manual worker version + ( + { + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), + "worker_version": False, + }, + "99999999-9999-9999-9999-999999999999", + ), + # Filter on element and worker run + ( + { + "element": CachedElement(id="22222222-2222-2222-2222-222222222222"), + "worker_run": "56785678-5678-5678-5678-567856785678", + }, + "12341234-1234-1234-1234-123412341234", + ), + # Filter on element, manual worker run + ( + { + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), + "worker_run": False, + }, + "99999999-9999-9999-9999-999999999999", + ), + ), +) +def test_list_element_parents_with_cache( + responses, + mock_elements_worker_with_cache, + mock_cached_elements, + filters, + expected_id, +): + # Query database through cache + elements = mock_elements_worker_with_cache.list_element_parents(**filters) + assert elements.count() == 1 + for parent in elements.order_by("id"): + assert parent.id == UUID(expected_id) + + # Check the worker never hits the API for elements + assert len(responses.calls) == len(BASE_API_CALLS) + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS