diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index b184460d87028b9e2b537f99e14049433db98d9f..b1be02e2787f3c06c53465bffbee91d737288ca4 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -166,6 +166,7 @@ class CachedTranscriptionEntity(Model): offset = IntegerField(constraints=[Check("offset >= 0")]) length = IntegerField(constraints=[Check("length > 0")]) worker_version_id = UUIDField() + confidence = FloatField(null=True) class Meta: primary_key = CompositeKey("transcription", "entity") diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index 1cf561658d1769ea9ca8aa96db0b376206acbb6d..449e8ce3df433e6f6cf8c40e2e71a4eb63b96e7f 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -82,7 +82,9 @@ class EntityMixin(object): return entity["id"] - def create_transcription_entity(self, transcription, entity, offset, length): + def create_transcription_entity( + self, transcription, entity, offset, length, confidence=None + ): """ Create a link between an existing entity and an existing transcription through API """ @@ -98,6 +100,9 @@ class EntityMixin(object): assert ( length is not None and isinstance(length, int) and length > 0 ), "length shouldn't be null and should be a strictly positive integer" + assert ( + confidence is None or isinstance(confidence, float) and 0 <= confidence <= 1 + ), "confidence should be null or a float in [0..1] range" if self.is_read_only: logger.warning( "Cannot create transcription entity as this worker is in read-only mode" @@ -112,6 +117,7 @@ class EntityMixin(object): "length": length, "offset": offset, "worker_version_id": self.worker_version_id, + "confidence": confidence, }, ) # TODO: Report transcription entity creation @@ -125,6 +131,7 @@ class EntityMixin(object): offset=offset, length=length, worker_version_id=self.worker_version_id, + confidence=confidence, ) except IntegrityError as e: logger.warning( diff --git a/tests/test_cache.py b/tests/test_cache.py index 03a681448fe6514cd7ca8041a9c0a809b6606365..fcc7e8e7406ec1d7b524e7d00b90efcdc0dc1812 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -58,7 +58,7 @@ def test_create_tables(tmp_path): CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) -CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) +CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" actual_schema = "\n".join( diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index e4f43b6f3fce660c04df235b16711527c4a625fe..075f8e5b7d95ccab43c2d96e56480d1bea264f54 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -451,6 +451,7 @@ def test_create_transcription_entity(responses, mock_elements_worker): "offset": 5, "length": 10, "worker_version_id": "12341234-1234-1234-1234-123412341234", + "confidence": None, } @@ -508,6 +509,7 @@ def test_create_transcription_entity_with_cache( "offset": 5, "length": 10, "worker_version_id": "12341234-1234-1234-1234-123412341234", + "confidence": None, } # Check that created transcription entity was properly stored in SQLite cache @@ -520,3 +522,75 @@ def test_create_transcription_entity_with_cache( worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ) ] + + +def test_create_transcription_entity_with_confidence_with_cache( + responses, mock_elements_worker_with_cache +): + CachedElement.create( + id=UUID("12341234-1234-1234-1234-123412341234"), + type="page", + ) + CachedTranscription.create( + id=UUID("11111111-1111-1111-1111-111111111111"), + element=UUID("12341234-1234-1234-1234-123412341234"), + text="Hello, it's me.", + confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ) + CachedEntity.create( + id=UUID("11111111-1111-1111-1111-111111111111"), + type="person", + name="Bob Bob", + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ) + + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + status=200, + json={ + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + "confidence": 0.77, + }, + ) + + mock_elements_worker_with_cache.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=10, + confidence=0.77, + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "POST", + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + ), + ] + assert json.loads(responses.calls[-1].request.body) == { + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + "worker_version_id": "12341234-1234-1234-1234-123412341234", + "confidence": 0.77, + } + + # Check that created transcription entity was properly stored in SQLite cache + assert list(CachedTranscriptionEntity.select()) == [ + CachedTranscriptionEntity( + transcription=UUID("11111111-1111-1111-1111-111111111111"), + entity=UUID("11111111-1111-1111-1111-111111111111"), + offset=5, + length=10, + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + confidence=0.77, + ) + ]