From d2c8950fd2b6d7abdf68e8262726ca706fee787d Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Fri, 9 Apr 2021 08:42:16 +0000 Subject: [PATCH] Support giving a CachedElement in helper functions using the local cache --- arkindex_worker/worker/classification.py | 5 +- arkindex_worker/worker/element.py | 4 +- arkindex_worker/worker/transcription.py | 4 +- .../test_classifications.py | 51 +++++++- tests/test_elements_worker/test_elements.py | 18 ++- .../test_transcriptions.py | 118 +++++++++++++++++- 6 files changed, 180 insertions(+), 20 deletions(-) diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index 3e979c04..8bc3b7ba 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -4,6 +4,7 @@ import os from apistar.exceptions import ErrorResponse from arkindex_worker import logger +from arkindex_worker.cache import CachedElement from arkindex_worker.models import Element @@ -65,8 +66,8 @@ class ClassificationMixin(object): Create a classification on the given element through API """ assert element and isinstance( - element, Element - ), "element shouldn't be null and should be of type Element" + element, (Element, CachedElement) + ), "element shouldn't be null and should be an Element or CachedElement" assert ml_class and isinstance( ml_class, str ), "ml_class shouldn't be null and should be of type str" diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py index ecedb26a..b25e8c51 100644 --- a/arkindex_worker/worker/element.py +++ b/arkindex_worker/worker/element.py @@ -172,8 +172,8 @@ class ElementMixin(object): List children of an element """ assert element and isinstance( - element, Element - ), "element shouldn't be null and should be of type Element" + element, (Element, CachedElement) + ), "element shouldn't be null and should be an Element or CachedElement" query_params = {} if best_class is not None: assert isinstance(best_class, str) or isinstance( diff --git a/arkindex_worker/worker/transcription.py b/arkindex_worker/worker/transcription.py index 23f1088e..2dd27a60 100644 --- a/arkindex_worker/worker/transcription.py +++ b/arkindex_worker/worker/transcription.py @@ -233,8 +233,8 @@ class TranscriptionMixin(object): List transcriptions on an element """ assert element and isinstance( - element, Element - ), "element shouldn't be null and should be of type Element" + element, (Element, CachedElement) + ), "element shouldn't be null and should be an Element or CachedElement" query_params = {} if element_type: assert isinstance(element_type, str), "element_type should be of type str" diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 20a67707..7ea525ea 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -4,6 +4,7 @@ import json import pytest from apistar.exceptions import ErrorResponse +from arkindex_worker.cache import CachedElement from arkindex_worker.models import Element @@ -159,7 +160,10 @@ def test_create_classification_wrong_element(mock_elements_worker): confidence=0.42, high_confidence=True, ) - assert str(e.value) == "element shouldn't be null and should be of type Element" + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_classification( @@ -168,7 +172,10 @@ def test_create_classification_wrong_element(mock_elements_worker): confidence=0.42, high_confidence=True, ) - assert str(e.value) == "element shouldn't be null and should be of type Element" + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) def test_create_classification_wrong_ml_class(mock_elements_worker, responses): @@ -394,6 +401,46 @@ def test_create_classification(responses, mock_elements_worker): ] == {"a_class": 1} +def test_create_classification_with_cached_element(responses, mock_elements_worker): + mock_elements_worker.classes = { + "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} + } + elt = CachedElement(id="12341234-1234-1234-1234-123412341234") + + responses.add( + responses.POST, + "http://testserver/api/v1/classifications/", + status=200, + ) + + mock_elements_worker.create_classification( + element=elt, + ml_class="a_class", + confidence=0.42, + high_confidence=True, + ) + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/classifications/", + ] + + assert json.loads(responses.calls[2].request.body) == { + "element": "12341234-1234-1234-1234-123412341234", + "ml_class": "0000", + "worker_version": "12341234-1234-1234-1234-123412341234", + "confidence": 0.42, + "high_confidence": True, + } + + # Classification has been created and reported + assert mock_elements_worker.report.report_data["elements"][elt.id][ + "classifications" + ] == {"a_class": 1} + + def test_create_classification_duplicate(responses, mock_elements_worker): mock_elements_worker.classes = { "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index a903a74d..018bd9cb 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -900,11 +900,17 @@ def test_create_elements_integrity_error( def test_list_element_children_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: mock_elements_worker.list_element_children(element=None) - assert str(e.value) == "element shouldn't be null and should be of type Element" + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) with pytest.raises(AssertionError) as e: mock_elements_worker.list_element_children(element="not element type") - assert str(e.value) == "element shouldn't be null and should be of type Element" + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) def test_list_element_children_wrong_best_class(mock_elements_worker): @@ -1125,7 +1131,7 @@ def test_list_element_children_with_cache_unhandled_param( # Filter on element should give all elements inserted ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), }, ( "11111111-1111-1111-1111-111111111111", @@ -1135,7 +1141,7 @@ def test_list_element_children_with_cache_unhandled_param( # Filter on element and page should give the second element ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "type": "page", }, ("22222222-2222-2222-2222-222222222222",), @@ -1143,7 +1149,7 @@ def test_list_element_children_with_cache_unhandled_param( # Filter on element and worker version should give all elements ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "worker_version": "56785678-5678-5678-5678-567856785678", }, ( @@ -1154,7 +1160,7 @@ def test_list_element_children_with_cache_unhandled_param( # Filter on element, type something and worker version should give first ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "type": "something", "worker_version": "56785678-5678-5678-5678-567856785678", }, diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index 20042876..72c9a6ab 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -143,7 +143,42 @@ def test_create_transcription_api_error(responses, mock_elements_worker): ] -def test_create_transcription(responses, mock_elements_worker_with_cache): +def test_create_transcription(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + f"http://testserver/api/v1/element/{elt.id}/transcription/", + status=200, + json={ + "id": "56785678-5678-5678-5678-567856785678", + "text": "i am a line", + "score": 0.42, + "confidence": 0.42, + "worker_version_id": "12341234-1234-1234-1234-123412341234", + }, + ) + + mock_elements_worker.create_transcription( + element=elt, + text="i am a line", + score=0.42, + ) + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + f"http://testserver/api/v1/element/{elt.id}/transcription/", + ] + + assert json.loads(responses.calls[2].request.body) == { + "text": "i am a line", + "worker_version": "12341234-1234-1234-1234-123412341234", + "score": 0.42, + } + + +def test_create_transcription_with_cache(responses, mock_elements_worker_with_cache): elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") responses.add( @@ -933,7 +968,72 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker ] -def test_create_element_transcriptions(responses, mock_elements_worker_with_cache): +def test_create_element_transcriptions(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", + status=200, + json=[ + { + "id": "56785678-5678-5678-5678-567856785678", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + { + "id": "67896789-6789-6789-6789-678967896789", + "element_id": "22222222-2222-2222-2222-222222222222", + "created": False, + }, + { + "id": "78907890-7890-7890-7890-789078907890", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + ], + ) + + annotations = mock_elements_worker.create_element_transcriptions( + element=elt, + sub_element_type="page", + transcriptions=TRANSCRIPTIONS_SAMPLE, + ) + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", + ] + + assert json.loads(responses.calls[2].request.body) == { + "element_type": "page", + "worker_version": "12341234-1234-1234-1234-123412341234", + "transcriptions": TRANSCRIPTIONS_SAMPLE, + "return_elements": True, + } + assert annotations == [ + { + "id": "56785678-5678-5678-5678-567856785678", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + { + "id": "67896789-6789-6789-6789-678967896789", + "element_id": "22222222-2222-2222-2222-222222222222", + "created": False, + }, + { + "id": "78907890-7890-7890-7890-789078907890", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + ] + + +def test_create_element_transcriptions_with_cache( + responses, mock_elements_worker_with_cache +): elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing") responses.add( @@ -1041,11 +1141,17 @@ def test_create_element_transcriptions(responses, mock_elements_worker_with_cach def test_list_transcriptions_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: mock_elements_worker.list_transcriptions(element=None) - assert str(e.value) == "element shouldn't be null and should be of type Element" + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) with pytest.raises(AssertionError) as e: mock_elements_worker.list_transcriptions(element="not element type") - assert str(e.value) == "element shouldn't be null and should be of type Element" + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) def test_list_transcriptions_wrong_element_type(mock_elements_worker): @@ -1215,7 +1321,7 @@ def test_list_transcriptions_with_cache_skip_recursive( # Filter on element should give all elements inserted ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), }, ( "11111111-1111-1111-1111-111111111111", @@ -1225,7 +1331,7 @@ def test_list_transcriptions_with_cache_skip_recursive( # Filter on element and worker version should give first element ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "worker_version": "56785678-5678-5678-5678-567856785678", }, ("11111111-1111-1111-1111-111111111111",), -- GitLab