Skip to content
Snippets Groups Projects
Commit 7b771d54 authored by Eva Bardou's avatar Eva Bardou
Browse files

Test cache insertion in transcriptions functions

parent 3e649e28
No related branches found
No related tags found
1 merge request!69Store Transcriptions in local cache
......@@ -700,7 +700,6 @@ class ElementsWorker(BaseWorker):
transcription = transcriptions[index]
element_id_hex = convert_str_uuid_to_hex(annotation["id"])
if annotation["created"] and annotation["id"] not in created_ids:
# TODO: Retrieve real element_name through API
elements_to_insert.append(
CachedElement(
id=element_id_hex,
......@@ -715,7 +714,7 @@ class ElementsWorker(BaseWorker):
transcriptions_to_insert.append(
CachedTranscription(
# TODO: Retrieve real transcription_id through API
id=convert_str_uuid_to_hex(uuid.uuid4()),
id=convert_str_uuid_to_hex(str(uuid.uuid4())),
element_id=element_id_hex,
text=transcription["text"],
confidence=transcription["score"],
......
# -*- coding: utf-8 -*-
import json
import os
from pathlib import Path
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement, CachedTranscription
from arkindex_worker.models import Element
from arkindex_worker.utils import convert_str_uuid_to_hex
CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache"
TRANSCRIPTIONS_SAMPLE = [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
......@@ -130,15 +135,22 @@ def test_create_transcription_api_error(responses, mock_elements_worker):
]
def test_create_transcription(responses, mock_elements_worker):
def test_create_transcription(responses, mock_elements_worker_with_cache):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcription/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"text": "i am a line",
"score": 0.42,
"confidence": 0.42,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
},
)
mock_elements_worker.create_transcription(
mock_elements_worker_with_cache.create_transcription(
element=elt,
text="i am a line",
score=0.42,
......@@ -157,6 +169,25 @@ def test_create_transcription(responses, mock_elements_worker):
"score": 0.42,
}
# Check that created transcriptions and elements were properly stored in SQLite cache
cache_path = f"{CACHE_DIR}/db.sqlite"
assert os.path.isfile(cache_path)
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM transcriptions"
).fetchall()
assert [CachedTranscription(**dict(row)) for row in rows] == [
CachedTranscription(
id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"),
element_id=convert_str_uuid_to_hex(elt.id),
text="i am a line",
confidence=0.42,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
)
]
def test_create_element_transcriptions_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
......@@ -551,20 +582,30 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
]
def test_create_element_transcriptions(responses, mock_elements_worker):
def test_create_element_transcriptions(
mocker, responses, mock_elements_worker_with_cache
):
mocker.patch(
"uuid.uuid4",
side_effect=[
"56785678-5678-5678-5678-567856785678",
"67896789-6789-6789-6789-678967896789",
"78907890-7890-7890-7890-789078907890",
],
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
status=200,
json=[
{"id": "word1_1_1", "created": False},
{"id": "word1_1_2", "created": False},
{"id": "word1_1_3", "created": False},
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
{"id": "22222222-2222-2222-2222-222222222222", "created": False},
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
],
)
annotations = mock_elements_worker.create_element_transcriptions(
annotations = mock_elements_worker_with_cache.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=TRANSCRIPTIONS_SAMPLE,
......@@ -584,9 +625,60 @@ def test_create_element_transcriptions(responses, mock_elements_worker):
"return_elements": True,
}
assert annotations == [
{"id": "word1_1_1", "created": False},
{"id": "word1_1_2", "created": False},
{"id": "word1_1_3", "created": False},
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
{"id": "22222222-2222-2222-2222-222222222222", "created": False},
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
]
# Check that created transcriptions and elements were properly stored in SQLite cache
cache_path = f"{CACHE_DIR}/db.sqlite"
assert os.path.isfile(cache_path)
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM elements"
).fetchall()
assert [CachedElement(**dict(row)) for row in rows] == [
CachedElement(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="page",
polygon=json.dumps([[100, 150], [700, 150], [700, 200], [100, 200]]),
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
)
]
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM transcriptions"
).fetchall()
assert [CachedTranscription(**dict(row)) for row in rows] == [
CachedTranscription(
id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
text="The",
confidence=0.5,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
),
CachedTranscription(
id=convert_str_uuid_to_hex("67896789-6789-6789-6789-678967896789"),
element_id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
text="first",
confidence=0.75,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
),
CachedTranscription(
id=convert_str_uuid_to_hex("78907890-7890-7890-7890-789078907890"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
text="line",
confidence=0.9,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
),
]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment