-
Eva Bardou authoredEva Bardou authored
test_transcriptions.py 27.20 KiB
# -*- coding: utf-8 -*-
import json
import os
from pathlib import Path
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement, CachedTranscription
from arkindex_worker.models import Element
from arkindex_worker.utils import convert_str_uuid_to_hex
CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache"
TRANSCRIPTIONS_SAMPLE = [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.5,
"text": "The",
},
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "first",
},
{
"polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]],
"score": 0.9,
"text": "line",
},
]
def test_create_transcription_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription(
element=None,
text="i am a line",
score=0.42,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription(
element="not element type",
text="i am a line",
score=0.42,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
def test_create_transcription_wrong_text(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription(
element=elt,
text=None,
score=0.42,
)
assert str(e.value) == "text shouldn't be null and should be of type str"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription(
element=elt,
text=1234,
score=0.42,
)
assert str(e.value) == "text shouldn't be null and should be of type str"
def test_create_transcription_wrong_score(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription(
element=elt,
text="i am a line",
score=None,
)
assert (
str(e.value) == "score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription(
element=elt,
text="i am a line",
score="wrong score",
)
assert (
str(e.value) == "score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription(
element=elt,
text="i am a line",
score=0,
)
assert (
str(e.value) == "score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription(
element=elt,
text="i am a line",
score=2.00,
)
assert (
str(e.value) == "score shouldn't be null and should be a float in [0..1] range"
)
def test_create_transcription_api_error(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcription/",
status=500,
)
with pytest.raises(ErrorResponse):
mock_elements_worker.create_transcription(
element=elt,
text="i am a line",
score=0.42,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
]
def test_create_transcription(responses, mock_elements_worker_with_cache):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcription/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"text": "i am a line",
"score": 0.42,
"confidence": 0.42,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
},
)
mock_elements_worker_with_cache.create_transcription(
element=elt,
text="i am a line",
score=0.42,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
]
assert json.loads(responses.calls[2].request.body) == {
"text": "i am a line",
"worker_version": "12341234-1234-1234-1234-123412341234",
"score": 0.42,
}
# Check that created transcription was properly stored in SQLite cache
cache_path = f"{CACHE_DIR}/db.sqlite"
assert os.path.isfile(cache_path)
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM transcriptions"
).fetchall()
assert [CachedTranscription(**dict(row)) for row in rows] == [
CachedTranscription(
id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"),
element_id=convert_str_uuid_to_hex(elt.id),
text="i am a line",
confidence=0.42,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
)
]
def test_create_element_transcriptions_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=None,
sub_element_type="page",
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element="not element type",
sub_element_type="page",
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
def test_create_element_transcriptions_wrong_sub_element_type(mock_elements_worker):
elt = Element({"zone": None})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type=None,
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert (
str(e.value) == "sub_element_type shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type=1234,
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert (
str(e.value) == "sub_element_type shouldn't be null and should be of type str"
)
def test_create_element_transcriptions_wrong_transcriptions(mock_elements_worker):
elt = Element({"zone": None})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=None,
)
assert str(e.value) == "transcriptions shouldn't be null and should be of type list"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=1234,
)
assert str(e.value) == "transcriptions shouldn't be null and should be of type list"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.5,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.5,
"text": None,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.5,
"text": 1234,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"text": "word",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": None,
"text": "word",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": "a wrong score",
"text": "word",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0,
"text": "word",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 2.00,
"text": "word",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{"score": 0.5, "text": "word"},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: polygon shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{"polygon": None, "score": 0.5, "text": "word"},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: polygon shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{"polygon": "not a polygon", "score": 0.5, "text": "word"},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: polygon shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{"polygon": [[1, 1], [2, 2]], "score": 0.5, "text": "word"},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: polygon should have at least three points"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [[1, 1, 1], [2, 2, 1], [2, 1, 1], [1, 2, 1]],
"score": 0.5,
"text": "word",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: polygon points should be lists of two items"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{"polygon": [[1], [2], [2], [1]], "score": 0.5, "text": "word"},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: polygon points should be lists of two items"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [["not a coord", 1], [2, 2], [2, 1], [1, 2]],
"score": 0.5,
"text": "word",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: polygon points should be lists of two numbers"
)
def test_create_element_transcriptions_api_error(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
status=500,
)
with pytest.raises(ErrorResponse):
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
]
def test_create_element_transcriptions(
mocker, responses, mock_elements_worker_with_cache
):
mocker.patch(
"uuid.uuid4",
side_effect=[
"56785678-5678-5678-5678-567856785678",
"67896789-6789-6789-6789-678967896789",
"78907890-7890-7890-7890-789078907890",
],
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
status=200,
json=[
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
{"id": "22222222-2222-2222-2222-222222222222", "created": False},
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
],
)
annotations = mock_elements_worker_with_cache.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
]
assert json.loads(responses.calls[2].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": TRANSCRIPTIONS_SAMPLE,
"return_elements": True,
}
assert annotations == [
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
{"id": "22222222-2222-2222-2222-222222222222", "created": False},
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
]
# Check that created transcriptions and elements were properly stored in SQLite cache
cache_path = f"{CACHE_DIR}/db.sqlite"
assert os.path.isfile(cache_path)
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM elements"
).fetchall()
assert [CachedElement(**dict(row)) for row in rows] == [
CachedElement(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="page",
polygon=json.dumps([[100, 150], [700, 150], [700, 200], [100, 200]]),
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
)
]
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM transcriptions"
).fetchall()
assert [CachedTranscription(**dict(row)) for row in rows] == [
CachedTranscription(
id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
text="The",
confidence=0.5,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
),
CachedTranscription(
id=convert_str_uuid_to_hex("67896789-6789-6789-6789-678967896789"),
element_id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
text="first",
confidence=0.75,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
),
CachedTranscription(
id=convert_str_uuid_to_hex("78907890-7890-7890-7890-789078907890"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
text="line",
confidence=0.9,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
),
]
def test_list_transcriptions_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(element=None)
assert str(e.value) == "element shouldn't be null and should be of type Element"
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(element="not element type")
assert str(e.value) == "element shouldn't be null and should be of type Element"
def test_list_transcriptions_wrong_element_type(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(
element=elt,
element_type=1234,
)
assert str(e.value) == "element_type should be of type str"
def test_list_transcriptions_wrong_recursive(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(
element=elt,
recursive="not bool",
)
assert str(e.value) == "recursive should be of type bool"
def test_list_transcriptions_wrong_worker_version(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(
element=elt,
worker_version=1234,
)
assert str(e.value) == "worker_version should be of type str"
def test_list_transcriptions_api_error(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.GET,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
status=500,
)
with pytest.raises(
Exception, match="Stopping pagination as data will be incomplete"
):
next(mock_elements_worker.list_transcriptions(element=elt))
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We do 5 retries
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
]
def test_list_transcriptions(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
trans = [
{
"id": "0000",
"text": "hey",
"confidence": 0.42,
"worker_version_id": "56785678-5678-5678-5678-567856785678",
"element": None,
},
{
"id": "1111",
"text": "it's",
"confidence": 0.42,
"worker_version_id": "56785678-5678-5678-5678-567856785678",
"element": None,
},
{
"id": "2222",
"text": "me",
"confidence": 0.42,
"worker_version_id": "56785678-5678-5678-5678-567856785678",
"element": None,
},
]
responses.add(
responses.GET,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
status=200,
json={
"count": 3,
"next": None,
"results": trans,
},
)
for idx, transcription in enumerate(
mock_elements_worker.list_transcriptions(element=elt)
):
assert transcription == trans[idx]
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
]