Skip to content
Snippets Groups Projects
Commit 6f16aeb9 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

New entity types handling

parent 62fca186
No related branches found
Tags 0.3.2
1 merge request!304New entity types handling
Pipeline #80094 passed
......@@ -19,7 +19,7 @@ from arkindex_worker.reporting import Reporter
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker.classification import ClassificationMixin
from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401
from arkindex_worker.worker.entity import EntityMixin # noqa: F401
from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
from arkindex_worker.worker.transcription import TranscriptionMixin
from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401
......@@ -92,6 +92,10 @@ class ElementsWorker(
self.classes = {}
self.entity_types = {}
"""Known and available entity types in processed corpus
"""
self._worker_version_cache = {}
def list_elements(self) -> Union[Iterable[CachedElement], List[str]]:
......
......@@ -3,8 +3,7 @@
ElementsWorker methods for entities.
"""
from enum import Enum
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union
from peewee import IntegrityError
......@@ -13,26 +12,54 @@ from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscripti
from arkindex_worker.models import Element, Transcription
class EntityType(Enum):
class MissingEntityType(Exception):
"""
Type of an entity.
Raised when the specified entity type was not found in the corpus and
the worker cannot create it.
"""
Person = "person"
Location = "location"
Subject = "subject"
Organization = "organization"
Misc = "misc"
Number = "number"
Date = "date"
class EntityMixin(object):
def check_required_entity_types(
self, entity_types: List[str], create_missing: bool = True
):
"""Checks that every entity type needed is available in the corpus.
Missing ones may be created automatically if needed.
:param entity_types: Entity type names to search.
:param create_missing: Whether the missing types should be created. Defaults to True.
:raises MissingEntityType: When an entity type is missing and cannot create.
"""
# Retrieve entity_type ID
if not self.entity_types:
# Load entity_types of corpus
self.list_corpus_entity_types()
for entity_type in entity_types:
# Do nothing if type already exists
if entity_type in self.entity_types:
continue
# Do not create missing if not requested
if not create_missing:
raise MissingEntityType(
f"Entity type `{entity_type}` was not in the corpus."
)
# Create type if non-existent
self.entity_types[entity_type] = self.request(
"CreateEntityType",
body={
"name": entity_type,
"corpus": self.corpus_id,
},
)["id"]
logger.info(f"Created a new entity type with name `{entity_type}`.")
def create_entity(
self,
element: Union[Element, CachedElement],
name: str,
type: EntityType,
type: str,
metas=dict(),
validated=None,
):
......@@ -52,8 +79,8 @@ class EntityMixin(object):
name, str
), "name shouldn't be null and should be of type str"
assert type and isinstance(
type, EntityType
), "type shouldn't be null and should be of type EntityType"
type, str
), "type shouldn't be null and should be of type str"
if metas:
assert isinstance(metas, dict), "metas should be of type dict"
if validated is not None:
......@@ -62,18 +89,26 @@ class EntityMixin(object):
logger.warning("Cannot create entity as this worker is in read-only mode")
return
# Retrieve entity_type ID
if not self.entity_types:
# Load entity_types of corpus
self.list_corpus_entity_types()
entity_type_id = self.entity_types.get(type)
assert entity_type_id, f"Entity type `{type}` not found in the corpus."
entity = self.request(
"CreateEntity",
body={
"name": name,
"type": type.value,
"type_id": entity_type_id,
"metas": metas,
"validated": validated,
"corpus": self.corpus_id,
"worker_run_id": self.worker_run_id,
},
)
self.report.add_entity(element.id, entity["id"], type.value, name)
self.report.add_entity(element.id, entity["id"], type, name)
if self.use_cache:
# Store entity in local cache
......@@ -81,7 +116,7 @@ class EntityMixin(object):
to_insert = [
{
"id": entity["id"],
"type": type.value,
"type": type,
"name": name,
"validated": validated if validated is not None else False,
"metas": metas,
......@@ -225,3 +260,19 @@ class EntityMixin(object):
return self.api_client.paginate(
"ListCorpusEntities", id=self.corpus_id, **query_params
)
def list_corpus_entity_types(
self,
):
"""
Loads available entity types in corpus.
"""
self.entity_types = {
entity_type["name"]: entity_type["id"]
for entity_type in self.api_client.paginate(
"ListCorpusEntityTypes", id=self.corpus_id
)
}
logger.info(
f"Loaded {len(self.entity_types)} entity types in corpus ({self.corpus_id})."
)
......@@ -3,7 +3,7 @@
::: arkindex_worker.worker.entity
options:
members:
- EntityType
- MissingEntityType
options:
show_category_heading: no
::: arkindex_worker.worker.entity.EntityMixin
......
......@@ -4,6 +4,7 @@ from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from responses import matchers
from arkindex_worker.cache import (
CachedElement,
......@@ -12,7 +13,7 @@ from arkindex_worker.cache import (
CachedTranscriptionEntity,
)
from arkindex_worker.models import Element, Transcription
from arkindex_worker.worker import EntityType
from arkindex_worker.worker.entity import MissingEntityType
from arkindex_worker.worker.transcription import TextOrientation
from . import BASE_API_CALLS
......@@ -23,7 +24,7 @@ def test_create_entity_wrong_element(mock_elements_worker):
mock_elements_worker.create_entity(
element=None,
name="Bob Bob",
type=EntityType.Person,
type="person",
)
assert (
str(e.value)
......@@ -34,7 +35,7 @@ def test_create_entity_wrong_element(mock_elements_worker):
mock_elements_worker.create_entity(
element="not element type",
name="Bob Bob",
type=EntityType.Person,
type="person",
)
assert (
str(e.value)
......@@ -49,7 +50,7 @@ def test_create_entity_wrong_name(mock_elements_worker):
mock_elements_worker.create_entity(
element=elt,
name=None,
type=EntityType.Person,
type="person",
)
assert str(e.value) == "name shouldn't be null and should be of type str"
......@@ -57,7 +58,7 @@ def test_create_entity_wrong_name(mock_elements_worker):
mock_elements_worker.create_entity(
element=elt,
name=1234,
type=EntityType.Person,
type="person",
)
assert str(e.value) == "name shouldn't be null and should be of type str"
......@@ -71,7 +72,7 @@ def test_create_entity_wrong_type(mock_elements_worker):
name="Bob Bob",
type=None,
)
assert str(e.value) == "type shouldn't be null and should be of type EntityType"
assert str(e.value) == "type shouldn't be null and should be of type str"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_entity(
......@@ -79,15 +80,7 @@ def test_create_entity_wrong_type(mock_elements_worker):
name="Bob Bob",
type=1234,
)
assert str(e.value) == "type shouldn't be null and should be of type EntityType"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_entity(
element=elt,
name="Bob Bob",
type="not_an_entity_type",
)
assert str(e.value) == "type shouldn't be null and should be of type EntityType"
assert str(e.value) == "type shouldn't be null and should be of type str"
def test_create_entity_wrong_corpus(monkeypatch, mock_elements_worker):
......@@ -99,7 +92,7 @@ def test_create_entity_wrong_corpus(monkeypatch, mock_elements_worker):
mock_elements_worker.create_entity(
element=elt,
name="Bob Bob",
type=EntityType.Person,
type="person",
metas="wrong metas",
)
assert str(e.value) == "metas should be of type dict"
......@@ -112,7 +105,7 @@ def test_create_entity_wrong_metas(mock_elements_worker):
mock_elements_worker.create_entity(
element=elt,
name="Bob Bob",
type=EntityType.Person,
type="person",
metas="wrong metas",
)
assert str(e.value) == "metas should be of type dict"
......@@ -125,13 +118,15 @@ def test_create_entity_wrong_validated(mock_elements_worker):
mock_elements_worker.create_entity(
element=elt,
name="Bob Bob",
type=EntityType.Person,
type="person",
validated="wrong validated",
)
assert str(e.value) == "validated should be of type bool"
def test_create_entity_api_error(responses, mock_elements_worker):
# Set one entity type
mock_elements_worker.entity_types = {"person": "person-entity-type-id"}
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
......@@ -143,7 +138,7 @@ def test_create_entity_api_error(responses, mock_elements_worker):
mock_elements_worker.create_entity(
element=elt,
name="Bob Bob",
type=EntityType.Person,
type="person",
)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
......@@ -160,6 +155,9 @@ def test_create_entity_api_error(responses, mock_elements_worker):
def test_create_entity(responses, mock_elements_worker):
# Set one entity type
mock_elements_worker.entity_types = {"person": "person-entity-type-id"}
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
......@@ -171,7 +169,7 @@ def test_create_entity(responses, mock_elements_worker):
entity_id = mock_elements_worker.create_entity(
element=elt,
name="Bob Bob",
type=EntityType.Person,
type="person",
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
......@@ -182,7 +180,7 @@ def test_create_entity(responses, mock_elements_worker):
]
assert json.loads(responses.calls[-1].request.body) == {
"name": "Bob Bob",
"type": "person",
"type_id": "person-entity-type-id",
"metas": {},
"validated": None,
"corpus": "11111111-1111-1111-1111-111111111111",
......@@ -191,7 +189,49 @@ def test_create_entity(responses, mock_elements_worker):
assert entity_id == "12345678-1234-1234-1234-123456789123"
def test_create_entity_missing_type(responses, mock_elements_worker):
"""
Create entity with an unknown type will fail.
"""
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
# Call to list entity types
responses.add(
responses.GET,
"http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/entity-types/",
status=200,
json={
"count": 1,
"next": None,
"results": [
{"id": "person-entity-type-id", "name": "person", "color": "00d1b2"}
],
},
)
with pytest.raises(
AssertionError, match="Entity type `new-entity` not found in the corpus."
):
mock_elements_worker.create_entity(
element=elt,
name="Bob Bob",
type="new-entity",
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/entity-types/",
),
]
def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
# Set one entity type
mock_elements_worker_with_cache.entity_types = {"person": "person-entity-type-id"}
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
responses.POST,
......@@ -203,7 +243,7 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
entity_id = mock_elements_worker_with_cache.create_entity(
element=elt,
name="Bob Bob",
type=EntityType.Person,
type="person",
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
......@@ -215,7 +255,7 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
assert json.loads(responses.calls[-1].request.body) == {
"name": "Bob Bob",
"type": "person",
"type_id": "person-entity-type-id",
"metas": {},
"validated": None,
"corpus": "11111111-1111-1111-1111-111111111111",
......@@ -784,3 +824,72 @@ def test_list_corpus_entities_wrong_parent(mock_elements_worker, wrong_parent):
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_corpus_entities(parent=wrong_parent)
assert str(e.value) == "parent should be of type Element"
def test_check_required_entity_types(responses, mock_elements_worker):
# Set one entity type
mock_elements_worker.entity_types = {"person": "person-entity-type-id"}
checked_types = ["person", "new-entity"]
# Call to create new entity type
responses.add(
responses.POST,
"http://testserver/api/v1/entity/types/",
status=200,
match=[
matchers.json_params_matcher(
{
"name": "new-entity",
"corpus": "11111111-1111-1111-1111-111111111111",
}
)
],
json={
"id": "new-entity-id",
"corpus": "11111111-1111-1111-1111-111111111111",
"name": "new-entity",
"color": "ffd1b3",
},
)
mock_elements_worker.check_required_entity_types(
entity_types=checked_types,
)
# Make sure the entity_types entry has been updated
assert mock_elements_worker.entity_types == {
"person": "person-entity-type-id",
"new-entity": "new-entity-id",
}
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/entity/types/",
),
]
def test_check_required_entity_types_no_creation_allowed(
responses, mock_elements_worker
):
# Set one entity type
mock_elements_worker.entity_types = {"person": "person-entity-type-id"}
checked_types = ["person", "new-entity"]
with pytest.raises(
MissingEntityType, match="Entity type `new-entity` was not in the corpus."
):
mock_elements_worker.check_required_entity_types(
entity_types=checked_types, create_missing=False
)
assert len(responses.calls) == len(BASE_API_CALLS)
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment