From b1dd27ef32cad530d8fc1e31e166eaeaae165406 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Wed, 1 Mar 2023 10:12:04 +0000 Subject: [PATCH] Bulk endpoint for transcription entities --- arkindex_worker/worker/entity.py | 112 +++++++++- tests/test_elements_worker/test_entities.py | 236 ++++++++++++++++++++ tests/test_reporting.py | 7 +- 3 files changed, 351 insertions(+), 4 deletions(-) diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index 91dd77c6..e6316cf8 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -3,7 +3,7 @@ ElementsWorker methods for entities. """ -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, TypedDict, Union from peewee import IntegrityError @@ -11,6 +11,10 @@ from arkindex_worker import logger from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscriptionEntity from arkindex_worker.models import Element, Transcription +Entity = TypedDict( + "Entity", name=str, type_id=str, length=int, offset=int, confidence=Optional[float] +) + class MissingEntityType(Exception): """ @@ -108,7 +112,7 @@ class EntityMixin(object): "worker_run_id": self.worker_run_id, }, ) - self.report.add_entity(element.id, entity["id"], type, name) + self.report.add_entity(element.id, entity["id"], entity_type_id, name) if self.use_cache: # Store entity in local cache @@ -204,6 +208,110 @@ class EntityMixin(object): ) return transcription_ent + def create_transcription_entities( + self, + transcription: Transcription, + entities: List[Entity], + ) -> List[Dict[str, str]]: + """ + Create multiple entities attached to a transcription in a single API request. + + :param transcription: Transcription to create the entity on. + :param entities: List of dicts, one per element. Each dict can have the following keys: + + name (str) + Required. Name of the entity. + + type_id (str) + Required. ID of the EntityType of the entity. + + length (int) + Required. Length of the entity in the transcription's text. + + offset (int) + Required. Starting position of the entity in the transcription's text, as a 0-based index. + + confidence (float or None) + Optional confidence score, between 0.0 and 1.0. + + :return: List of dicts, with each dict having a two keys, `transcription_entity_id` and `entity_id`, holding the UUID of each created object. + """ + assert transcription and isinstance( + transcription, Transcription + ), "transcription shouldn't be null and should be of type Transcription" + + # Needed for MLreport + assert ( + hasattr(transcription, "element") and transcription.element + ), f"No element linked to {transcription}" + + assert entities and isinstance( + entities, list + ), "entities shouldn't be null and should be of type list" + + for index, entity in enumerate(entities): + assert isinstance( + entity, dict + ), f"Entity at index {index} in entities: Should be of type dict" + + name = entity.get("name") + assert name and isinstance( + name, str + ), f"Entity at index {index} in entities: name shouldn't be null and should be of type str" + + type_id = entity.get("type_id") + assert type_id and isinstance( + type_id, str + ), f"Entity at index {index} in entities: type_id shouldn't be null and should be of type str" + + offset = entity.get("offset") + assert ( + offset is not None and isinstance(offset, int) and offset >= 0 + ), f"Entity at index {index} in entities: offset shouldn't be null and should be a positive integer" + + length = entity.get("length") + assert ( + length is not None and isinstance(length, int) and length > 0 + ), f"Entity at index {index} in entities: length shouldn't be null and should be a strictly positive integer" + + confidence = entity.get("confidence") + assert confidence is None or ( + isinstance(confidence, float) and 0 <= confidence <= 1 + ), f"Entity at index {index} in entities: confidence should be None or a float in [0..1] range" + + if self.is_read_only: + logger.warning( + "Cannot create transcription entities in bulk as this worker is in read-only mode" + ) + return + + created_ids = self.request( + "CreateTranscriptionEntities", + id=transcription.id, + body={ + "worker_run_id": self.worker_run_id, + "entities": entities, + }, + ) + + for entity, created_objects in zip(entities, created_ids["entities"]): + # Report entity creation + self.report.add_entity( + transcription.element.id, + created_objects["entity_id"], + entity.get("type_id"), + entity.get("name"), + ) + + # Report transcription entity creation + self.report.add_transcription_entity( + created_objects["entity_id"], + transcription, + created_objects["transcription_entity_id"], + ) + + return created_ids["entities"] + def list_transcription_entities( self, transcription: Transcription, diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index 4d1773a7..b989f3b3 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -893,3 +893,239 @@ def test_check_required_entity_types_no_creation_allowed( assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + + +@pytest.mark.parametrize("transcription", (None, "not a transcription", 1)) +def test_create_transcription_entities_wrong_transcription( + mock_elements_worker, transcription +): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entities( + transcription=transcription, + entities=[], + ) + assert ( + str(e.value) + == "transcription shouldn't be null and should be of type Transcription" + ) + + +def test_create_transcription_entities_no_transcription_element(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entities( + transcription=Transcription(id="transcription_id"), + entities=[], + ) + assert str(e.value) == "No element linked to Transcription (transcription_id)" + + +@pytest.mark.parametrize("entities", (None, "not a list of entities", 1)) +def test_create_transcription_entities_wrong_entities(mock_elements_worker, entities): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entities( + transcription=Transcription( + id="transcription_id", element={"id": "element_id"} + ), + entities=entities, + ) + assert str(e.value) == "entities shouldn't be null and should be of type list" + + +def test_create_transcription_entities_wrong_entities_subtype(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entities( + transcription=Transcription( + id="transcription_id", element={"id": "element_id"} + ), + entities=["not a dict"], + ) + assert str(e.value) == "Entity at index 0 in entities: Should be of type dict" + + +@pytest.mark.parametrize( + "entity, error", + ( + ( + { + "name": None, + "type_id": "12341234-1234-1234-1234-123412341234", + "offset": 0, + "length": 1, + "confidence": 0.5, + }, + "Entity at index 0 in entities: name shouldn't be null and should be of type str", + ), + ( + {"name": "A", "type_id": None, "offset": 0, "length": 1, "confidence": 0.5}, + "Entity at index 0 in entities: type_id shouldn't be null and should be of type str", + ), + ( + {"name": "A", "type_id": 0, "offset": 0, "length": 1, "confidence": 0.5}, + "Entity at index 0 in entities: type_id shouldn't be null and should be of type str", + ), + ( + { + "name": "A", + "type_id": "12341234-1234-1234-1234-123412341234", + "offset": None, + "length": 1, + "confidence": 0.5, + }, + "Entity at index 0 in entities: offset shouldn't be null and should be a positive integer", + ), + ( + { + "name": "A", + "type_id": "12341234-1234-1234-1234-123412341234", + "offset": -2, + "length": 1, + "confidence": 0.5, + }, + "Entity at index 0 in entities: offset shouldn't be null and should be a positive integer", + ), + ( + { + "name": "A", + "type_id": "12341234-1234-1234-1234-123412341234", + "offset": 0, + "length": None, + "confidence": 0.5, + }, + "Entity at index 0 in entities: length shouldn't be null and should be a strictly positive integer", + ), + ( + { + "name": "A", + "type_id": "12341234-1234-1234-1234-123412341234", + "offset": 0, + "length": 0, + "confidence": 0.5, + }, + "Entity at index 0 in entities: length shouldn't be null and should be a strictly positive integer", + ), + ( + { + "name": "A", + "type_id": "12341234-1234-1234-1234-123412341234", + "offset": 0, + "length": 1, + "confidence": "not None or a float", + }, + "Entity at index 0 in entities: confidence should be None or a float in [0..1] range", + ), + ( + { + "name": "A", + "type_id": "12341234-1234-1234-1234-123412341234", + "offset": 0, + "length": 1, + "confidence": 1.3, + }, + "Entity at index 0 in entities: confidence should be None or a float in [0..1] range", + ), + ), +) +def test_create_transcription_entities_wrong_entity( + mock_elements_worker, entity, error +): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription_entities( + transcription=Transcription( + id="transcription_id", element={"id": "element_id"} + ), + entities=[entity], + ) + assert str(e.value) == error + + +def test_create_transcription_entities(responses, mock_elements_worker): + element_id = "element_id" + transcription = Transcription(id="transcription-id", element={"id": element_id}) + # Call to Transcription entities creation in bulk + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/transcription-id/entities/bulk/", + status=201, + match=[ + matchers.json_params_matcher( + { + "worker_run_id": "56785678-5678-5678-5678-567856785678", + "entities": [ + { + "name": "Teklia", + "type_id": "22222222-2222-2222-2222-222222222222", + "offset": 0, + "length": 6, + "confidence": 1.0, + } + ], + } + ) + ], + json={ + "entities": [ + { + "transcription_entity_id": "transc-entity-id", + "entity_id": "entity-id", + } + ] + }, + ) + + # Store entity type/slug correspondence on the worker + mock_elements_worker.entity_types = { + "22222222-2222-2222-2222-222222222222": "organization" + } + created_objects = mock_elements_worker.create_transcription_entities( + transcription=transcription, + entities=[ + { + "name": "Teklia", + "type_id": "22222222-2222-2222-2222-222222222222", + "offset": 0, + "length": 6, + "confidence": 1.0, + } + ], + ) + + assert len(created_objects) == 1 + + assert element_id in mock_elements_worker.report.report_data["elements"] + ml_report = mock_elements_worker.report.report_data["elements"][element_id] + + assert "started" in ml_report + del ml_report["started"] + + # Check reporting + assert ml_report == { + "elements": {}, + "transcriptions": 0, + "classifications": {}, + "entities": [ + { + "id": "entity-id", + "type": "22222222-2222-2222-2222-222222222222", + "name": "Teklia", + } + ], + "transcription_entities": [ + { + "transcription_id": "transcription-id", + "entity_id": "entity-id", + "transcription_entity_id": "transc-entity-id", + } + ], + "metadata": [], + "errors": [], + } + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "POST", + "http://testserver/api/v1/transcription/transcription-id/entities/bulk/", + ), + ] diff --git a/tests/test_reporting.py b/tests/test_reporting.py index 1ef13e85..af214f01 100644 --- a/tests/test_reporting.py +++ b/tests/test_reporting.py @@ -167,7 +167,10 @@ def test_add_transcription_count(): def test_add_entity(): reporter = Reporter("worker") reporter.add_entity( - "myelement", "12341234-1234-1234-1234-123412341234", "person", "Bob Bob" + "myelement", + "12341234-1234-1234-1234-123412341234", + "person-entity-type-id", + "Bob Bob", ) assert "myelement" in reporter.report_data["elements"] element_data = reporter.report_data["elements"]["myelement"] @@ -179,7 +182,7 @@ def test_add_entity(): "entities": [ { "id": "12341234-1234-1234-1234-123412341234", - "type": "person", + "type": "person-entity-type-id", "name": "Bob Bob", } ], -- GitLab