From dcead0963c7f9a52fbef8cc7a321bfb06c0f6b3b Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Mon, 29 Mar 2021 19:06:05 +0000 Subject: [PATCH] Retrieve transcriptions from local cache in list_transcriptions --- arkindex_worker/cache.py | 4 + arkindex_worker/worker.py | 24 +++- tests/conftest.py | 47 +++++++ tests/test_elements_worker/test_elements.py | 117 +++++++++--------- .../test_transcriptions.py | 100 +++++++++++++++ 5 files changed, 230 insertions(+), 62 deletions(-) diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index cec327ab..1322e2bc 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import json +import logging from peewee import ( BooleanField, @@ -13,6 +14,8 @@ from peewee import ( UUIDField, ) +logger = logging.getLogger(__name__) + db = SqliteDatabase(None) @@ -65,6 +68,7 @@ def init_cache_db(path): }, ) db.connect() + logger.info(f"Connected to cache on {path}") def create_tables(): diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index cc66f345..1a4282c9 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -935,9 +935,27 @@ class ElementsWorker(BaseWorker): ), "worker_version should be of type str" query_params["worker_version"] = worker_version - transcriptions = self.api_client.paginate( - "ListTranscriptions", id=element.id, **query_params - ) + if self.use_cache and recursive is None: + # Checking that we only received query_params handled by the cache + assert set(query_params.keys()) <= { + "worker_version", + }, "When using the local cache, you can only filter by 'worker_version'" + + transcriptions = CachedTranscription.select().where( + CachedTranscription.element_id == element.id + ) + if worker_version: + transcriptions = transcriptions.where( + CachedTranscription.worker_version_id == worker_version + ) + else: + if self.use_cache: + logger.warning( + "'recursive' filter was set, results will be retrieved from the API since the local cache doesn't handle this filter." + ) + transcriptions = self.api_client.paginate( + "ListTranscriptions", id=element.id, **query_params + ) return transcriptions diff --git a/tests/conftest.py b/tests/conftest.py index 1d6c208e..ca6cd6f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,11 +5,13 @@ import os import sys import time from pathlib import Path +from uuid import UUID import pytest import yaml from arkindex.mock import MockApiClient +from arkindex_worker.cache import CachedElement, CachedTranscription from arkindex_worker.git import GitHelper, GitlabHelper from arkindex_worker.worker import ElementsWorker @@ -228,3 +230,48 @@ def fake_gitlab_helper_factory(): ) return run + + +@pytest.fixture +def mock_cached_elements(): + """Insert few elements in local cache""" + CachedElement.create( + id=UUID("11111111-1111-1111-1111-111111111111"), + parent_id="12341234-1234-1234-1234-123412341234", + type="something", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + CachedElement.create( + id=UUID("22222222-2222-2222-2222-222222222222"), + parent_id=UUID("12341234-1234-1234-1234-123412341234"), + type="page", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + assert CachedElement.select().count() == 2 + + +@pytest.fixture +def mock_cached_transcriptions(): + """Insert few transcriptions in local cache, on a shared element""" + CachedElement.create( + id=UUID("12341234-1234-1234-1234-123412341234"), + type="page", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + CachedTranscription.create( + id=UUID("11111111-1111-1111-1111-111111111111"), + element_id=UUID("12341234-1234-1234-1234-123412341234"), + text="Hello!", + confidence=0.42, + worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), + ) + CachedTranscription.create( + id=UUID("22222222-2222-2222-2222-222222222222"), + element_id=UUID("12341234-1234-1234-1234-123412341234"), + text="How are you?", + confidence=0.42, + worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), + ) diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index a978b970..ed782475 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -12,23 +12,6 @@ from arkindex_worker.cache import CachedElement from arkindex_worker.models import Element from arkindex_worker.worker import ElementsWorker -ELEMENTS_TO_INSERT = [ - CachedElement( - id="11111111-1111-1111-1111-111111111111", - parent_id="12341234-1234-1234-1234-123412341234", - type="something", - polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", - worker_version_id="56785678-5678-5678-5678-567856785678", - ), - CachedElement( - id="22222222-2222-2222-2222-222222222222", - parent_id="12341234-1234-1234-1234-123412341234", - type="something", - polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", - worker_version_id="56785678-5678-5678-5678-567856785678", - ), -] - def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker): _, path = tempfile.mkstemp() @@ -987,51 +970,67 @@ def test_list_element_children_with_cache_unhandled_param( ) -def test_list_element_children_with_cache(responses, mock_elements_worker_with_cache): - elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) - - for idx, child in enumerate( - mock_elements_worker_with_cache.list_element_children(element=elt) - ): - assert child == [] - - # Initialize SQLite cache with some elements - CachedElement.insert_many(ELEMENTS_TO_INSERT) - - expected_children = ELEMENTS_TO_INSERT - - for idx, child in enumerate( - mock_elements_worker_with_cache.list_element_children(element=elt) - ): - assert child == expected_children[idx] - - expected_children = [ELEMENTS_TO_INSERT[1]] - - for idx, child in enumerate( - mock_elements_worker_with_cache.list_element_children(element=elt, type="page") - ): - assert child == expected_children[idx] - - expected_children = ELEMENTS_TO_INSERT[:2] - - for idx, child in enumerate( - mock_elements_worker_with_cache.list_element_children( - element=elt, worker_version="56785678-5678-5678-5678-567856785678" - ) - ): - assert child == expected_children[idx] +@pytest.mark.parametrize( + "filters, expected_ids", + ( + # Filter on element should give all elements inserted + ( + { + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + }, + ( + "11111111-1111-1111-1111-111111111111", + "22222222-2222-2222-2222-222222222222", + ), + ), + # Filter on element and page should give the second element + ( + { + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "type": "page", + }, + ("22222222-2222-2222-2222-222222222222",), + ), + # Filter on element and worker version should give all elements + ( + { + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "worker_version": "56785678-5678-5678-5678-567856785678", + }, + ( + "11111111-1111-1111-1111-111111111111", + "22222222-2222-2222-2222-222222222222", + ), + ), + # Filter on element, type something and worker version should give first + ( + { + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "type": "something", + "worker_version": "56785678-5678-5678-5678-567856785678", + }, + ("11111111-1111-1111-1111-111111111111",), + ), + ), +) +def test_list_element_children_with_cache( + responses, + mock_elements_worker_with_cache, + mock_cached_elements, + filters, + expected_ids, +): - expected_children = [ELEMENTS_TO_INSERT[0]] + # Check we have 2 elements already present in database + assert CachedElement.select().count() == 2 - for idx, child in enumerate( - mock_elements_worker_with_cache.list_element_children( - element=elt, - type="something", - worker_version="56785678-5678-5678-5678-567856785678", - ) - ): - assert child == expected_children[idx] + # Query database through cache + elements = mock_elements_worker_with_cache.list_element_children(**filters) + assert elements.count() == len(expected_ids) + for child, expected_id in zip(elements.order_by("id"), expected_ids): + assert child.id == UUID(expected_id) + # Check the worker never hits the API for elements assert len(responses.calls) == 2 assert [call.request.url for call in responses.calls] == [ "http://testserver/api/v1/user/", diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index 8caf3963..20042876 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -1154,3 +1154,103 @@ def test_list_transcriptions(responses, mock_elements_worker): "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", ] + + +def test_list_transcriptions_with_cache_unhandled_param( + responses, mock_elements_worker_with_cache +): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker_with_cache.list_transcriptions( + element=elt, element_type="page" + ) + assert ( + str(e.value) + == "When using the local cache, you can only filter by 'worker_version'" + ) + + +def test_list_transcriptions_with_cache_skip_recursive( + responses, mock_elements_worker_with_cache +): + # When the local cache is activated and the user defines the recursive filter, we should fallback to the API + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + trans = [ + { + "id": "0000", + "text": "hey", + "confidence": 0.42, + "worker_version_id": "56785678-5678-5678-5678-567856785678", + "element": None, + }, + ] + responses.add( + responses.GET, + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True", + status=200, + json={ + "count": 3, + "next": None, + "results": trans, + }, + ) + + for idx, transcription in enumerate( + mock_elements_worker_with_cache.list_transcriptions(element=elt, recursive=True) + ): + assert transcription == trans[idx] + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True", + ] + + +@pytest.mark.parametrize( + "filters, expected_ids", + ( + # Filter on element should give all elements inserted + ( + { + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + }, + ( + "11111111-1111-1111-1111-111111111111", + "22222222-2222-2222-2222-222222222222", + ), + ), + # Filter on element and worker version should give first element + ( + { + "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "worker_version": "56785678-5678-5678-5678-567856785678", + }, + ("11111111-1111-1111-1111-111111111111",), + ), + ), +) +def test_list_transcriptions_with_cache( + responses, + mock_elements_worker_with_cache, + mock_cached_transcriptions, + filters, + expected_ids, +): + # Check we have 2 elements already present in database + assert CachedTranscription.select().count() == 2 + + # Query database through cache + transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters) + assert transcriptions.count() == len(expected_ids) + for transcription, expected_id in zip(transcriptions.order_by("id"), expected_ids): + assert transcription.id == UUID(expected_id) + + # Check the worker never hits the API for elements + assert len(responses.calls) == 2 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + ] -- GitLab