Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (4)
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,peewee,pytest,requests,setuptools,sh,tenacity,yaml
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,requests,setuptools,sh,tenacity,yaml
0.2.3-rc1
0.2.3-rc2
......@@ -124,6 +124,7 @@ class CachedTranscription(Model):
element = ForeignKeyField(CachedElement, backref="transcriptions")
text = TextField()
confidence = FloatField()
orientation = CharField(max_length=50)
worker_version_id = UUIDField()
class Meta:
......@@ -165,6 +166,7 @@ class CachedTranscriptionEntity(Model):
offset = IntegerField(constraints=[Check("offset >= 0")])
length = IntegerField(constraints=[Check("length > 0")])
worker_version_id = UUIDField()
confidence = FloatField(null=True)
class Meta:
primary_key = CompositeKey("transcription", "entity")
......
......@@ -82,7 +82,9 @@ class EntityMixin(object):
return entity["id"]
def create_transcription_entity(self, transcription, entity, offset, length):
def create_transcription_entity(
self, transcription, entity, offset, length, confidence=None
):
"""
Create a link between an existing entity and an existing transcription through API
"""
......@@ -98,21 +100,28 @@ class EntityMixin(object):
assert (
length is not None and isinstance(length, int) and length > 0
), "length shouldn't be null and should be a strictly positive integer"
assert (
confidence is None or isinstance(confidence, float) and 0 <= confidence <= 1
), "confidence should be null or a float in [0..1] range"
if self.is_read_only:
logger.warning(
"Cannot create transcription entity as this worker is in read-only mode"
)
return
body = {
"entity": entity,
"length": length,
"offset": offset,
"worker_version_id": self.worker_version_id,
}
if confidence is not None:
body["confidence"] = confidence
transcription_ent = self.request(
"CreateTranscriptionEntity",
id=transcription,
body={
"entity": entity,
"length": length,
"offset": offset,
"worker_version_id": self.worker_version_id,
},
body=body,
)
# TODO: Report transcription entity creation
......@@ -125,6 +134,7 @@ class EntityMixin(object):
offset=offset,
length=length,
worker_version_id=self.worker_version_id,
confidence=confidence,
)
except IntegrityError as e:
logger.warning(
......
# -*- coding: utf-8 -*-
from enum import Enum
from peewee import IntegrityError
from arkindex_worker import logger
......@@ -7,8 +9,17 @@ from arkindex_worker.cache import CachedElement, CachedTranscription
from arkindex_worker.models import Element
class TextOrientation(Enum):
HorizontalLeftToRight = "horizontal-lr"
HorizontalRightToLeft = "horizontal-rl"
VerticalRightToLeft = "vertical-rl"
VerticalLeftToRight = "vertical-lr"
class TranscriptionMixin(object):
def create_transcription(self, element, text, score):
def create_transcription(
self, element, text, score, orientation=TextOrientation.HorizontalLeftToRight
):
"""
Create a transcription on the given element through the API.
"""
......@@ -18,7 +29,9 @@ class TranscriptionMixin(object):
assert text and isinstance(
text, str
), "text shouldn't be null and should be of type str"
assert orientation and isinstance(
orientation, TextOrientation
), "orientation shouldn't be null and should be of type TextOrientation"
assert (
isinstance(score, float) and 0 <= score <= 1
), "score shouldn't be null and should be a float in [0..1] range"
......@@ -36,6 +49,7 @@ class TranscriptionMixin(object):
"text": text,
"worker_version": self.worker_version_id,
"score": score,
"orientation": orientation.value,
},
)
......@@ -50,6 +64,7 @@ class TranscriptionMixin(object):
"element_id": element.id,
"text": created["text"],
"confidence": created["confidence"],
"orientation": created["orientation"],
"worker_version_id": self.worker_version_id,
}
]
......@@ -70,7 +85,10 @@ class TranscriptionMixin(object):
transcriptions, list
), "transcriptions shouldn't be null and should be of type list"
for index, transcription in enumerate(transcriptions):
# Create shallow copies of every transcription to avoid mutating the original payload
transcriptions_payload = list(map(dict, transcriptions))
for (index, transcription) in enumerate(transcriptions_payload):
element_id = transcription.get("element_id")
assert element_id and isinstance(
element_id, str
......@@ -86,11 +104,20 @@ class TranscriptionMixin(object):
score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
orientation = transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight
)
assert orientation and isinstance(
orientation, TextOrientation
), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
if orientation:
transcription["orientation"] = orientation.value
created_trs = self.request(
"CreateTranscriptions",
body={
"worker_version": self.worker_version_id,
"transcriptions": transcriptions,
"transcriptions": transcriptions_payload,
},
)["transcriptions"]
......@@ -106,6 +133,7 @@ class TranscriptionMixin(object):
"element_id": created_tr["element_id"],
"text": created_tr["text"],
"confidence": created_tr["confidence"],
"orientation": created_tr["orientation"],
"worker_version_id": self.worker_version_id,
}
for created_tr in created_trs
......@@ -132,7 +160,10 @@ class TranscriptionMixin(object):
transcriptions, list
), "transcriptions shouldn't be null and should be of type list"
for index, transcription in enumerate(transcriptions):
# Create shallow copies of every transcription to avoid mutating the original payload
transcriptions_payload = list(map(dict, transcriptions))
for (index, transcription) in enumerate(transcriptions_payload):
text = transcription.get("text")
assert text and isinstance(
text, str
......@@ -143,6 +174,15 @@ class TranscriptionMixin(object):
score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
orientation = transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight
)
assert orientation and isinstance(
orientation, TextOrientation
), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
if orientation:
transcription["orientation"] = orientation.value
polygon = transcription.get("polygon")
assert polygon and isinstance(
polygon, list
......@@ -168,7 +208,7 @@ class TranscriptionMixin(object):
body={
"element_type": sub_element_type,
"worker_version": self.worker_version_id,
"transcriptions": transcriptions,
"transcriptions": transcriptions_payload,
"return_elements": True,
},
)
......@@ -216,6 +256,9 @@ class TranscriptionMixin(object):
"element_id": annotation["element_id"],
"text": transcription["text"],
"confidence": transcription["score"],
"orientation": transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight
).value,
"worker_version_id": self.worker_version_id,
}
)
......
......@@ -15,6 +15,7 @@ from arkindex.mock import MockApiClient
from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription
from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.transcription import TextOrientation
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
......@@ -381,6 +382,7 @@ def mock_cached_transcriptions():
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="This",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
......@@ -388,6 +390,7 @@ def mock_cached_transcriptions():
element_id=UUID("22222222-2222-2222-2222-222222222222"),
text="is",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
......@@ -395,6 +398,7 @@ def mock_cached_transcriptions():
element_id=UUID("33333333-3333-3333-3333-333333333333"),
text="a",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
......@@ -402,6 +406,7 @@ def mock_cached_transcriptions():
element_id=UUID("44444444-4444-4444-4444-444444444444"),
text="good",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
......@@ -409,6 +414,7 @@ def mock_cached_transcriptions():
element_id=UUID("55555555-5555-5555-5555-555555555555"),
text="test",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
......@@ -461,6 +467,7 @@ def mock_databases(tmpdir):
element_id=UUID("42424242-4242-4242-4242-424242424242"),
text="Hello!",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
......@@ -477,6 +484,7 @@ def mock_databases(tmpdir):
element_id=UUID("42424242-4242-4242-4242-424242424242"),
text="Hello again neighbor !",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
......
......@@ -58,8 +58,8 @@ def test_create_tables(tmp_path):
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join(
[
......
......@@ -13,6 +13,7 @@ from arkindex_worker.cache import (
)
from arkindex_worker.models import Element
from arkindex_worker.worker import EntityType
from arkindex_worker.worker.transcription import TextOrientation
from . import BASE_API_CALLS
......@@ -417,7 +418,7 @@ def test_create_transcription_entity_api_error(responses, mock_elements_worker):
]
def test_create_transcription_entity(responses, mock_elements_worker):
def test_create_transcription_entity_no_confidence(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
......@@ -453,6 +454,83 @@ def test_create_transcription_entity(responses, mock_elements_worker):
}
def test_create_transcription_entity_with_confidence(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
status=200,
json={
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
"confidence": 0.33,
},
)
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
confidence=0.33,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
]
assert json.loads(responses.calls[-1].request.body) == {
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"confidence": 0.33,
}
def test_create_transcription_entity_confidence_none(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
status=200,
json={
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
"confidence": None,
},
)
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
confidence=None,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
]
assert json.loads(responses.calls[-1].request.body) == {
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
}
def test_create_transcription_entity_with_cache(
responses, mock_elements_worker_with_cache
):
......@@ -465,6 +543,75 @@ def test_create_transcription_entity_with_cache(
element=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello, it's me.",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
CachedEntity.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
type="person",
name="Bob Bob",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
status=200,
json={
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
},
)
mock_elements_worker_with_cache.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
]
assert json.loads(responses.calls[-1].request.body) == {
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
}
# Check that created transcription entity was properly stored in SQLite cache
assert list(CachedTranscriptionEntity.select()) == [
CachedTranscriptionEntity(
transcription=UUID("11111111-1111-1111-1111-111111111111"),
entity=UUID("11111111-1111-1111-1111-111111111111"),
offset=5,
length=10,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
def test_create_transcription_entity_with_confidence_with_cache(
responses, mock_elements_worker_with_cache
):
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello, it's me.",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
CachedEntity.create(
......@@ -482,6 +629,7 @@ def test_create_transcription_entity_with_cache(
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
"confidence": 0.77,
},
)
......@@ -490,6 +638,7 @@ def test_create_transcription_entity_with_cache(
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
confidence=0.77,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
......@@ -506,6 +655,7 @@ def test_create_transcription_entity_with_cache(
"offset": 5,
"length": 10,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"confidence": 0.77,
}
# Check that created transcription entity was properly stored in SQLite cache
......@@ -516,5 +666,6 @@ def test_create_transcription_entity_with_cache(
offset=5,
length=10,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
confidence=0.77,
)
]
......@@ -4,9 +4,11 @@ from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from playhouse.shortcuts import model_to_dict
from arkindex_worker.cache import CachedElement, CachedTranscription
from arkindex_worker.models import Element
from arkindex_worker.worker.transcription import TextOrientation
from . import BASE_API_CALLS
......@@ -117,6 +119,74 @@ def test_create_transcription_wrong_score(mock_elements_worker):
)
def test_create_transcription_default_orientation(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcription/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"text": "Animula vagula blandula",
"confidence": 0.42,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
},
)
mock_elements_worker.create_transcription(
element=elt,
text="Animula vagula blandula",
score=0.42,
)
assert json.loads(responses.calls[-1].request.body) == {
"text": "Animula vagula blandula",
"worker_version": "12341234-1234-1234-1234-123412341234",
"score": 0.42,
"orientation": "horizontal-lr",
}
def test_create_transcription_orientation(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcription/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"text": "Animula vagula blandula",
"confidence": 0.42,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
},
)
mock_elements_worker.create_transcription(
element=elt,
text="Animula vagula blandula",
orientation=TextOrientation.VerticalLeftToRight,
score=0.42,
)
assert json.loads(responses.calls[-1].request.body) == {
"text": "Animula vagula blandula",
"worker_version": "12341234-1234-1234-1234-123412341234",
"score": 0.42,
"orientation": "vertical-lr",
}
def test_create_transcription_wrong_orientation(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription(
element=elt,
text="Animula vagula blandula",
score=0.26,
orientation="eliptical",
)
assert (
str(e.value)
== "orientation shouldn't be null and should be of type TextOrientation"
)
def test_create_transcription_api_error(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
......@@ -177,6 +247,7 @@ def test_create_transcription(responses, mock_elements_worker):
"text": "i am a line",
"worker_version": "12341234-1234-1234-1234-123412341234",
"score": 0.42,
"orientation": "horizontal-lr",
}
......@@ -192,6 +263,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
"text": "i am a line",
"score": 0.42,
"confidence": 0.42,
"orientation": "horizontal-lr",
"worker_version_id": "12341234-1234-1234-1234-123412341234",
},
)
......@@ -212,6 +284,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
assert json.loads(responses.calls[-1].request.body) == {
"text": "i am a line",
"worker_version": "12341234-1234-1234-1234-123412341234",
"orientation": "horizontal-lr",
"score": 0.42,
}
......@@ -222,11 +295,63 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
element_id=UUID(elt.id),
text="i am a line",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
def test_create_transcription_orientation_with_cache(
responses, mock_elements_worker_with_cache
):
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcription/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"text": "Animula vagula blandula",
"confidence": 0.42,
"orientation": "vertical-lr",
"worker_version_id": "12341234-1234-1234-1234-123412341234",
},
)
mock_elements_worker_with_cache.create_transcription(
element=elt,
text="Animula vagula blandula",
orientation=TextOrientation.VerticalLeftToRight,
score=0.42,
)
assert json.loads(responses.calls[-1].request.body) == {
"text": "Animula vagula blandula",
"worker_version": "12341234-1234-1234-1234-123412341234",
"orientation": "vertical-lr",
"score": 0.42,
}
# Check that the text orientation was properly stored in SQLite cache
assert list(map(model_to_dict, CachedTranscription.select())) == [
{
"id": UUID("56785678-5678-5678-5678-567856785678"),
"element": {
"id": UUID("12341234-1234-1234-1234-123412341234"),
"parent_id": None,
"type": "thing",
"image": None,
"polygon": None,
"rotation_angle": 0,
"mirrored": False,
"initial": False,
"worker_version_id": None,
},
"text": "Animula vagula blandula",
"confidence": 0.42,
"orientation": TextOrientation.VerticalLeftToRight.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
}
]
def test_create_transcriptions_wrong_transcriptions(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
......@@ -457,6 +582,27 @@ def test_create_transcriptions_wrong_transcriptions(mock_elements_worker):
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
"score": 0.28,
"orientation": "wobble",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
)
def test_create_transcriptions_api_error(responses, mock_elements_worker):
responses.add(
......@@ -519,12 +665,14 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
"id": "00000000-0000-0000-0000-000000000000",
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"orientation": "horizontal-lr",
"confidence": 0.75,
},
{
"id": "11111111-1111-1111-1111-111111111111",
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
"orientation": "horizontal-lr",
"confidence": 0.42,
},
],
......@@ -544,7 +692,20 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
assert json.loads(responses.calls[-1].request.body) == {
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": trans,
"transcriptions": [
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
"orientation": TextOrientation.HorizontalLeftToRight.value,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
"score": 0.42,
"orientation": TextOrientation.HorizontalLeftToRight.value,
},
],
}
# Check that created transcriptions were properly stored in SQLite cache
......@@ -554,6 +715,7 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="The",
confidence=0.75,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedTranscription(
......@@ -561,11 +723,117 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="word",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
]
def test_create_transcriptions_orientation(responses, mock_elements_worker_with_cache):
CachedElement.create(id="11111111-1111-1111-1111-111111111111", type="thing")
trans = [
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "Animula vagula blandula",
"score": 0.12,
"orientation": TextOrientation.HorizontalRightToLeft,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "Hospes comesque corporis",
"score": 0.21,
"orientation": TextOrientation.VerticalLeftToRight,
},
]
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/bulk/",
status=200,
json={
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": [
{
"id": "00000000-0000-0000-0000-000000000000",
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "Animula vagula blandula",
"orientation": "horizontal-rl",
"confidence": 0.12,
},
{
"id": "11111111-1111-1111-1111-111111111111",
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "Hospes comesque corporis",
"orientation": "vertical-lr",
"confidence": 0.21,
},
],
},
)
mock_elements_worker_with_cache.create_transcriptions(
transcriptions=trans,
)
assert json.loads(responses.calls[-1].request.body) == {
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": [
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "Animula vagula blandula",
"score": 0.12,
"orientation": TextOrientation.HorizontalRightToLeft.value,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "Hospes comesque corporis",
"score": 0.21,
"orientation": TextOrientation.VerticalLeftToRight.value,
},
],
}
# Check that oriented transcriptions were properly stored in SQLite cache
assert list(map(model_to_dict, CachedTranscription.select())) == [
{
"id": UUID("00000000-0000-0000-0000-000000000000"),
"element": {
"id": UUID("11111111-1111-1111-1111-111111111111"),
"parent_id": None,
"type": "thing",
"image": None,
"polygon": None,
"rotation_angle": 0,
"mirrored": False,
"initial": False,
"worker_version_id": None,
},
"text": "Animula vagula blandula",
"confidence": 0.12,
"orientation": TextOrientation.HorizontalRightToLeft.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
},
{
"id": UUID("11111111-1111-1111-1111-111111111111"),
"element": {
"id": UUID("11111111-1111-1111-1111-111111111111"),
"parent_id": None,
"type": "thing",
"image": None,
"polygon": None,
"rotation_angle": 0,
"mirrored": False,
"initial": False,
"worker_version_id": None,
},
"text": "Hospes comesque corporis",
"confidence": 0.21,
"orientation": TextOrientation.VerticalLeftToRight.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
},
]
def test_create_element_transcriptions_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
......@@ -941,6 +1209,29 @@ def test_create_element_transcriptions_wrong_transcriptions(mock_elements_worker
== "Transcription at index 1 in transcriptions: polygon points should be lists of two numbers"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=[
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "The",
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.35,
"text": "word",
"orientation": "uptown",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
)
def test_create_element_transcriptions_api_error(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
......@@ -1011,7 +1302,26 @@ def test_create_element_transcriptions(responses, mock_elements_worker):
assert json.loads(responses.calls[-1].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": TRANSCRIPTIONS_SAMPLE,
"transcriptions": [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.5,
"text": "The",
"orientation": TextOrientation.HorizontalLeftToRight.value,
},
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "first",
"orientation": TextOrientation.HorizontalLeftToRight.value,
},
{
"polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]],
"score": 0.9,
"text": "line",
"orientation": TextOrientation.HorizontalLeftToRight.value,
},
],
"return_elements": True,
}
assert annotations == [
......@@ -1077,7 +1387,26 @@ def test_create_element_transcriptions_with_cache(
assert json.loads(responses.calls[-1].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": TRANSCRIPTIONS_SAMPLE,
"transcriptions": [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.5,
"text": "The",
"orientation": TextOrientation.HorizontalLeftToRight.value,
},
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "first",
"orientation": TextOrientation.HorizontalLeftToRight.value,
},
{
"polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]],
"score": 0.9,
"text": "line",
"orientation": TextOrientation.HorizontalLeftToRight.value,
},
],
"return_elements": True,
}
assert annotations == [
......@@ -1121,6 +1450,7 @@ def test_create_element_transcriptions_with_cache(
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="The",
confidence=0.5,
orientation=TextOrientation.HorizontalLeftToRight.value,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedTranscription(
......@@ -1128,6 +1458,7 @@ def test_create_element_transcriptions_with_cache(
element_id=UUID("22222222-2222-2222-2222-222222222222"),
text="first",
confidence=0.75,
orientation=TextOrientation.HorizontalLeftToRight.value,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedTranscription(
......@@ -1135,11 +1466,168 @@ def test_create_element_transcriptions_with_cache(
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="line",
confidence=0.9,
orientation=TextOrientation.HorizontalLeftToRight.value,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
]
def test_create_transcriptions_orientation_with_cache(
responses, mock_elements_worker_with_cache
):
elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
status=200,
json=[
{
"id": "56785678-5678-5678-5678-567856785678",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
{
"id": "67896789-6789-6789-6789-678967896789",
"element_id": "22222222-2222-2222-2222-222222222222",
"created": False,
},
{
"id": "78907890-7890-7890-7890-789078907890",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
],
)
oriented_transcriptions = [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.5,
"text": "Animula vagula blandula",
},
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "Hospes comesque corporis",
"orientation": TextOrientation.VerticalLeftToRight,
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.9,
"text": "Quae nunc abibis in loca",
"orientation": TextOrientation.HorizontalRightToLeft,
},
]
annotations = mock_elements_worker_with_cache.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=oriented_transcriptions,
)
assert json.loads(responses.calls[-1].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.5,
"text": "Animula vagula blandula",
"orientation": TextOrientation.HorizontalLeftToRight.value,
},
{
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"score": 0.75,
"text": "Hospes comesque corporis",
"orientation": TextOrientation.VerticalLeftToRight.value,
},
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"score": 0.9,
"text": "Quae nunc abibis in loca",
"orientation": TextOrientation.HorizontalRightToLeft.value,
},
],
"return_elements": True,
}
assert annotations == [
{
"id": "56785678-5678-5678-5678-567856785678",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
{
"id": "67896789-6789-6789-6789-678967896789",
"element_id": "22222222-2222-2222-2222-222222222222",
"created": False,
},
{
"id": "78907890-7890-7890-7890-789078907890",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
]
# Check that the text orientation was properly stored in SQLite cache
assert list(map(model_to_dict, CachedTranscription.select())) == [
{
"id": UUID("56785678-5678-5678-5678-567856785678"),
"element": {
"id": UUID("11111111-1111-1111-1111-111111111111"),
"parent_id": UUID(elt.id),
"type": "page",
"image": None,
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"rotation_angle": 0,
"mirrored": False,
"initial": False,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
},
"text": "Animula vagula blandula",
"confidence": 0.5,
"orientation": TextOrientation.HorizontalLeftToRight.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
},
{
"id": UUID("67896789-6789-6789-6789-678967896789"),
"element": {
"id": UUID("22222222-2222-2222-2222-222222222222"),
"parent_id": UUID(elt.id),
"type": "page",
"image": None,
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
"rotation_angle": 0,
"mirrored": False,
"initial": False,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
},
"text": "Hospes comesque corporis",
"confidence": 0.75,
"orientation": TextOrientation.VerticalLeftToRight.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
},
{
"id": UUID("78907890-7890-7890-7890-789078907890"),
"element": {
"id": UUID("11111111-1111-1111-1111-111111111111"),
"parent_id": UUID(elt.id),
"type": "page",
"image": None,
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
"rotation_angle": 0,
"mirrored": False,
"initial": False,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
},
"text": "Quae nunc abibis in loca",
"confidence": 0.9,
"orientation": TextOrientation.HorizontalRightToLeft.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
},
]
def test_list_transcriptions_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(element=None)
......