From 61dd9cb6a64e5b073a15a89c04d5b193997c4480 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Thu, 18 Aug 2022 10:42:05 +0000 Subject: [PATCH] Support worker_version=False in both cache and no cache mode --- arkindex_worker/cache.py | 8 +- arkindex_worker/worker/element.py | 22 ++++-- arkindex_worker/worker/transcription.py | 31 +++++--- tests/conftest.py | 17 +++- tests/test_cache.py | 8 +- tests/test_elements_worker/test_elements.py | 70 ++++++++++++++++- .../test_transcriptions.py | 78 +++++++++++++++++-- 7 files changed, 198 insertions(+), 36 deletions(-) diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index e96d4e37..40e756c6 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -157,7 +157,7 @@ class CachedTranscription(Model): text = TextField() confidence = FloatField() orientation = CharField(max_length=50) - worker_version_id = UUIDField() + worker_version_id = UUIDField(null=True) class Meta: database = db @@ -170,7 +170,7 @@ class CachedClassification(Model): class_name = TextField() confidence = FloatField() state = CharField(max_length=10) - worker_version_id = UUIDField() + worker_version_id = UUIDField(null=True) class Meta: database = db @@ -183,7 +183,7 @@ class CachedEntity(Model): name = TextField() validated = BooleanField(default=False) metas = JSONField(null=True) - worker_version_id = UUIDField() + worker_version_id = UUIDField(null=True) class Meta: database = db @@ -197,7 +197,7 @@ class CachedTranscriptionEntity(Model): entity = ForeignKeyField(CachedEntity, backref="transcription_entities") offset = IntegerField(constraints=[Check("offset >= 0")]) length = IntegerField(constraints=[Check("length > 0")]) - worker_version_id = UUIDField() + worker_version_id = UUIDField(null=True) confidence = FloatField(null=True) class Meta: diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py index e2342d9c..82f53c28 100644 --- a/arkindex_worker/worker/element.py +++ b/arkindex_worker/worker/element.py @@ -261,7 +261,7 @@ class ElementMixin(object): with_corpus: Optional[bool] = None, with_has_children: Optional[bool] = None, with_zone: Optional[bool] = None, - worker_version: Optional[str] = None, + worker_version: Optional[Union[str, bool]] = None, ) -> Union[Iterable[dict], Iterable[CachedElement]]: """ List children of an element. @@ -295,7 +295,7 @@ class ElementMixin(object): This parameter is not supported when caching is enabled. :type with_zone: Optional[bool] :param worker_version: Restrict to elements created by a worker version with this UUID. - :type worker_version: Optional[str] + :type worker_version: Optional[Union[str, bool]] :return: An iterable of dicts from the ``ListElementChildren`` API endpoint, or an iterable of :class:`CachedElement` when caching is enabled. :rtype: Union[Iterable[dict], Iterable[CachedElement]] @@ -330,10 +330,14 @@ class ElementMixin(object): if with_zone is not None: assert isinstance(with_zone, bool), "with_zone should be of type bool" query_params["with_zone"] = with_zone - if worker_version: + if worker_version is not None: assert isinstance( - worker_version, str - ), "worker_version should be of type str" + worker_version, (str, bool) + ), "worker_version should be of type str or bool" + if isinstance(worker_version, bool): + assert ( + worker_version is False + ), "if of type bool, worker_version can only be set to False" query_params["worker_version"] = worker_version if self.use_cache: @@ -346,8 +350,12 @@ class ElementMixin(object): query = CachedElement.select().where(CachedElement.parent_id == element.id) if type: query = query.where(CachedElement.type == type) - if worker_version: - query = query.where(CachedElement.worker_version_id == worker_version) + if worker_version is not None: + # If worker_version=False, filter by manual worker_version e.g. None + worker_version_id = worker_version if worker_version else None + query = query.where( + CachedElement.worker_version_id == worker_version_id + ) return query else: diff --git a/arkindex_worker/worker/transcription.py b/arkindex_worker/worker/transcription.py index 4afa0a08..42b27f3c 100644 --- a/arkindex_worker/worker/transcription.py +++ b/arkindex_worker/worker/transcription.py @@ -4,6 +4,7 @@ ElementsWorker methods for transcriptions. """ from enum import Enum +from typing import Iterable, Optional, Union from peewee import IntegrityError @@ -348,8 +349,12 @@ class TranscriptionMixin(object): return annotations def list_transcriptions( - self, element, element_type=None, recursive=None, worker_version=None - ): + self, + element: Union[Element, CachedElement], + element_type: Optional[str] = None, + recursive: Optional[bool] = None, + worker_version: Optional[Union[str, bool]] = None, + ) -> Union[Iterable[dict], Iterable[CachedTranscription]]: """ List transcriptions on an element. @@ -359,11 +364,11 @@ class TranscriptionMixin(object): :type element_type: Optional[str] :param recursive: Include transcriptions of any descendant of this element, recursively. :type recursive: Optional[bool] - :param worker_version: Restrict to transcriptions created by a worker version with this UUID. - :type worker_version: Optional[str] + :param worker_version: Restrict to transcriptions created by a worker version with this UUID. Set to False to look for manually created transcriptions. + :type worker_version: Optional[Union[str, bool]] :returns: An iterable of dicts representing each transcription, or an iterable of CachedTranscription when cache support is enabled. - :rtype: Iterable[dict] or Iterable[CachedTranscription] + :rtype: Union[Iterable[dict], Iterable[CachedTranscription]] """ assert element and isinstance( element, (Element, CachedElement) @@ -375,10 +380,14 @@ class TranscriptionMixin(object): if recursive is not None: assert isinstance(recursive, bool), "recursive should be of type bool" query_params["recursive"] = recursive - if worker_version: + if worker_version is not None: assert isinstance( - worker_version, str - ), "worker_version should be of type str" + worker_version, (str, bool) + ), "worker_version should be of type str or bool" + if isinstance(worker_version, bool): + assert ( + worker_version is False + ), "if of type bool, worker_version can only be set to False" query_params["worker_version"] = worker_version if self.use_cache: @@ -409,9 +418,11 @@ class TranscriptionMixin(object): if element_type: transcriptions = transcriptions.where(cte.c.type == element_type) - if worker_version: + if worker_version is not None: + # If worker_version=False, filter by manual worker_version e.g. None + worker_version_id = worker_version if worker_version else None transcriptions = transcriptions.where( - CachedTranscription.worker_version_id == worker_version + CachedTranscription.worker_version_id == worker_version_id ) else: transcriptions = self.api_client.paginate( diff --git a/tests/conftest.py b/tests/conftest.py index 35ac31ae..d4052fd3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -327,7 +327,14 @@ def mock_cached_elements(): polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), ) - assert CachedElement.select().count() == 2 + CachedElement.create( + id=UUID("33333333-3333-3333-3333-333333333333"), + parent_id=UUID("12341234-1234-1234-1234-123412341234"), + type="paragraph", + polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]", + worker_version_id=None, + ) + assert CachedElement.select().count() == 3 @pytest.fixture @@ -407,6 +414,14 @@ def mock_cached_transcriptions(): orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), ) + CachedTranscription.create( + id=UUID("66666666-6666-6666-6666-666666666666"), + element_id=UUID("11111111-1111-1111-1111-111111111111"), + text="This is a manual one", + confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, + worker_version_id=None, + ) @pytest.fixture(scope="function") diff --git a/tests/test_cache.py b/tests/test_cache.py index 8dae516b..7dcde0cb 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -58,12 +58,12 @@ def test_create_tables(tmp_path): init_cache_db(db_path) create_tables() - expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id")) + expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id")) CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, "confidence" REAL, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) -CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) +CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) -CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) -CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" +CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) +CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" actual_schema = "\n".join( [ diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 72215312..be2d56dd 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -1255,7 +1255,18 @@ def test_list_element_children_wrong_worker_version(mock_elements_worker): element=elt, worker_version=1234, ) - assert str(e.value) == "worker_version should be of type str" + assert str(e.value) == "worker_version should be of type str or bool" + + +def test_list_element_children_wrong_bool_worker_version(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_element_children( + element=elt, + worker_version=True, + ) + assert str(e.value) == "if of type bool, worker_version can only be set to False" def test_list_element_children_api_error(responses, mock_elements_worker): @@ -1363,6 +1374,48 @@ def test_list_element_children(responses, mock_elements_worker): ] +def test_list_element_children_manual_worker_version(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + expected_children = [ + { + "id": "0000", + "type": "page", + "name": "Test", + "corpus": {}, + "thumbnail_url": None, + "zone": {}, + "best_classes": None, + "has_children": None, + "worker_version_id": None, + } + ] + responses.add( + responses.GET, + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False", + status=200, + json={ + "count": 1, + "next": None, + "results": expected_children, + }, + ) + + for idx, child in enumerate( + mock_elements_worker.list_element_children(element=elt, worker_version=False) + ): + assert child == expected_children[idx] + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "GET", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False", + ), + ] + + def test_list_element_children_with_cache_unhandled_param( mock_elements_worker_with_cache, ): @@ -1389,6 +1442,7 @@ def test_list_element_children_with_cache_unhandled_param( ( "11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222", + "33333333-3333-3333-3333-333333333333", ), ), # Filter on element and page should give the second element @@ -1399,7 +1453,7 @@ def test_list_element_children_with_cache_unhandled_param( }, ("22222222-2222-2222-2222-222222222222",), ), - # Filter on element and worker version should give all elements + # Filter on element and worker version should give first two elements ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), @@ -1410,7 +1464,7 @@ def test_list_element_children_with_cache_unhandled_param( "22222222-2222-2222-2222-222222222222", ), ), - # Filter on element, type something and worker version should give first + # Filter on element, type something and worker version should give first ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), @@ -1419,6 +1473,14 @@ def test_list_element_children_with_cache_unhandled_param( }, ("11111111-1111-1111-1111-111111111111",), ), + # Filter on element, manual worker version should give third + ( + { + "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), + "worker_version": False, + }, + ("33333333-3333-3333-3333-333333333333",), + ), ), ) def test_list_element_children_with_cache( @@ -1430,7 +1492,7 @@ def test_list_element_children_with_cache( ): # Check we have 2 elements already present in database - assert CachedElement.select().count() == 2 + assert CachedElement.select().count() == 3 # Query database through cache elements = mock_elements_worker_with_cache.list_element_children(**filters) diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index d15d3d23..0411db22 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -1682,7 +1682,18 @@ def test_list_transcriptions_wrong_worker_version(mock_elements_worker): element=elt, worker_version=1234, ) - assert str(e.value) == "worker_version should be of type str" + assert str(e.value) == "worker_version should be of type str or bool" + + +def test_list_transcriptions_wrong_bool_worker_version(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_transcriptions( + element=elt, + worker_version=True, + ) + assert str(e.value) == "if of type bool, worker_version can only be set to False" def test_list_transcriptions_api_error(responses, mock_elements_worker): @@ -1778,19 +1789,60 @@ def test_list_transcriptions(responses, mock_elements_worker): ] +def test_list_transcriptions_manual_worker_version(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + trans = [ + { + "id": "0000", + "text": "hey", + "confidence": 0.42, + "worker_version_id": None, + "element": None, + } + ] + responses.add( + responses.GET, + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?worker_version=False", + status=200, + json={ + "count": 1, + "next": None, + "results": trans, + }, + ) + + for idx, transcription in enumerate( + mock_elements_worker.list_transcriptions(element=elt, worker_version=False) + ): + assert transcription == trans[idx] + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "GET", + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?worker_version=False", + ), + ] + + @pytest.mark.parametrize( "filters, expected_ids", ( - # Filter on element should give first transcription + # Filter on element should give first and sixth transcription ( { "element": CachedElement( id="11111111-1111-1111-1111-111111111111", type="page" ), }, - ("11111111-1111-1111-1111-111111111111",), + ( + "11111111-1111-1111-1111-111111111111", + "66666666-6666-6666-6666-666666666666", + ), ), - # Filter on element and element_type should give first transcription + # Filter on element and element_type should give first and sixth transcription ( { "element": CachedElement( @@ -1798,7 +1850,10 @@ def test_list_transcriptions(responses, mock_elements_worker): ), "element_type": "page", }, - ("11111111-1111-1111-1111-111111111111",), + ( + "11111111-1111-1111-1111-111111111111", + "66666666-6666-6666-6666-666666666666", + ), ), # Filter on element and worker_version should give first transcription ( @@ -1824,6 +1879,7 @@ def test_list_transcriptions(responses, mock_elements_worker): "33333333-3333-3333-3333-333333333333", "44444444-4444-4444-4444-444444444444", "55555555-5555-5555-5555-555555555555", + "66666666-6666-6666-6666-666666666666", ), ), # Filter recursively on element and worker_version should give four transcriptions @@ -1857,6 +1913,16 @@ def test_list_transcriptions(responses, mock_elements_worker): "55555555-5555-5555-5555-555555555555", ), ), + # Filter on element with manually created transcription should give sixth transcription + ( + { + "element": CachedElement( + id="11111111-1111-1111-1111-111111111111", type="page" + ), + "worker_version": False, + }, + ("66666666-6666-6666-6666-666666666666",), + ), ), ) def test_list_transcriptions_with_cache( @@ -1867,7 +1933,7 @@ def test_list_transcriptions_with_cache( expected_ids, ): # Check we have 5 elements already present in database - assert CachedTranscription.select().count() == 5 + assert CachedTranscription.select().count() == 6 # Query database through cache transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters) -- GitLab