Skip to content
Snippets Groups Projects
Commit b7b06c42 authored by Eva Bardou's avatar Eva Bardou
Browse files

Handle recursive and element_type filters in cached list_transcriptions

parent cafc198d
No related branches found
No related tags found
1 merge request!94Handle recursive and element_type filters in cached list_transcriptions
Pipeline #78468 passed
......@@ -248,24 +248,34 @@ class TranscriptionMixin(object):
), "worker_version should be of type str"
query_params["worker_version"] = worker_version
if self.use_cache and recursive is None:
# Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= {
"worker_version",
}, "When using the local cache, you can only filter by 'worker_version'"
if self.use_cache:
elements_query = CachedElement.select().where(
CachedElement.id == element.id
)
type_attr = CachedElement.type
if recursive:
base_case = elements_query.cte("base", recursive=True)
recursive = CachedElement.select().join(
base_case, on=(CachedElement.parent_id == base_case.c.id)
)
cte = base_case.union_all(recursive)
elements_query = cte.select_from(cte.c.id, cte.c.type)
type_attr = cte.c.type
if element_type:
elements_query = elements_query.where(type_attr == element_type)
elements_ids = [elem.id for elem in elements_query]
transcriptions = CachedTranscription.select().where(
CachedTranscription.element_id == element.id
CachedTranscription.element_id.in_(elements_ids)
)
if worker_version:
transcriptions = transcriptions.where(
CachedTranscription.worker_version_id == worker_version
)
else:
if self.use_cache:
logger.warning(
"'recursive' filter was set, results will be retrieved from the API since the local cache doesn't handle this filter."
)
transcriptions = self.api_client.paginate(
"ListTranscriptions", id=element.id, **query_params
)
......
......@@ -269,22 +269,71 @@ def mock_cached_elements():
def mock_cached_transcriptions():
"""Insert few transcriptions in local cache, on a shared element"""
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
id=UUID("11111111-1111-1111-1111-111111111111"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
type="something_else",
parent_id=UUID("11111111-1111-1111-1111-111111111111"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("33333333-3333-3333-3333-333333333333"),
type="page",
parent_id=UUID("11111111-1111-1111-1111-111111111111"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("44444444-4444-4444-4444-444444444444"),
type="something_else",
parent_id=UUID("22222222-2222-2222-2222-222222222222"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("55555555-5555-5555-5555-555555555555"),
type="something_else",
parent_id=UUID("44444444-4444-4444-4444-444444444444"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello!",
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="This",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="How are you?",
element_id=UUID("22222222-2222-2222-2222-222222222222"),
text="is",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("33333333-3333-3333-3333-333333333333"),
element_id=UUID("33333333-3333-3333-3333-333333333333"),
text="a",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("44444444-4444-4444-4444-444444444444"),
element_id=UUID("44444444-4444-4444-4444-444444444444"),
text="good",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("55555555-5555-5555-5555-555555555555"),
element_id=UUID("55555555-5555-5555-5555-555555555555"),
text="test",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
......
......@@ -1262,80 +1262,52 @@ def test_list_transcriptions(responses, mock_elements_worker):
]
def test_list_transcriptions_with_cache_unhandled_param(
responses, mock_elements_worker_with_cache
):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker_with_cache.list_transcriptions(
element=elt, element_type="page"
)
assert (
str(e.value)
== "When using the local cache, you can only filter by 'worker_version'"
)
def test_list_transcriptions_with_cache_skip_recursive(
responses, mock_elements_worker_with_cache
):
# When the local cache is activated and the user defines the recursive filter, we should fallback to the API
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
trans = [
{
"id": "0000",
"text": "hey",
"confidence": 0.42,
"worker_version_id": "56785678-5678-5678-5678-567856785678",
"element": None,
},
]
responses.add(
responses.GET,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True",
status=200,
json={
"count": 3,
"next": None,
"results": trans,
},
)
for idx, transcription in enumerate(
mock_elements_worker_with_cache.list_transcriptions(element=elt, recursive=True)
):
assert transcription == trans[idx]
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True",
]
@pytest.mark.parametrize(
"filters, expected_ids",
(
# Filter on element should give all elements inserted
# Filter on element should give first transcription
(
{
"element": CachedElement(id="11111111-1111-1111-1111-111111111111"),
},
("11111111-1111-1111-1111-111111111111",),
),
# Filter recursively on element should give all transcriptions inserted
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"element": CachedElement(id="11111111-1111-1111-1111-111111111111"),
"recursive": True,
},
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
"33333333-3333-3333-3333-333333333333",
"44444444-4444-4444-4444-444444444444",
"55555555-5555-5555-5555-555555555555",
),
),
# Filter on element and worker version should give first element
# Filter recursively on element and worker_version should give first transcription
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"element": CachedElement(id="11111111-1111-1111-1111-111111111111"),
"worker_version": "56785678-5678-5678-5678-567856785678",
"recursive": True,
},
("11111111-1111-1111-1111-111111111111",),
),
# Filter recursively on element and element_type should give three transcriptions
(
{
"element": CachedElement(id="11111111-1111-1111-1111-111111111111"),
"element_type": "something_else",
"recursive": True,
},
(
"22222222-2222-2222-2222-222222222222",
"44444444-4444-4444-4444-444444444444",
"55555555-5555-5555-5555-555555555555",
),
),
),
)
def test_list_transcriptions_with_cache(
......@@ -1345,8 +1317,8 @@ def test_list_transcriptions_with_cache(
filters,
expected_ids,
):
# Check we have 2 elements already present in database
assert CachedTranscription.select().count() == 2
# Check we have 5 elements already present in database
assert CachedTranscription.select().count() == 5
# Query database through cache
transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment