From a11872b5d99ad5bb60447e70e518596626441e62 Mon Sep 17 00:00:00 2001 From: Martin Maarand <maarand@teklia.com> Date: Thu, 18 Nov 2021 16:32:49 +0000 Subject: [PATCH] Resolve "Support confidence in create_transcription_entity helper" --- arkindex_worker/cache.py | 1 + arkindex_worker/worker/entity.py | 9 ++- tests/test_cache.py | 2 +- tests/test_elements_worker/test_entities.py | 74 +++++++++++++++++++++ 4 files changed, 84 insertions(+), 2 deletions(-) diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index b184460d..b1be02e2 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -166,6 +166,7 @@ class CachedTranscriptionEntity(Model): offset = IntegerField(constraints=[Check("offset >= 0")]) length = IntegerField(constraints=[Check("length > 0")]) worker_version_id = UUIDField() + confidence = FloatField(null=True) class Meta: primary_key = CompositeKey("transcription", "entity") diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index 1cf56165..449e8ce3 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -82,7 +82,9 @@ class EntityMixin(object): return entity["id"] - def create_transcription_entity(self, transcription, entity, offset, length): + def create_transcription_entity( + self, transcription, entity, offset, length, confidence=None + ): """ Create a link between an existing entity and an existing transcription through API """ @@ -98,6 +100,9 @@ class EntityMixin(object): assert ( length is not None and isinstance(length, int) and length > 0 ), "length shouldn't be null and should be a strictly positive integer" + assert ( + confidence is None or isinstance(confidence, float) and 0 <= confidence <= 1 + ), "confidence should be null or a float in [0..1] range" if self.is_read_only: logger.warning( "Cannot create transcription entity as this worker is in read-only mode" @@ -112,6 +117,7 @@ class EntityMixin(object): "length": length, "offset": offset, "worker_version_id": self.worker_version_id, + "confidence": confidence, }, ) # TODO: Report transcription entity creation @@ -125,6 +131,7 @@ class EntityMixin(object): offset=offset, length=length, worker_version_id=self.worker_version_id, + confidence=confidence, ) except IntegrityError as e: logger.warning( diff --git a/tests/test_cache.py b/tests/test_cache.py index 03a68144..fcc7e8e7 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -58,7 +58,7 @@ def test_create_tables(tmp_path): CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) -CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) +CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" actual_schema = "\n".join( diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index e4f43b6f..075f8e5b 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -451,6 +451,7 @@ def test_create_transcription_entity(responses, mock_elements_worker): "offset": 5, "length": 10, "worker_version_id": "12341234-1234-1234-1234-123412341234", + "confidence": None, } @@ -508,6 +509,7 @@ def test_create_transcription_entity_with_cache( "offset": 5, "length": 10, "worker_version_id": "12341234-1234-1234-1234-123412341234", + "confidence": None, } # Check that created transcription entity was properly stored in SQLite cache @@ -520,3 +522,75 @@ def test_create_transcription_entity_with_cache( worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ) ] + + +def test_create_transcription_entity_with_confidence_with_cache( + responses, mock_elements_worker_with_cache +): + CachedElement.create( + id=UUID("12341234-1234-1234-1234-123412341234"), + type="page", + ) + CachedTranscription.create( + id=UUID("11111111-1111-1111-1111-111111111111"), + element=UUID("12341234-1234-1234-1234-123412341234"), + text="Hello, it's me.", + confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ) + CachedEntity.create( + id=UUID("11111111-1111-1111-1111-111111111111"), + type="person", + name="Bob Bob", + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ) + + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + status=200, + json={ + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + "confidence": 0.77, + }, + ) + + mock_elements_worker_with_cache.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=10, + confidence=0.77, + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "POST", + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + ), + ] + assert json.loads(responses.calls[-1].request.body) == { + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + "worker_version_id": "12341234-1234-1234-1234-123412341234", + "confidence": 0.77, + } + + # Check that created transcription entity was properly stored in SQLite cache + assert list(CachedTranscriptionEntity.select()) == [ + CachedTranscriptionEntity( + transcription=UUID("11111111-1111-1111-1111-111111111111"), + entity=UUID("11111111-1111-1111-1111-111111111111"), + offset=5, + length=10, + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + confidence=0.77, + ) + ] -- GitLab