From f0a0a7ac4a7e192e152532b84f0fd60e00ef6bbf Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Mon, 26 Apr 2021 15:20:47 +0000 Subject: [PATCH] Implement create_transcription_entity + Support local cache --- arkindex_worker/cache.py | 17 ++ arkindex_worker/worker/entity.py | 52 ++++- tests/test_cache.py | 1 + tests/test_elements_worker/test_entities.py | 236 +++++++++++++++++++- 4 files changed, 304 insertions(+), 2 deletions(-) diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 4505f394..92be4913 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -6,6 +6,8 @@ import sqlite3 from peewee import ( BooleanField, CharField, + Check, + CompositeKey, Field, FloatField, ForeignKeyField, @@ -138,6 +140,20 @@ class CachedEntity(Model): table_name = "entities" +class CachedTranscriptionEntity(Model): + transcription = ForeignKeyField( + CachedTranscription, backref="transcription_entities" + ) + entity = ForeignKeyField(CachedEntity, backref="transcription_entities") + offset = IntegerField(constraints=[Check("offset >= 0")]) + length = IntegerField(constraints=[Check("length > 0")]) + + class Meta: + primary_key = CompositeKey("transcription", "entity") + database = db + table_name = "transcription_entities" + + # Add all the managed models in that list # It's used here, but also in unit tests MODELS = [ @@ -146,6 +162,7 @@ MODELS = [ CachedTranscription, CachedClassification, CachedEntity, + CachedTranscriptionEntity, ] diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index 1689ca03..3f3876e1 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -5,7 +5,7 @@ from enum import Enum from peewee import IntegrityError from arkindex_worker import logger -from arkindex_worker.cache import CachedElement, CachedEntity +from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscriptionEntity from arkindex_worker.models import Element @@ -81,3 +81,53 @@ class EntityMixin(object): logger.warning(f"Couldn't save created entity in local cache: {e}") return entity["id"] + + def create_transcription_entity(self, transcription, entity, offset, length): + """ + Create a link between an existing entity and an existing transcription through API + """ + assert transcription and isinstance( + transcription, str + ), "transcription shouldn't be null and should be of type str" + assert entity and isinstance( + entity, str + ), "entity shouldn't be null and should be of type str" + assert ( + offset and isinstance(offset, int) and offset >= 0 + ), "offset shouldn't be null and should be a positive integer" + assert ( + length and isinstance(length, int) and length > 0 + ), "length shouldn't be null and should be a strictly positive integer" + if self.is_read_only: + logger.warning( + "Cannot create transcription entity as this worker is in read-only mode" + ) + return + + self.request( + "CreateTranscriptionEntity", + id=transcription, + body={ + "entity": entity, + "length": length, + "offset": offset, + }, + ) + # TODO: Report transcription entity creation + + if self.use_cache: + # Store transcription entity in local cache + try: + to_insert = [ + { + "transcription": transcription, + "entity": entity, + "offset": offset, + "length": length, + } + ] + CachedTranscriptionEntity.insert_many(to_insert).execute() + except IntegrityError as e: + logger.warning( + f"Couldn't save created transcription entity in local cache: {e}" + ) diff --git a/tests/test_cache.py b/tests/test_cache.py index bc26c000..ce4ad7f0 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -58,6 +58,7 @@ def test_create_tables(tmp_path): CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) +CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" actual_schema = "\n".join( diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index b03edb0f..d32d3cfb 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -5,7 +5,12 @@ from uuid import UUID import pytest from apistar.exceptions import ErrorResponse -from arkindex_worker.cache import CachedElement, CachedEntity +from arkindex_worker.cache import ( + CachedElement, + CachedEntity, + CachedTranscription, + CachedTranscriptionEntity, +) from arkindex_worker.models import Element from arkindex_worker.worker import EntityType @@ -258,3 +263,232 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache): worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ) ] + + +def test_create_transcription_entity_wrong_transcription(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entity( + transcription=None, + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=10, + ) + assert str(e.value) == "transcription shouldn't be null and should be of type str" + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entity( + transcription=1234, + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=10, + ) + assert str(e.value) == "transcription shouldn't be null and should be of type str" + + +def test_create_transcription_entity_wrong_entity(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity=None, + offset=5, + length=10, + ) + assert str(e.value) == "entity shouldn't be null and should be of type str" + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity=1234, + offset=5, + length=10, + ) + assert str(e.value) == "entity shouldn't be null and should be of type str" + + +def test_create_transcription_entity_wrong_offset(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=None, + length=10, + ) + assert str(e.value) == "offset shouldn't be null and should be a positive integer" + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset="not an int", + length=10, + ) + assert str(e.value) == "offset shouldn't be null and should be a positive integer" + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=-1, + length=10, + ) + assert str(e.value) == "offset shouldn't be null and should be a positive integer" + + +def test_create_transcription_entity_wrong_length(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=None, + ) + assert ( + str(e.value) + == "length shouldn't be null and should be a strictly positive integer" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length="not an int", + ) + assert ( + str(e.value) + == "length shouldn't be null and should be a strictly positive integer" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=0, + ) + assert ( + str(e.value) + == "length shouldn't be null and should be a strictly positive integer" + ) + + +def test_create_transcription_entity_api_error(responses, mock_elements_worker): + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + status=500, + ) + + with pytest.raises(ErrorResponse): + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=10, + ) + + assert len(responses.calls) == 7 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + # We retry 5 times the API call + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + ] + + +def test_create_transcription_entity(responses, mock_elements_worker): + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + status=200, + json={ + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + }, + ) + + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=10, + ) + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + ] + assert json.loads(responses.calls[2].request.body) == { + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + } + + +def test_create_transcription_entity_with_cache( + responses, mock_elements_worker_with_cache +): + CachedElement.create( + id=UUID("12341234-1234-1234-1234-123412341234"), + type="page", + ) + CachedTranscription.create( + id=UUID("11111111-1111-1111-1111-111111111111"), + element=UUID("12341234-1234-1234-1234-123412341234"), + text="Hello, it's me.", + confidence=0.42, + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ) + CachedEntity.create( + id=UUID("11111111-1111-1111-1111-111111111111"), + type="person", + name="Bob Bob", + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + ) + + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + status=200, + json={ + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + }, + ) + + mock_elements_worker_with_cache.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=10, + ) + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + ] + assert json.loads(responses.calls[2].request.body) == { + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + } + + # Check that created transcription entity was properly stored in SQLite cache + assert list(CachedTranscriptionEntity.select()) == [ + CachedTranscriptionEntity( + transcription=UUID("11111111-1111-1111-1111-111111111111"), + entity=UUID("11111111-1111-1111-1111-111111111111"), + offset=5, + length=10, + ) + ] -- GitLab