diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 4ea4eee87ff82282393d32a659390b91d0e7b33b..7ea525ead5eef5e24c0355bd8a6d88b905671286 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -4,6 +4,7 @@ import json import pytest from apistar.exceptions import ErrorResponse +from arkindex_worker.cache import CachedElement from arkindex_worker.models import Element @@ -400,6 +401,46 @@ def test_create_classification(responses, mock_elements_worker): ] == {"a_class": 1} +def test_create_classification_with_cached_element(responses, mock_elements_worker): + mock_elements_worker.classes = { + "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} + } + elt = CachedElement(id="12341234-1234-1234-1234-123412341234") + + responses.add( + responses.POST, + "http://testserver/api/v1/classifications/", + status=200, + ) + + mock_elements_worker.create_classification( + element=elt, + ml_class="a_class", + confidence=0.42, + high_confidence=True, + ) + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/classifications/", + ] + + assert json.loads(responses.calls[2].request.body) == { + "element": "12341234-1234-1234-1234-123412341234", + "ml_class": "0000", + "worker_version": "12341234-1234-1234-1234-123412341234", + "confidence": 0.42, + "high_confidence": True, + } + + # Classification has been created and reported + assert mock_elements_worker.report.report_data["elements"][elt.id][ + "classifications" + ] == {"a_class": 1} + + def test_create_classification_duplicate(responses, mock_elements_worker): mock_elements_worker.classes = { "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 0c790b6ac5a99b8138a35f8152f783a22f5af390..018bd9cbde155445c240c6fbef30b846eff56c7f 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -1131,7 +1131,7 @@ def test_list_element_children_with_cache_unhandled_param( # Filter on element should give all elements inserted ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), }, ( "11111111-1111-1111-1111-111111111111", @@ -1141,7 +1141,7 @@ def test_list_element_children_with_cache_unhandled_param( # Filter on element and page should give the second element ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "type": "page", }, ("22222222-2222-2222-2222-222222222222",), @@ -1149,7 +1149,7 @@ def test_list_element_children_with_cache_unhandled_param( # Filter on element and worker version should give all elements ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "worker_version": "56785678-5678-5678-5678-567856785678", }, ( @@ -1160,7 +1160,7 @@ def test_list_element_children_with_cache_unhandled_param( # Filter on element, type something and worker version should give first ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "type": "something", "worker_version": "56785678-5678-5678-5678-567856785678", }, diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index de464ac8ab840888639aa031ea3739ef3b71fb93..72c9a6ab6828e4750ff39ecf03c11448f8066004 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -143,7 +143,42 @@ def test_create_transcription_api_error(responses, mock_elements_worker): ] -def test_create_transcription(responses, mock_elements_worker_with_cache): +def test_create_transcription(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + f"http://testserver/api/v1/element/{elt.id}/transcription/", + status=200, + json={ + "id": "56785678-5678-5678-5678-567856785678", + "text": "i am a line", + "score": 0.42, + "confidence": 0.42, + "worker_version_id": "12341234-1234-1234-1234-123412341234", + }, + ) + + mock_elements_worker.create_transcription( + element=elt, + text="i am a line", + score=0.42, + ) + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + f"http://testserver/api/v1/element/{elt.id}/transcription/", + ] + + assert json.loads(responses.calls[2].request.body) == { + "text": "i am a line", + "worker_version": "12341234-1234-1234-1234-123412341234", + "score": 0.42, + } + + +def test_create_transcription_with_cache(responses, mock_elements_worker_with_cache): elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") responses.add( @@ -933,7 +968,72 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker ] -def test_create_element_transcriptions(responses, mock_elements_worker_with_cache): +def test_create_element_transcriptions(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", + status=200, + json=[ + { + "id": "56785678-5678-5678-5678-567856785678", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + { + "id": "67896789-6789-6789-6789-678967896789", + "element_id": "22222222-2222-2222-2222-222222222222", + "created": False, + }, + { + "id": "78907890-7890-7890-7890-789078907890", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + ], + ) + + annotations = mock_elements_worker.create_element_transcriptions( + element=elt, + sub_element_type="page", + transcriptions=TRANSCRIPTIONS_SAMPLE, + ) + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", + ] + + assert json.loads(responses.calls[2].request.body) == { + "element_type": "page", + "worker_version": "12341234-1234-1234-1234-123412341234", + "transcriptions": TRANSCRIPTIONS_SAMPLE, + "return_elements": True, + } + assert annotations == [ + { + "id": "56785678-5678-5678-5678-567856785678", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + { + "id": "67896789-6789-6789-6789-678967896789", + "element_id": "22222222-2222-2222-2222-222222222222", + "created": False, + }, + { + "id": "78907890-7890-7890-7890-789078907890", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + ] + + +def test_create_element_transcriptions_with_cache( + responses, mock_elements_worker_with_cache +): elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing") responses.add( @@ -1221,7 +1321,7 @@ def test_list_transcriptions_with_cache_skip_recursive( # Filter on element should give all elements inserted ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), }, ( "11111111-1111-1111-1111-111111111111", @@ -1231,7 +1331,7 @@ def test_list_transcriptions_with_cache_skip_recursive( # Filter on element and worker version should give first element ( { - "element": Element({"id": "12341234-1234-1234-1234-123412341234"}), + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "worker_version": "56785678-5678-5678-5678-567856785678", }, ("11111111-1111-1111-1111-111111111111",),