diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index aa40bbcbecd77f345a488607082747c55d6fa8b5..ddea3371c3b1a24366d33878420731df7ec01248 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -68,6 +68,7 @@ class CachedElement(Model): mirrored = BooleanField(default=False) initial = BooleanField(default=False) worker_version_id = UUIDField(null=True) + confidence = FloatField(null=True) class Meta: database = db diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py index a89528baa02055493d3ff885dcaaa815e4a62c5b..2abced62c599fb30d2e8f0b9776cab99804b5d18 100644 --- a/arkindex_worker/worker/element.py +++ b/arkindex_worker/worker/element.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import uuid +from typing import Dict, List, Optional, Union from peewee import IntegrityError @@ -40,10 +41,26 @@ class ElementMixin(object): return True - def create_sub_element(self, element, type, name, polygon): + def create_sub_element( + self, + element: Element, + type: str, + name: str, + polygon: List[List[Union[int, float]]], + confidence: Optional[float] = None, + ) -> str: """ - Create a child element on the given element through API - Return the ID of the created sub element + Create a child element on the given element through the API. + + :param Element element: The parent element. + :param str type: Slug of the element type for this child element. + :param str name: Name of the child element. + :param polygon: Polygon of the child element. + :type polygon: list(list(int or float)) + :param confidence: Optional confidence score, between 0.0 and 1.0. + :type confidence: float or None + :returns: UUID of the created element. + :rtype: str """ assert element and isinstance( element, Element @@ -64,6 +81,10 @@ class ElementMixin(object): assert all( isinstance(coord, (int, float)) for point in polygon for coord in point ), "polygon points should be lists of two numbers" + assert confidence is None or ( + isinstance(confidence, float) and 0 <= confidence <= 1 + ), "confidence should be None or a float in [0..1] range" + if self.is_read_only: logger.warning("Cannot create element as this worker is in read-only mode") return @@ -78,16 +99,43 @@ class ElementMixin(object): "polygon": polygon, "parent": element.id, "worker_version": self.worker_version_id, + "confidence": confidence, }, ) self.report.add_element(element.id, type) return sub_element["id"] - def create_elements(self, parent, elements): + def create_elements( + self, + parent: Union[Element, CachedElement], + elements: List[ + Dict[str, Union[str, List[List[Union[int, float]]], float, None]] + ], + ) -> List[Dict[str, str]]: """ - Create children elements on the given element through API - Return the IDs of created elements + Create child elements on the given element in a single API request. + + :param parent: Parent element for all the new child elements. The parent must have an image and a polygon. + :type parent: Element or CachedElement + :param elements: List of dicts, one per element. Each dict can have the following keys: + + name (str) + Required. Name of the element. + + type (str) + Required. Slug of the element type for this element. + + polygon (list(list(int or float))) + Required. Polygon for this child element. Must have at least three points, with each point + having two non-negative coordinates and being inside of the parent element's image. + + confidence (float or None) + Optional confidence score, between 0.0 and 1.0. + + :type elements: list(dict(str, Any)) + :return: List of dicts, with each dict having a single key, ``id``, holding the UUID of each created element. + :rtype: list(dict(str, str)) """ if isinstance(parent, Element): assert parent.get( @@ -135,6 +183,11 @@ class ElementMixin(object): isinstance(coord, (int, float)) for point in polygon for coord in point ), f"Element at index {index} in elements: polygon points should be lists of two numbers" + confidence = element.get("confidence") + assert confidence is None or ( + isinstance(confidence, float) and 0 <= confidence <= 1 + ), f"Element at index {index} in elements: confidence should be None or a float in [0..1] range" + if self.is_read_only: logger.warning("Cannot create elements as this worker is in read-only mode") return @@ -176,6 +229,7 @@ class ElementMixin(object): "image_id": image_id, "polygon": element["polygon"], "worker_version_id": self.worker_version_id, + "confidence": element.get("confidence"), } for idx, element in enumerate(elements) ] diff --git a/tests/test_cache.py b/tests/test_cache.py index 94a2a3e4bb84e1db42fe6bc8d82475b3d10a711d..8dae516bf85afc8753e45f3ba26db8201d00d15f 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -59,7 +59,7 @@ def test_create_tables(tmp_path): create_tables() expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id")) -CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) +CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, "confidence" REAL, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index fba5fe56cb0104796381d8d0b3fdc2c1e3a6865c..7221531206ed463f8afeeacb551bf9ee095f3585 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -424,6 +424,19 @@ def test_create_sub_element_wrong_polygon(mock_elements_worker): assert str(e.value) == "polygon points should be lists of two numbers" +@pytest.mark.parametrize("confidence", ["lol", "0.2", -1.0, 1.42, float("inf")]) +def test_create_sub_element_wrong_confidence(mock_elements_worker, confidence): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_sub_element( + element=Element({"zone": None}), + type="something", + name="blah", + polygon=[[0, 0], [0, 10], [10, 10], [10, 0], [0, 0]], + confidence=confidence, + ) + assert str(e.value) == "confidence should be None or a float in [0..1] range" + + def test_create_sub_element_api_error(responses, mock_elements_worker): elt = Element( { @@ -495,6 +508,49 @@ def test_create_sub_element(responses, mock_elements_worker): "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], "parent": "12341234-1234-1234-1234-123412341234", "worker_version": "12341234-1234-1234-1234-123412341234", + "confidence": None, + } + assert sub_element_id == "12345678-1234-1234-1234-123456789123" + + +def test_create_sub_element_confidence(responses, mock_elements_worker): + elt = Element( + { + "id": "12341234-1234-1234-1234-123412341234", + "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, + "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}}, + } + ) + responses.add( + responses.POST, + "http://testserver/api/v1/elements/create/", + status=200, + json={"id": "12345678-1234-1234-1234-123456789123"}, + ) + + sub_element_id = mock_elements_worker.create_sub_element( + element=elt, + type="something", + name="0", + polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], + confidence=0.42, + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ("POST", "http://testserver/api/v1/elements/create/"), + ] + assert json.loads(responses.calls[-1].request.body) == { + "type": "something", + "name": "0", + "image": "22222222-2222-2222-2222-222222222222", + "corpus": "11111111-1111-1111-1111-111111111111", + "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], + "parent": "12341234-1234-1234-1234-123412341234", + "worker_version": "12341234-1234-1234-1234-123412341234", + "confidence": 0.42, } assert sub_element_id == "12345678-1234-1234-1234-123456789123" @@ -740,6 +796,26 @@ def test_create_elements_wrong_elements_polygon(mock_elements_worker): ) +@pytest.mark.parametrize("confidence", ["lol", "0.2", -1.0, 1.42, float("inf")]) +def test_create_elements_wrong_elements_confidence(mock_elements_worker, confidence): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_elements( + parent=Element({"zone": {"image": {"id": "image_id"}}}), + elements=[ + { + "name": "a", + "type": "something", + "polygon": [[0, 0], [0, 10], [10, 10], [10, 0], [0, 0]], + "confidence": confidence, + } + ], + ) + assert ( + str(e.value) + == "Element at index 0 in elements: confidence should be None or a float in [0..1] range" + ) + + def test_create_elements_api_error(responses, mock_elements_worker): elt = Element( { @@ -930,6 +1006,80 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path): image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + confidence=None, + ) + ] + + +def test_create_elements_confidence( + responses, mock_elements_worker_with_cache, tmp_path +): + elt = Element( + { + "id": "12341234-1234-1234-1234-123412341234", + "zone": { + "image": { + "id": "c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", + "width": 42, + "height": 42, + "url": "http://aaaa", + } + }, + } + ) + responses.add( + responses.POST, + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", + status=200, + json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}], + ) + + created_ids = mock_elements_worker_with_cache.create_elements( + parent=elt, + elements=[ + { + "name": "0", + "type": "something", + "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], + "confidence": 0.42, + } + ], + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "POST", + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", + ), + ] + assert json.loads(responses.calls[-1].request.body) == { + "elements": [ + { + "name": "0", + "type": "something", + "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], + "confidence": 0.42, + } + ], + "worker_version": "12341234-1234-1234-1234-123412341234", + } + assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}] + + # Check that created elements were properly stored in SQLite cache + assert (tmp_path / "db.sqlite").is_file() + + assert list(CachedElement.select()) == [ + CachedElement( + id=UUID("497f6eca-6276-4993-bfeb-53cbbbba6f08"), + parent_id=UUID("12341234-1234-1234-1234-123412341234"), + type="something", + image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", + polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], + worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), + confidence=0.42, ) ] diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index 048c01688dcb25db602953725c0cd9ad4b1e921e..d15d3d2310eaf855f2d75a232d2387e221facf7e 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -345,6 +345,7 @@ def test_create_transcription_orientation_with_cache( "mirrored": False, "initial": False, "worker_version_id": None, + "confidence": None, }, "text": "Animula vagula blandula", "confidence": 0.42, @@ -809,6 +810,7 @@ def test_create_transcriptions_orientation(responses, mock_elements_worker_with_ "mirrored": False, "initial": False, "worker_version_id": None, + "confidence": None, }, "text": "Animula vagula blandula", "confidence": 0.12, @@ -827,6 +829,7 @@ def test_create_transcriptions_orientation(responses, mock_elements_worker_with_ "mirrored": False, "initial": False, "worker_version_id": None, + "confidence": None, }, "text": "Hospes comesque corporis", "confidence": 0.21, @@ -1585,6 +1588,7 @@ def test_create_transcriptions_orientation_with_cache( "mirrored": False, "initial": False, "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + "confidence": None, }, "text": "Animula vagula blandula", "confidence": 0.5, @@ -1603,6 +1607,7 @@ def test_create_transcriptions_orientation_with_cache( "mirrored": False, "initial": False, "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + "confidence": None, }, "text": "Hospes comesque corporis", "confidence": 0.75, @@ -1621,6 +1626,7 @@ def test_create_transcriptions_orientation_with_cache( "mirrored": False, "initial": False, "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + "confidence": None, }, "text": "Quae nunc abibis in loca", "confidence": 0.9,