From 6f16aeb9507e9bb79acc58c0253deb98a3bbdd4d Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Thu, 16 Feb 2023 09:09:11 +0000 Subject: [PATCH] New entity types handling --- arkindex_worker/worker/__init__.py | 6 +- arkindex_worker/worker/entity.py | 87 ++++++++--- docs/ref/api/entity.md | 2 +- tests/test_elements_worker/test_entities.py | 155 +++++++++++++++++--- 4 files changed, 207 insertions(+), 43 deletions(-) diff --git a/arkindex_worker/worker/__init__.py b/arkindex_worker/worker/__init__.py index af509a2b..61a9f7af 100644 --- a/arkindex_worker/worker/__init__.py +++ b/arkindex_worker/worker/__init__.py @@ -19,7 +19,7 @@ from arkindex_worker.reporting import Reporter from arkindex_worker.worker.base import BaseWorker from arkindex_worker.worker.classification import ClassificationMixin from arkindex_worker.worker.element import ElementMixin -from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401 +from arkindex_worker.worker.entity import EntityMixin # noqa: F401 from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401 from arkindex_worker.worker.transcription import TranscriptionMixin from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401 @@ -92,6 +92,10 @@ class ElementsWorker( self.classes = {} + self.entity_types = {} + """Known and available entity types in processed corpus + """ + self._worker_version_cache = {} def list_elements(self) -> Union[Iterable[CachedElement], List[str]]: diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index 5be7824e..91dd77c6 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -3,8 +3,7 @@ ElementsWorker methods for entities. """ -from enum import Enum -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union from peewee import IntegrityError @@ -13,26 +12,54 @@ from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscripti from arkindex_worker.models import Element, Transcription -class EntityType(Enum): +class MissingEntityType(Exception): """ - Type of an entity. + Raised when the specified entity type was not found in the corpus and + the worker cannot create it. """ - Person = "person" - Location = "location" - Subject = "subject" - Organization = "organization" - Misc = "misc" - Number = "number" - Date = "date" - class EntityMixin(object): + def check_required_entity_types( + self, entity_types: List[str], create_missing: bool = True + ): + """Checks that every entity type needed is available in the corpus. + Missing ones may be created automatically if needed. + + :param entity_types: Entity type names to search. + :param create_missing: Whether the missing types should be created. Defaults to True. + :raises MissingEntityType: When an entity type is missing and cannot create. + """ + # Retrieve entity_type ID + if not self.entity_types: + # Load entity_types of corpus + self.list_corpus_entity_types() + for entity_type in entity_types: + # Do nothing if type already exists + if entity_type in self.entity_types: + continue + + # Do not create missing if not requested + if not create_missing: + raise MissingEntityType( + f"Entity type `{entity_type}` was not in the corpus." + ) + + # Create type if non-existent + self.entity_types[entity_type] = self.request( + "CreateEntityType", + body={ + "name": entity_type, + "corpus": self.corpus_id, + }, + )["id"] + logger.info(f"Created a new entity type with name `{entity_type}`.") + def create_entity( self, element: Union[Element, CachedElement], name: str, - type: EntityType, + type: str, metas=dict(), validated=None, ): @@ -52,8 +79,8 @@ class EntityMixin(object): name, str ), "name shouldn't be null and should be of type str" assert type and isinstance( - type, EntityType - ), "type shouldn't be null and should be of type EntityType" + type, str + ), "type shouldn't be null and should be of type str" if metas: assert isinstance(metas, dict), "metas should be of type dict" if validated is not None: @@ -62,18 +89,26 @@ class EntityMixin(object): logger.warning("Cannot create entity as this worker is in read-only mode") return + # Retrieve entity_type ID + if not self.entity_types: + # Load entity_types of corpus + self.list_corpus_entity_types() + + entity_type_id = self.entity_types.get(type) + assert entity_type_id, f"Entity type `{type}` not found in the corpus." + entity = self.request( "CreateEntity", body={ "name": name, - "type": type.value, + "type_id": entity_type_id, "metas": metas, "validated": validated, "corpus": self.corpus_id, "worker_run_id": self.worker_run_id, }, ) - self.report.add_entity(element.id, entity["id"], type.value, name) + self.report.add_entity(element.id, entity["id"], type, name) if self.use_cache: # Store entity in local cache @@ -81,7 +116,7 @@ class EntityMixin(object): to_insert = [ { "id": entity["id"], - "type": type.value, + "type": type, "name": name, "validated": validated if validated is not None else False, "metas": metas, @@ -225,3 +260,19 @@ class EntityMixin(object): return self.api_client.paginate( "ListCorpusEntities", id=self.corpus_id, **query_params ) + + def list_corpus_entity_types( + self, + ): + """ + Loads available entity types in corpus. + """ + self.entity_types = { + entity_type["name"]: entity_type["id"] + for entity_type in self.api_client.paginate( + "ListCorpusEntityTypes", id=self.corpus_id + ) + } + logger.info( + f"Loaded {len(self.entity_types)} entity types in corpus ({self.corpus_id})." + ) diff --git a/docs/ref/api/entity.md b/docs/ref/api/entity.md index c7ebbe85..ca6a9d06 100644 --- a/docs/ref/api/entity.md +++ b/docs/ref/api/entity.md @@ -3,7 +3,7 @@ ::: arkindex_worker.worker.entity options: members: - - EntityType + - MissingEntityType options: show_category_heading: no ::: arkindex_worker.worker.entity.EntityMixin diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index 4d9e4459..4d1773a7 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -4,6 +4,7 @@ from uuid import UUID import pytest from apistar.exceptions import ErrorResponse +from responses import matchers from arkindex_worker.cache import ( CachedElement, @@ -12,7 +13,7 @@ from arkindex_worker.cache import ( CachedTranscriptionEntity, ) from arkindex_worker.models import Element, Transcription -from arkindex_worker.worker import EntityType +from arkindex_worker.worker.entity import MissingEntityType from arkindex_worker.worker.transcription import TextOrientation from . import BASE_API_CALLS @@ -23,7 +24,7 @@ def test_create_entity_wrong_element(mock_elements_worker): mock_elements_worker.create_entity( element=None, name="Bob Bob", - type=EntityType.Person, + type="person", ) assert ( str(e.value) @@ -34,7 +35,7 @@ def test_create_entity_wrong_element(mock_elements_worker): mock_elements_worker.create_entity( element="not element type", name="Bob Bob", - type=EntityType.Person, + type="person", ) assert ( str(e.value) @@ -49,7 +50,7 @@ def test_create_entity_wrong_name(mock_elements_worker): mock_elements_worker.create_entity( element=elt, name=None, - type=EntityType.Person, + type="person", ) assert str(e.value) == "name shouldn't be null and should be of type str" @@ -57,7 +58,7 @@ def test_create_entity_wrong_name(mock_elements_worker): mock_elements_worker.create_entity( element=elt, name=1234, - type=EntityType.Person, + type="person", ) assert str(e.value) == "name shouldn't be null and should be of type str" @@ -71,7 +72,7 @@ def test_create_entity_wrong_type(mock_elements_worker): name="Bob Bob", type=None, ) - assert str(e.value) == "type shouldn't be null and should be of type EntityType" + assert str(e.value) == "type shouldn't be null and should be of type str" with pytest.raises(AssertionError) as e: mock_elements_worker.create_entity( @@ -79,15 +80,7 @@ def test_create_entity_wrong_type(mock_elements_worker): name="Bob Bob", type=1234, ) - assert str(e.value) == "type shouldn't be null and should be of type EntityType" - - with pytest.raises(AssertionError) as e: - mock_elements_worker.create_entity( - element=elt, - name="Bob Bob", - type="not_an_entity_type", - ) - assert str(e.value) == "type shouldn't be null and should be of type EntityType" + assert str(e.value) == "type shouldn't be null and should be of type str" def test_create_entity_wrong_corpus(monkeypatch, mock_elements_worker): @@ -99,7 +92,7 @@ def test_create_entity_wrong_corpus(monkeypatch, mock_elements_worker): mock_elements_worker.create_entity( element=elt, name="Bob Bob", - type=EntityType.Person, + type="person", metas="wrong metas", ) assert str(e.value) == "metas should be of type dict" @@ -112,7 +105,7 @@ def test_create_entity_wrong_metas(mock_elements_worker): mock_elements_worker.create_entity( element=elt, name="Bob Bob", - type=EntityType.Person, + type="person", metas="wrong metas", ) assert str(e.value) == "metas should be of type dict" @@ -125,13 +118,15 @@ def test_create_entity_wrong_validated(mock_elements_worker): mock_elements_worker.create_entity( element=elt, name="Bob Bob", - type=EntityType.Person, + type="person", validated="wrong validated", ) assert str(e.value) == "validated should be of type bool" def test_create_entity_api_error(responses, mock_elements_worker): + # Set one entity type + mock_elements_worker.entity_types = {"person": "person-entity-type-id"} elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, @@ -143,7 +138,7 @@ def test_create_entity_api_error(responses, mock_elements_worker): mock_elements_worker.create_entity( element=elt, name="Bob Bob", - type=EntityType.Person, + type="person", ) assert len(responses.calls) == len(BASE_API_CALLS) + 5 @@ -160,6 +155,9 @@ def test_create_entity_api_error(responses, mock_elements_worker): def test_create_entity(responses, mock_elements_worker): + # Set one entity type + mock_elements_worker.entity_types = {"person": "person-entity-type-id"} + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, @@ -171,7 +169,7 @@ def test_create_entity(responses, mock_elements_worker): entity_id = mock_elements_worker.create_entity( element=elt, name="Bob Bob", - type=EntityType.Person, + type="person", ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 @@ -182,7 +180,7 @@ def test_create_entity(responses, mock_elements_worker): ] assert json.loads(responses.calls[-1].request.body) == { "name": "Bob Bob", - "type": "person", + "type_id": "person-entity-type-id", "metas": {}, "validated": None, "corpus": "11111111-1111-1111-1111-111111111111", @@ -191,7 +189,49 @@ def test_create_entity(responses, mock_elements_worker): assert entity_id == "12345678-1234-1234-1234-123456789123" +def test_create_entity_missing_type(responses, mock_elements_worker): + """ + Create entity with an unknown type will fail. + """ + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + # Call to list entity types + responses.add( + responses.GET, + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/entity-types/", + status=200, + json={ + "count": 1, + "next": None, + "results": [ + {"id": "person-entity-type-id", "name": "person", "color": "00d1b2"} + ], + }, + ) + + with pytest.raises( + AssertionError, match="Entity type `new-entity` not found in the corpus." + ): + mock_elements_worker.create_entity( + element=elt, + name="Bob Bob", + type="new-entity", + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "GET", + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/entity-types/", + ), + ] + + def test_create_entity_with_cache(responses, mock_elements_worker_with_cache): + # Set one entity type + mock_elements_worker_with_cache.entity_types = {"person": "person-entity-type-id"} elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") responses.add( responses.POST, @@ -203,7 +243,7 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache): entity_id = mock_elements_worker_with_cache.create_entity( element=elt, name="Bob Bob", - type=EntityType.Person, + type="person", ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 @@ -215,7 +255,7 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache): assert json.loads(responses.calls[-1].request.body) == { "name": "Bob Bob", - "type": "person", + "type_id": "person-entity-type-id", "metas": {}, "validated": None, "corpus": "11111111-1111-1111-1111-111111111111", @@ -784,3 +824,72 @@ def test_list_corpus_entities_wrong_parent(mock_elements_worker, wrong_parent): with pytest.raises(AssertionError) as e: mock_elements_worker.list_corpus_entities(parent=wrong_parent) assert str(e.value) == "parent should be of type Element" + + +def test_check_required_entity_types(responses, mock_elements_worker): + # Set one entity type + mock_elements_worker.entity_types = {"person": "person-entity-type-id"} + + checked_types = ["person", "new-entity"] + + # Call to create new entity type + responses.add( + responses.POST, + "http://testserver/api/v1/entity/types/", + status=200, + match=[ + matchers.json_params_matcher( + { + "name": "new-entity", + "corpus": "11111111-1111-1111-1111-111111111111", + } + ) + ], + json={ + "id": "new-entity-id", + "corpus": "11111111-1111-1111-1111-111111111111", + "name": "new-entity", + "color": "ffd1b3", + }, + ) + + mock_elements_worker.check_required_entity_types( + entity_types=checked_types, + ) + + # Make sure the entity_types entry has been updated + assert mock_elements_worker.entity_types == { + "person": "person-entity-type-id", + "new-entity": "new-entity-id", + } + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "POST", + "http://testserver/api/v1/entity/types/", + ), + ] + + +def test_check_required_entity_types_no_creation_allowed( + responses, mock_elements_worker +): + # Set one entity type + mock_elements_worker.entity_types = {"person": "person-entity-type-id"} + + checked_types = ["person", "new-entity"] + + with pytest.raises( + MissingEntityType, match="Entity type `new-entity` was not in the corpus." + ): + mock_elements_worker.check_required_entity_types( + entity_types=checked_types, create_missing=False + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS -- GitLab