From 6778870303595557182e9da974c43915ae36a279 Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Mon, 26 Apr 2021 07:21:25 +0000 Subject: [PATCH] Handle recursive and element_type filters in cached list_transcriptions --- arkindex_worker/worker/transcription.py | 41 ++++-- tests/conftest.py | 59 +++++++- .../test_transcriptions.py | 127 +++++++++--------- 3 files changed, 148 insertions(+), 79 deletions(-) diff --git a/arkindex_worker/worker/transcription.py b/arkindex_worker/worker/transcription.py index 2dd27a60..d89c1425 100644 --- a/arkindex_worker/worker/transcription.py +++ b/arkindex_worker/worker/transcription.py @@ -248,24 +248,39 @@ class TranscriptionMixin(object): ), "worker_version should be of type str" query_params["worker_version"] = worker_version - if self.use_cache and recursive is None: - # Checking that we only received query_params handled by the cache - assert set(query_params.keys()) <= { - "worker_version", - }, "When using the local cache, you can only filter by 'worker_version'" - - transcriptions = CachedTranscription.select().where( - CachedTranscription.element_id == element.id - ) + if self.use_cache: + if not recursive: + # In this case we don't have to return anything, it's easier to use an + # impossible condition (False) rather than filtering by type for nothing + if element_type and element_type != element.type: + return CachedTranscription.select().where(False) + transcriptions = CachedTranscription.select().where( + CachedTranscription.element_id == element.id + ) + else: + base_case = ( + CachedElement.select() + .where(CachedElement.id == element.id) + .cte("base", recursive=True) + ) + recursive = CachedElement.select().join( + base_case, on=(CachedElement.parent_id == base_case.c.id) + ) + cte = base_case.union_all(recursive) + transcriptions = ( + CachedTranscription.select() + .join(cte, on=(CachedTranscription.element_id == cte.c.id)) + .with_cte(cte) + ) + + if element_type: + transcriptions = transcriptions.where(cte.c.type == element_type) + if worker_version: transcriptions = transcriptions.where( CachedTranscription.worker_version_id == worker_version ) else: - if self.use_cache: - logger.warning( - "'recursive' filter was set, results will be retrieved from the API since the local cache doesn't handle this filter." - ) transcriptions = self.api_client.paginate( "ListTranscriptions", id=element.id, **query_params ) diff --git a/tests/conftest.py b/tests/conftest.py index e8581194..840784ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -269,22 +269,71 @@ def mock_cached_elements(): def mock_cached_transcriptions(): """Insert few transcriptions in local cache, on a shared element""" CachedElement.create( - id=UUID("12341234-1234-1234-1234-123412341234"), + id=UUID("11111111-1111-1111-1111-111111111111"), + type="page", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + CachedElement.create( + id=UUID("22222222-2222-2222-2222-222222222222"), + type="something_else", + parent_id=UUID("11111111-1111-1111-1111-111111111111"), + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + CachedElement.create( + id=UUID("33333333-3333-3333-3333-333333333333"), type="page", + parent_id=UUID("11111111-1111-1111-1111-111111111111"), + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + CachedElement.create( + id=UUID("44444444-4444-4444-4444-444444444444"), + type="something_else", + parent_id=UUID("22222222-2222-2222-2222-222222222222"), + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + CachedElement.create( + id=UUID("55555555-5555-5555-5555-555555555555"), + type="something_else", + parent_id=UUID("44444444-4444-4444-4444-444444444444"), polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), ) CachedTranscription.create( id=UUID("11111111-1111-1111-1111-111111111111"), - element_id=UUID("12341234-1234-1234-1234-123412341234"), - text="Hello!", + element_id=UUID("11111111-1111-1111-1111-111111111111"), + text="This", confidence=0.42, worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), ) CachedTranscription.create( id=UUID("22222222-2222-2222-2222-222222222222"), - element_id=UUID("12341234-1234-1234-1234-123412341234"), - text="How are you?", + element_id=UUID("22222222-2222-2222-2222-222222222222"), + text="is", + confidence=0.42, + worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), + ) + CachedTranscription.create( + id=UUID("33333333-3333-3333-3333-333333333333"), + element_id=UUID("33333333-3333-3333-3333-333333333333"), + text="a", + confidence=0.42, + worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), + ) + CachedTranscription.create( + id=UUID("44444444-4444-4444-4444-444444444444"), + element_id=UUID("44444444-4444-4444-4444-444444444444"), + text="good", + confidence=0.42, + worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), + ) + CachedTranscription.create( + id=UUID("55555555-5555-5555-5555-555555555555"), + element_id=UUID("55555555-5555-5555-5555-555555555555"), + text="test", confidence=0.42, worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), ) diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index 72c9a6ab..586b2d2d 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -1262,79 +1262,84 @@ def test_list_transcriptions(responses, mock_elements_worker): ] -def test_list_transcriptions_with_cache_unhandled_param( - responses, mock_elements_worker_with_cache -): - elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) - - with pytest.raises(AssertionError) as e: - mock_elements_worker_with_cache.list_transcriptions( - element=elt, element_type="page" - ) - assert ( - str(e.value) - == "When using the local cache, you can only filter by 'worker_version'" - ) - - -def test_list_transcriptions_with_cache_skip_recursive( - responses, mock_elements_worker_with_cache -): - # When the local cache is activated and the user defines the recursive filter, we should fallback to the API - elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) - trans = [ - { - "id": "0000", - "text": "hey", - "confidence": 0.42, - "worker_version_id": "56785678-5678-5678-5678-567856785678", - "element": None, - }, - ] - responses.add( - responses.GET, - "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True", - status=200, - json={ - "count": 3, - "next": None, - "results": trans, - }, - ) - - for idx, transcription in enumerate( - mock_elements_worker_with_cache.list_transcriptions(element=elt, recursive=True) - ): - assert transcription == trans[idx] - - assert len(responses.calls) == 3 - assert [call.request.url for call in responses.calls] == [ - "http://testserver/api/v1/user/", - "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", - "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True", - ] - - @pytest.mark.parametrize( "filters, expected_ids", ( - # Filter on element should give all elements inserted + # Filter on element should give first transcription + ( + { + "element": CachedElement( + id="11111111-1111-1111-1111-111111111111", type="page" + ), + }, + ("11111111-1111-1111-1111-111111111111",), + ), + # Filter on element and element_type should give first transcription + ( + { + "element": CachedElement( + id="11111111-1111-1111-1111-111111111111", type="page" + ), + "element_type": "page", + }, + ("11111111-1111-1111-1111-111111111111",), + ), + # Filter on element and worker_version should give first transcription ( { - "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), + "element": CachedElement( + id="11111111-1111-1111-1111-111111111111", type="page" + ), + "worker_version": "56785678-5678-5678-5678-567856785678", + }, + ("11111111-1111-1111-1111-111111111111",), + ), + # Filter recursively on element should give all transcriptions inserted + ( + { + "element": CachedElement( + id="11111111-1111-1111-1111-111111111111", type="page" + ), + "recursive": True, }, ( "11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222", + "33333333-3333-3333-3333-333333333333", + "44444444-4444-4444-4444-444444444444", + "55555555-5555-5555-5555-555555555555", ), ), - # Filter on element and worker version should give first element + # Filter recursively on element and worker_version should give four transcriptions ( { - "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), - "worker_version": "56785678-5678-5678-5678-567856785678", + "element": CachedElement( + id="11111111-1111-1111-1111-111111111111", type="page" + ), + "worker_version": "90129012-9012-9012-9012-901290129012", + "recursive": True, }, - ("11111111-1111-1111-1111-111111111111",), + ( + "22222222-2222-2222-2222-222222222222", + "33333333-3333-3333-3333-333333333333", + "44444444-4444-4444-4444-444444444444", + "55555555-5555-5555-5555-555555555555", + ), + ), + # Filter recursively on element and element_type should give three transcriptions + ( + { + "element": CachedElement( + id="11111111-1111-1111-1111-111111111111", type="page" + ), + "element_type": "something_else", + "recursive": True, + }, + ( + "22222222-2222-2222-2222-222222222222", + "44444444-4444-4444-4444-444444444444", + "55555555-5555-5555-5555-555555555555", + ), ), ), ) @@ -1345,8 +1350,8 @@ def test_list_transcriptions_with_cache( filters, expected_ids, ): - # Check we have 2 elements already present in database - assert CachedTranscription.select().count() == 2 + # Check we have 5 elements already present in database + assert CachedTranscription.select().count() == 5 # Query database through cache transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters) -- GitLab