From 02620644e53e118434dfd5c8bf00e68b978392e5 Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Wed, 24 Mar 2021 14:55:44 +0100 Subject: [PATCH] Test cache insertion in transcriptions functions --- arkindex_worker/worker.py | 3 +- .../test_transcriptions.py | 112 ++++++++++++++++-- 2 files changed, 103 insertions(+), 12 deletions(-) diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index e386f40f..d9de81bd 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -700,7 +700,6 @@ class ElementsWorker(BaseWorker): transcription = transcriptions[index] element_id_hex = convert_str_uuid_to_hex(annotation["id"]) if annotation["created"] and annotation["id"] not in created_ids: - # TODO: Retrieve real element_name through API elements_to_insert.append( CachedElement( id=element_id_hex, @@ -715,7 +714,7 @@ class ElementsWorker(BaseWorker): transcriptions_to_insert.append( CachedTranscription( # TODO: Retrieve real transcription_id through API - id=convert_str_uuid_to_hex(uuid.uuid4()), + id=convert_str_uuid_to_hex(str(uuid.uuid4())), element_id=element_id_hex, text=transcription["text"], confidence=transcription["score"], diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index 3d603fcb..60162fe2 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -1,11 +1,16 @@ # -*- coding: utf-8 -*- import json +import os +from pathlib import Path import pytest from apistar.exceptions import ErrorResponse +from arkindex_worker.cache import CachedElement, CachedTranscription from arkindex_worker.models import Element +from arkindex_worker.utils import convert_str_uuid_to_hex +CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache" TRANSCRIPTIONS_SAMPLE = [ { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], @@ -130,15 +135,22 @@ def test_create_transcription_api_error(responses, mock_elements_worker): ] -def test_create_transcription(responses, mock_elements_worker): +def test_create_transcription(responses, mock_elements_worker_with_cache): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, f"http://testserver/api/v1/element/{elt.id}/transcription/", status=200, + json={ + "id": "56785678-5678-5678-5678-567856785678", + "text": "i am a line", + "score": 0.42, + "confidence": 0.42, + "worker_version_id": "12341234-1234-1234-1234-123412341234", + }, ) - mock_elements_worker.create_transcription( + mock_elements_worker_with_cache.create_transcription( element=elt, text="i am a line", score=0.42, @@ -157,6 +169,25 @@ def test_create_transcription(responses, mock_elements_worker): "score": 0.42, } + # Check that created transcriptions and elements were properly stored in SQLite cache + cache_path = f"{CACHE_DIR}/db.sqlite" + assert os.path.isfile(cache_path) + + rows = mock_elements_worker_with_cache.cache.cursor.execute( + "SELECT * FROM transcriptions" + ).fetchall() + assert [CachedTranscription(**dict(row)) for row in rows] == [ + CachedTranscription( + id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"), + element_id=convert_str_uuid_to_hex(elt.id), + text="i am a line", + confidence=0.42, + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ) + ] + def test_create_element_transcriptions_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: @@ -551,20 +582,30 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker ] -def test_create_element_transcriptions(responses, mock_elements_worker): +def test_create_element_transcriptions( + mocker, responses, mock_elements_worker_with_cache +): + mocker.patch( + "uuid.uuid4", + side_effect=[ + "56785678-5678-5678-5678-567856785678", + "67896789-6789-6789-6789-678967896789", + "78907890-7890-7890-7890-789078907890", + ], + ) elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", status=200, json=[ - {"id": "word1_1_1", "created": False}, - {"id": "word1_1_2", "created": False}, - {"id": "word1_1_3", "created": False}, + {"id": "11111111-1111-1111-1111-111111111111", "created": True}, + {"id": "22222222-2222-2222-2222-222222222222", "created": False}, + {"id": "11111111-1111-1111-1111-111111111111", "created": True}, ], ) - annotations = mock_elements_worker.create_element_transcriptions( + annotations = mock_elements_worker_with_cache.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=TRANSCRIPTIONS_SAMPLE, @@ -584,9 +625,60 @@ def test_create_element_transcriptions(responses, mock_elements_worker): "return_elements": True, } assert annotations == [ - {"id": "word1_1_1", "created": False}, - {"id": "word1_1_2", "created": False}, - {"id": "word1_1_3", "created": False}, + {"id": "11111111-1111-1111-1111-111111111111", "created": True}, + {"id": "22222222-2222-2222-2222-222222222222", "created": False}, + {"id": "11111111-1111-1111-1111-111111111111", "created": True}, + ] + + # Check that created transcriptions and elements were properly stored in SQLite cache + cache_path = f"{CACHE_DIR}/db.sqlite" + assert os.path.isfile(cache_path) + + rows = mock_elements_worker_with_cache.cache.cursor.execute( + "SELECT * FROM elements" + ).fetchall() + assert [CachedElement(**dict(row)) for row in rows] == [ + CachedElement( + id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + type="page", + polygon=json.dumps([[100, 150], [700, 150], [700, 200], [100, 200]]), + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ) + ] + rows = mock_elements_worker_with_cache.cache.cursor.execute( + "SELECT * FROM transcriptions" + ).fetchall() + assert [CachedTranscription(**dict(row)) for row in rows] == [ + CachedTranscription( + id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"), + element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + text="The", + confidence=0.5, + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ), + CachedTranscription( + id=convert_str_uuid_to_hex("67896789-6789-6789-6789-678967896789"), + element_id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"), + text="first", + confidence=0.75, + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ), + CachedTranscription( + id=convert_str_uuid_to_hex("78907890-7890-7890-7890-789078907890"), + element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + text="line", + confidence=0.9, + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ), ] -- GitLab