Skip to content
Snippets Groups Projects
Commit 9244ade3 authored by ml bonhomme's avatar ml bonhomme :bee: Committed by Erwan Rouchet
Browse files

Text orientation in base worker

parent db93b4ab
No related branches found
No related tags found
1 merge request!139Text orientation in base worker
Pipeline #78877 passed
...@@ -8,4 +8,4 @@ line_length = 88 ...@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,peewee,pytest,requests,setuptools,sh,tenacity,yaml known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,requests,setuptools,sh,tenacity,yaml
...@@ -124,6 +124,7 @@ class CachedTranscription(Model): ...@@ -124,6 +124,7 @@ class CachedTranscription(Model):
element = ForeignKeyField(CachedElement, backref="transcriptions") element = ForeignKeyField(CachedElement, backref="transcriptions")
text = TextField() text = TextField()
confidence = FloatField() confidence = FloatField()
orientation = CharField(max_length=50)
worker_version_id = UUIDField() worker_version_id = UUIDField()
class Meta: class Meta:
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from enum import Enum
from peewee import IntegrityError from peewee import IntegrityError
from arkindex_worker import logger from arkindex_worker import logger
...@@ -7,8 +9,17 @@ from arkindex_worker.cache import CachedElement, CachedTranscription ...@@ -7,8 +9,17 @@ from arkindex_worker.cache import CachedElement, CachedTranscription
from arkindex_worker.models import Element from arkindex_worker.models import Element
class TextOrientation(Enum):
HorizontalLeftToRight = "horizontal-lr"
HorizontalRightToLeft = "horizontal-rl"
VerticalRightToLeft = "vertical-rl"
VerticalLeftToRight = "vertical-lr"
class TranscriptionMixin(object): class TranscriptionMixin(object):
def create_transcription(self, element, text, score): def create_transcription(
self, element, text, score, orientation=TextOrientation.HorizontalLeftToRight
):
""" """
Create a transcription on the given element through the API. Create a transcription on the given element through the API.
""" """
...@@ -18,7 +29,9 @@ class TranscriptionMixin(object): ...@@ -18,7 +29,9 @@ class TranscriptionMixin(object):
assert text and isinstance( assert text and isinstance(
text, str text, str
), "text shouldn't be null and should be of type str" ), "text shouldn't be null and should be of type str"
assert orientation and isinstance(
orientation, TextOrientation
), "orientation shouldn't be null and should be of type TextOrientation"
assert ( assert (
isinstance(score, float) and 0 <= score <= 1 isinstance(score, float) and 0 <= score <= 1
), "score shouldn't be null and should be a float in [0..1] range" ), "score shouldn't be null and should be a float in [0..1] range"
...@@ -36,6 +49,7 @@ class TranscriptionMixin(object): ...@@ -36,6 +49,7 @@ class TranscriptionMixin(object):
"text": text, "text": text,
"worker_version": self.worker_version_id, "worker_version": self.worker_version_id,
"score": score, "score": score,
"orientation": orientation.value,
}, },
) )
...@@ -50,6 +64,7 @@ class TranscriptionMixin(object): ...@@ -50,6 +64,7 @@ class TranscriptionMixin(object):
"element_id": element.id, "element_id": element.id,
"text": created["text"], "text": created["text"],
"confidence": created["confidence"], "confidence": created["confidence"],
"orientation": created["orientation"],
"worker_version_id": self.worker_version_id, "worker_version_id": self.worker_version_id,
} }
] ]
...@@ -70,7 +85,10 @@ class TranscriptionMixin(object): ...@@ -70,7 +85,10 @@ class TranscriptionMixin(object):
transcriptions, list transcriptions, list
), "transcriptions shouldn't be null and should be of type list" ), "transcriptions shouldn't be null and should be of type list"
for index, transcription in enumerate(transcriptions): # Create shallow copies of every transcription to avoid mutating the original payload
transcriptions_payload = list(map(dict, transcriptions))
for (index, transcription) in enumerate(transcriptions_payload):
element_id = transcription.get("element_id") element_id = transcription.get("element_id")
assert element_id and isinstance( assert element_id and isinstance(
element_id, str element_id, str
...@@ -86,11 +104,20 @@ class TranscriptionMixin(object): ...@@ -86,11 +104,20 @@ class TranscriptionMixin(object):
score is not None and isinstance(score, float) and 0 <= score <= 1 score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range" ), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
orientation = transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight
)
assert orientation and isinstance(
orientation, TextOrientation
), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
if orientation:
transcription["orientation"] = orientation.value
created_trs = self.request( created_trs = self.request(
"CreateTranscriptions", "CreateTranscriptions",
body={ body={
"worker_version": self.worker_version_id, "worker_version": self.worker_version_id,
"transcriptions": transcriptions, "transcriptions": transcriptions_payload,
}, },
)["transcriptions"] )["transcriptions"]
...@@ -106,6 +133,7 @@ class TranscriptionMixin(object): ...@@ -106,6 +133,7 @@ class TranscriptionMixin(object):
"element_id": created_tr["element_id"], "element_id": created_tr["element_id"],
"text": created_tr["text"], "text": created_tr["text"],
"confidence": created_tr["confidence"], "confidence": created_tr["confidence"],
"orientation": created_tr["orientation"],
"worker_version_id": self.worker_version_id, "worker_version_id": self.worker_version_id,
} }
for created_tr in created_trs for created_tr in created_trs
...@@ -132,7 +160,10 @@ class TranscriptionMixin(object): ...@@ -132,7 +160,10 @@ class TranscriptionMixin(object):
transcriptions, list transcriptions, list
), "transcriptions shouldn't be null and should be of type list" ), "transcriptions shouldn't be null and should be of type list"
for index, transcription in enumerate(transcriptions): # Create shallow copies of every transcription to avoid mutating the original payload
transcriptions_payload = list(map(dict, transcriptions))
for (index, transcription) in enumerate(transcriptions_payload):
text = transcription.get("text") text = transcription.get("text")
assert text and isinstance( assert text and isinstance(
text, str text, str
...@@ -143,6 +174,15 @@ class TranscriptionMixin(object): ...@@ -143,6 +174,15 @@ class TranscriptionMixin(object):
score is not None and isinstance(score, float) and 0 <= score <= 1 score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range" ), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
orientation = transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight
)
assert orientation and isinstance(
orientation, TextOrientation
), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
if orientation:
transcription["orientation"] = orientation.value
polygon = transcription.get("polygon") polygon = transcription.get("polygon")
assert polygon and isinstance( assert polygon and isinstance(
polygon, list polygon, list
...@@ -168,7 +208,7 @@ class TranscriptionMixin(object): ...@@ -168,7 +208,7 @@ class TranscriptionMixin(object):
body={ body={
"element_type": sub_element_type, "element_type": sub_element_type,
"worker_version": self.worker_version_id, "worker_version": self.worker_version_id,
"transcriptions": transcriptions, "transcriptions": transcriptions_payload,
"return_elements": True, "return_elements": True,
}, },
) )
...@@ -216,6 +256,9 @@ class TranscriptionMixin(object): ...@@ -216,6 +256,9 @@ class TranscriptionMixin(object):
"element_id": annotation["element_id"], "element_id": annotation["element_id"],
"text": transcription["text"], "text": transcription["text"],
"confidence": transcription["score"], "confidence": transcription["score"],
"orientation": transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight
).value,
"worker_version_id": self.worker_version_id, "worker_version_id": self.worker_version_id,
} }
) )
......
...@@ -15,6 +15,7 @@ from arkindex.mock import MockApiClient ...@@ -15,6 +15,7 @@ from arkindex.mock import MockApiClient
from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription
from arkindex_worker.git import GitHelper, GitlabHelper from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import BaseWorker, ElementsWorker from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.transcription import TextOrientation
FIXTURES_DIR = Path(__file__).resolve().parent / "data" FIXTURES_DIR = Path(__file__).resolve().parent / "data"
...@@ -381,6 +382,7 @@ def mock_cached_transcriptions(): ...@@ -381,6 +382,7 @@ def mock_cached_transcriptions():
element_id=UUID("11111111-1111-1111-1111-111111111111"), element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="This", text="This",
confidence=0.42, confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
) )
CachedTranscription.create( CachedTranscription.create(
...@@ -388,6 +390,7 @@ def mock_cached_transcriptions(): ...@@ -388,6 +390,7 @@ def mock_cached_transcriptions():
element_id=UUID("22222222-2222-2222-2222-222222222222"), element_id=UUID("22222222-2222-2222-2222-222222222222"),
text="is", text="is",
confidence=0.42, confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
) )
CachedTranscription.create( CachedTranscription.create(
...@@ -395,6 +398,7 @@ def mock_cached_transcriptions(): ...@@ -395,6 +398,7 @@ def mock_cached_transcriptions():
element_id=UUID("33333333-3333-3333-3333-333333333333"), element_id=UUID("33333333-3333-3333-3333-333333333333"),
text="a", text="a",
confidence=0.42, confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
) )
CachedTranscription.create( CachedTranscription.create(
...@@ -402,6 +406,7 @@ def mock_cached_transcriptions(): ...@@ -402,6 +406,7 @@ def mock_cached_transcriptions():
element_id=UUID("44444444-4444-4444-4444-444444444444"), element_id=UUID("44444444-4444-4444-4444-444444444444"),
text="good", text="good",
confidence=0.42, confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
) )
CachedTranscription.create( CachedTranscription.create(
...@@ -409,6 +414,7 @@ def mock_cached_transcriptions(): ...@@ -409,6 +414,7 @@ def mock_cached_transcriptions():
element_id=UUID("55555555-5555-5555-5555-555555555555"), element_id=UUID("55555555-5555-5555-5555-555555555555"),
text="test", text="test",
confidence=0.42, confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
) )
...@@ -461,6 +467,7 @@ def mock_databases(tmpdir): ...@@ -461,6 +467,7 @@ def mock_databases(tmpdir):
element_id=UUID("42424242-4242-4242-4242-424242424242"), element_id=UUID("42424242-4242-4242-4242-424242424242"),
text="Hello!", text="Hello!",
confidence=0.42, confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
) )
...@@ -477,6 +484,7 @@ def mock_databases(tmpdir): ...@@ -477,6 +484,7 @@ def mock_databases(tmpdir):
element_id=UUID("42424242-4242-4242-4242-424242424242"), element_id=UUID("42424242-4242-4242-4242-424242424242"),
text="Hello again neighbor !", text="Hello again neighbor !",
confidence=0.42, confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
) )
......
...@@ -59,7 +59,7 @@ CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type ...@@ -59,7 +59,7 @@ CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join( actual_schema = "\n".join(
[ [
......
...@@ -13,6 +13,7 @@ from arkindex_worker.cache import ( ...@@ -13,6 +13,7 @@ from arkindex_worker.cache import (
) )
from arkindex_worker.models import Element from arkindex_worker.models import Element
from arkindex_worker.worker import EntityType from arkindex_worker.worker import EntityType
from arkindex_worker.worker.transcription import TextOrientation
from . import BASE_API_CALLS from . import BASE_API_CALLS
...@@ -465,6 +466,7 @@ def test_create_transcription_entity_with_cache( ...@@ -465,6 +466,7 @@ def test_create_transcription_entity_with_cache(
element=UUID("12341234-1234-1234-1234-123412341234"), element=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello, it's me.", text="Hello, it's me.",
confidence=0.42, confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
) )
CachedEntity.create( CachedEntity.create(
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment