# -*- coding: utf-8 -*- import json import os from pathlib import Path import pytest from apistar.exceptions import ErrorResponse from arkindex_worker.cache import CachedElement, CachedTranscription from arkindex_worker.models import Element from arkindex_worker.utils import convert_str_uuid_to_hex CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache" TRANSCRIPTIONS_SAMPLE = [ { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], "score": 0.5, "text": "The", }, { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "first", }, { "polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]], "score": 0.9, "text": "line", }, ] def test_create_transcription_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: mock_elements_worker.create_transcription( element=None, text="i am a line", score=0.42, ) assert str(e.value) == "element shouldn't be null and should be of type Element" with pytest.raises(AssertionError) as e: mock_elements_worker.create_transcription( element="not element type", text="i am a line", score=0.42, ) assert str(e.value) == "element shouldn't be null and should be of type Element" def test_create_transcription_wrong_text(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.create_transcription( element=elt, text=None, score=0.42, ) assert str(e.value) == "text shouldn't be null and should be of type str" with pytest.raises(AssertionError) as e: mock_elements_worker.create_transcription( element=elt, text=1234, score=0.42, ) assert str(e.value) == "text shouldn't be null and should be of type str" def test_create_transcription_wrong_score(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.create_transcription( element=elt, text="i am a line", score=None, ) assert ( str(e.value) == "score shouldn't be null and should be a float in [0..1] range" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_transcription( element=elt, text="i am a line", score="wrong score", ) assert ( str(e.value) == "score shouldn't be null and should be a float in [0..1] range" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_transcription( element=elt, text="i am a line", score=0, ) assert ( str(e.value) == "score shouldn't be null and should be a float in [0..1] range" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_transcription( element=elt, text="i am a line", score=2.00, ) assert ( str(e.value) == "score shouldn't be null and should be a float in [0..1] range" ) def test_create_transcription_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, f"http://testserver/api/v1/element/{elt.id}/transcription/", status=500, ) with pytest.raises(ErrorResponse): mock_elements_worker.create_transcription( element=elt, text="i am a line", score=0.42, ) assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/element/{elt.id}/transcription/", ] def test_create_transcription(responses, mock_elements_worker_with_cache): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, f"http://testserver/api/v1/element/{elt.id}/transcription/", status=200, json={ "id": "56785678-5678-5678-5678-567856785678", "text": "i am a line", "score": 0.42, "confidence": 0.42, "worker_version_id": "12341234-1234-1234-1234-123412341234", }, ) mock_elements_worker_with_cache.create_transcription( element=elt, text="i am a line", score=0.42, ) assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/element/{elt.id}/transcription/", ] assert json.loads(responses.calls[2].request.body) == { "text": "i am a line", "worker_version": "12341234-1234-1234-1234-123412341234", "score": 0.42, } # Check that created transcription was properly stored in SQLite cache cache_path = f"{CACHE_DIR}/db.sqlite" assert os.path.isfile(cache_path) rows = mock_elements_worker_with_cache.cache.cursor.execute( "SELECT * FROM transcriptions" ).fetchall() assert [CachedTranscription(**dict(row)) for row in rows] == [ CachedTranscription( id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"), element_id=convert_str_uuid_to_hex(elt.id), text="i am a line", confidence=0.42, worker_version_id=convert_str_uuid_to_hex( "12341234-1234-1234-1234-123412341234" ), ) ] def test_create_element_transcriptions_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=None, sub_element_type="page", transcriptions=TRANSCRIPTIONS_SAMPLE, ) assert str(e.value) == "element shouldn't be null and should be of type Element" with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element="not element type", sub_element_type="page", transcriptions=TRANSCRIPTIONS_SAMPLE, ) assert str(e.value) == "element shouldn't be null and should be of type Element" def test_create_element_transcriptions_wrong_sub_element_type(mock_elements_worker): elt = Element({"zone": None}) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type=None, transcriptions=TRANSCRIPTIONS_SAMPLE, ) assert ( str(e.value) == "sub_element_type shouldn't be null and should be of type str" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type=1234, transcriptions=TRANSCRIPTIONS_SAMPLE, ) assert ( str(e.value) == "sub_element_type shouldn't be null and should be of type str" ) def test_create_element_transcriptions_wrong_transcriptions(mock_elements_worker): elt = Element({"zone": None}) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=None, ) assert str(e.value) == "transcriptions shouldn't be null and should be of type list" with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=1234, ) assert str(e.value) == "transcriptions shouldn't be null and should be of type list" with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], "score": 0.5, }, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], "score": 0.5, "text": None, }, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], "score": 0.5, "text": 1234, }, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], "text": "word", }, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], "score": None, "text": "word", }, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], "score": "a wrong score", "text": "word", }, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], "score": 0, "text": "word", }, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], "score": 2.00, "text": "word", }, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, {"score": 0.5, "text": "word"}, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: polygon shouldn't be null and should be of type list" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, {"polygon": None, "score": 0.5, "text": "word"}, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: polygon shouldn't be null and should be of type list" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, {"polygon": "not a polygon", "score": 0.5, "text": "word"}, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: polygon shouldn't be null and should be of type list" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, {"polygon": [[1, 1], [2, 2]], "score": 0.5, "text": "word"}, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: polygon should have at least three points" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, { "polygon": [[1, 1, 1], [2, 2, 1], [2, 1, 1], [1, 2, 1]], "score": 0.5, "text": "word", }, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: polygon points should be lists of two items" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, {"polygon": [[1], [2], [2], [1]], "score": 0.5, "text": "word"}, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: polygon points should be lists of two items" ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=[ { "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], "score": 0.75, "text": "The", }, { "polygon": [["not a coord", 1], [2, 2], [2, 1], [1, 2]], "score": 0.5, "text": "word", }, ], ) assert ( str(e.value) == "Transcription at index 1 in transcriptions: polygon points should be lists of two numbers" ) def test_create_element_transcriptions_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", status=500, ) with pytest.raises(ErrorResponse): mock_elements_worker.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=TRANSCRIPTIONS_SAMPLE, ) assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", ] def test_create_element_transcriptions( mocker, responses, mock_elements_worker_with_cache ): mocker.patch( "uuid.uuid4", side_effect=[ "56785678-5678-5678-5678-567856785678", "67896789-6789-6789-6789-678967896789", "78907890-7890-7890-7890-789078907890", ], ) elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", status=200, json=[ {"id": "11111111-1111-1111-1111-111111111111", "created": True}, {"id": "22222222-2222-2222-2222-222222222222", "created": False}, {"id": "11111111-1111-1111-1111-111111111111", "created": True}, ], ) annotations = mock_elements_worker_with_cache.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=TRANSCRIPTIONS_SAMPLE, ) assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", ] assert json.loads(responses.calls[2].request.body) == { "element_type": "page", "worker_version": "12341234-1234-1234-1234-123412341234", "transcriptions": TRANSCRIPTIONS_SAMPLE, "return_elements": True, } assert annotations == [ {"id": "11111111-1111-1111-1111-111111111111", "created": True}, {"id": "22222222-2222-2222-2222-222222222222", "created": False}, {"id": "11111111-1111-1111-1111-111111111111", "created": True}, ] # Check that created transcriptions and elements were properly stored in SQLite cache cache_path = f"{CACHE_DIR}/db.sqlite" assert os.path.isfile(cache_path) rows = mock_elements_worker_with_cache.cache.cursor.execute( "SELECT * FROM elements" ).fetchall() assert [CachedElement(**dict(row)) for row in rows] == [ CachedElement( id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), type="page", polygon=json.dumps([[100, 150], [700, 150], [700, 200], [100, 200]]), worker_version_id=convert_str_uuid_to_hex( "12341234-1234-1234-1234-123412341234" ), ) ] rows = mock_elements_worker_with_cache.cache.cursor.execute( "SELECT * FROM transcriptions" ).fetchall() assert [CachedTranscription(**dict(row)) for row in rows] == [ CachedTranscription( id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"), element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), text="The", confidence=0.5, worker_version_id=convert_str_uuid_to_hex( "12341234-1234-1234-1234-123412341234" ), ), CachedTranscription( id=convert_str_uuid_to_hex("67896789-6789-6789-6789-678967896789"), element_id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"), text="first", confidence=0.75, worker_version_id=convert_str_uuid_to_hex( "12341234-1234-1234-1234-123412341234" ), ), CachedTranscription( id=convert_str_uuid_to_hex("78907890-7890-7890-7890-789078907890"), element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), text="line", confidence=0.9, worker_version_id=convert_str_uuid_to_hex( "12341234-1234-1234-1234-123412341234" ), ), ] def test_list_transcriptions_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: mock_elements_worker.list_transcriptions(element=None) assert str(e.value) == "element shouldn't be null and should be of type Element" with pytest.raises(AssertionError) as e: mock_elements_worker.list_transcriptions(element="not element type") assert str(e.value) == "element shouldn't be null and should be of type Element" def test_list_transcriptions_wrong_element_type(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.list_transcriptions( element=elt, element_type=1234, ) assert str(e.value) == "element_type should be of type str" def test_list_transcriptions_wrong_recursive(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.list_transcriptions( element=elt, recursive="not bool", ) assert str(e.value) == "recursive should be of type bool" def test_list_transcriptions_wrong_worker_version(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.list_transcriptions( element=elt, worker_version=1234, ) assert str(e.value) == "worker_version should be of type str" def test_list_transcriptions_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.GET, "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", status=500, ) with pytest.raises( Exception, match="Stopping pagination as data will be incomplete" ): next(mock_elements_worker.list_transcriptions(element=elt)) assert len(responses.calls) == 7 assert [call.request.url for call in responses.calls] == [ "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", # We do 5 retries "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", ] def test_list_transcriptions(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) trans = [ { "id": "0000", "text": "hey", "confidence": 0.42, "worker_version_id": "56785678-5678-5678-5678-567856785678", "element": None, }, { "id": "1111", "text": "it's", "confidence": 0.42, "worker_version_id": "56785678-5678-5678-5678-567856785678", "element": None, }, { "id": "2222", "text": "me", "confidence": 0.42, "worker_version_id": "56785678-5678-5678-5678-567856785678", "element": None, }, ] responses.add( responses.GET, "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", status=200, json={ "count": 3, "next": None, "results": trans, }, ) for idx, transcription in enumerate( mock_elements_worker.list_transcriptions(element=elt) ): assert transcription == trans[idx] assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", ]