Skip to content
Snippets Groups Projects
Commit b1dd27ef authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Bulk endpoint for transcription entities

parent 74f6ae6e
No related branches found
No related tags found
1 merge request!308Bulk endpoint for transcription entities
Pipeline #80122 passed
......@@ -3,7 +3,7 @@
ElementsWorker methods for entities.
"""
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, TypedDict, Union
from peewee import IntegrityError
......@@ -11,6 +11,10 @@ from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscriptionEntity
from arkindex_worker.models import Element, Transcription
Entity = TypedDict(
"Entity", name=str, type_id=str, length=int, offset=int, confidence=Optional[float]
)
class MissingEntityType(Exception):
"""
......@@ -108,7 +112,7 @@ class EntityMixin(object):
"worker_run_id": self.worker_run_id,
},
)
self.report.add_entity(element.id, entity["id"], type, name)
self.report.add_entity(element.id, entity["id"], entity_type_id, name)
if self.use_cache:
# Store entity in local cache
......@@ -204,6 +208,110 @@ class EntityMixin(object):
)
return transcription_ent
def create_transcription_entities(
self,
transcription: Transcription,
entities: List[Entity],
) -> List[Dict[str, str]]:
"""
Create multiple entities attached to a transcription in a single API request.
:param transcription: Transcription to create the entity on.
:param entities: List of dicts, one per element. Each dict can have the following keys:
name (str)
Required. Name of the entity.
type_id (str)
Required. ID of the EntityType of the entity.
length (int)
Required. Length of the entity in the transcription's text.
offset (int)
Required. Starting position of the entity in the transcription's text, as a 0-based index.
confidence (float or None)
Optional confidence score, between 0.0 and 1.0.
:return: List of dicts, with each dict having a two keys, `transcription_entity_id` and `entity_id`, holding the UUID of each created object.
"""
assert transcription and isinstance(
transcription, Transcription
), "transcription shouldn't be null and should be of type Transcription"
# Needed for MLreport
assert (
hasattr(transcription, "element") and transcription.element
), f"No element linked to {transcription}"
assert entities and isinstance(
entities, list
), "entities shouldn't be null and should be of type list"
for index, entity in enumerate(entities):
assert isinstance(
entity, dict
), f"Entity at index {index} in entities: Should be of type dict"
name = entity.get("name")
assert name and isinstance(
name, str
), f"Entity at index {index} in entities: name shouldn't be null and should be of type str"
type_id = entity.get("type_id")
assert type_id and isinstance(
type_id, str
), f"Entity at index {index} in entities: type_id shouldn't be null and should be of type str"
offset = entity.get("offset")
assert (
offset is not None and isinstance(offset, int) and offset >= 0
), f"Entity at index {index} in entities: offset shouldn't be null and should be a positive integer"
length = entity.get("length")
assert (
length is not None and isinstance(length, int) and length > 0
), f"Entity at index {index} in entities: length shouldn't be null and should be a strictly positive integer"
confidence = entity.get("confidence")
assert confidence is None or (
isinstance(confidence, float) and 0 <= confidence <= 1
), f"Entity at index {index} in entities: confidence should be None or a float in [0..1] range"
if self.is_read_only:
logger.warning(
"Cannot create transcription entities in bulk as this worker is in read-only mode"
)
return
created_ids = self.request(
"CreateTranscriptionEntities",
id=transcription.id,
body={
"worker_run_id": self.worker_run_id,
"entities": entities,
},
)
for entity, created_objects in zip(entities, created_ids["entities"]):
# Report entity creation
self.report.add_entity(
transcription.element.id,
created_objects["entity_id"],
entity.get("type_id"),
entity.get("name"),
)
# Report transcription entity creation
self.report.add_transcription_entity(
created_objects["entity_id"],
transcription,
created_objects["transcription_entity_id"],
)
return created_ids["entities"]
def list_transcription_entities(
self,
transcription: Transcription,
......
......@@ -893,3 +893,239 @@ def test_check_required_entity_types_no_creation_allowed(
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS
@pytest.mark.parametrize("transcription", (None, "not a transcription", 1))
def test_create_transcription_entities_wrong_transcription(
mock_elements_worker, transcription
):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entities(
transcription=transcription,
entities=[],
)
assert (
str(e.value)
== "transcription shouldn't be null and should be of type Transcription"
)
def test_create_transcription_entities_no_transcription_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entities(
transcription=Transcription(id="transcription_id"),
entities=[],
)
assert str(e.value) == "No element linked to Transcription (transcription_id)"
@pytest.mark.parametrize("entities", (None, "not a list of entities", 1))
def test_create_transcription_entities_wrong_entities(mock_elements_worker, entities):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entities(
transcription=Transcription(
id="transcription_id", element={"id": "element_id"}
),
entities=entities,
)
assert str(e.value) == "entities shouldn't be null and should be of type list"
def test_create_transcription_entities_wrong_entities_subtype(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entities(
transcription=Transcription(
id="transcription_id", element={"id": "element_id"}
),
entities=["not a dict"],
)
assert str(e.value) == "Entity at index 0 in entities: Should be of type dict"
@pytest.mark.parametrize(
"entity, error",
(
(
{
"name": None,
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": 0,
"length": 1,
"confidence": 0.5,
},
"Entity at index 0 in entities: name shouldn't be null and should be of type str",
),
(
{"name": "A", "type_id": None, "offset": 0, "length": 1, "confidence": 0.5},
"Entity at index 0 in entities: type_id shouldn't be null and should be of type str",
),
(
{"name": "A", "type_id": 0, "offset": 0, "length": 1, "confidence": 0.5},
"Entity at index 0 in entities: type_id shouldn't be null and should be of type str",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": None,
"length": 1,
"confidence": 0.5,
},
"Entity at index 0 in entities: offset shouldn't be null and should be a positive integer",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": -2,
"length": 1,
"confidence": 0.5,
},
"Entity at index 0 in entities: offset shouldn't be null and should be a positive integer",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": 0,
"length": None,
"confidence": 0.5,
},
"Entity at index 0 in entities: length shouldn't be null and should be a strictly positive integer",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": 0,
"length": 0,
"confidence": 0.5,
},
"Entity at index 0 in entities: length shouldn't be null and should be a strictly positive integer",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": 0,
"length": 1,
"confidence": "not None or a float",
},
"Entity at index 0 in entities: confidence should be None or a float in [0..1] range",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": 0,
"length": 1,
"confidence": 1.3,
},
"Entity at index 0 in entities: confidence should be None or a float in [0..1] range",
),
),
)
def test_create_transcription_entities_wrong_entity(
mock_elements_worker, entity, error
):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entities(
transcription=Transcription(
id="transcription_id", element={"id": "element_id"}
),
entities=[entity],
)
assert str(e.value) == error
def test_create_transcription_entities(responses, mock_elements_worker):
element_id = "element_id"
transcription = Transcription(id="transcription-id", element={"id": element_id})
# Call to Transcription entities creation in bulk
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
status=201,
match=[
matchers.json_params_matcher(
{
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"entities": [
{
"name": "Teklia",
"type_id": "22222222-2222-2222-2222-222222222222",
"offset": 0,
"length": 6,
"confidence": 1.0,
}
],
}
)
],
json={
"entities": [
{
"transcription_entity_id": "transc-entity-id",
"entity_id": "entity-id",
}
]
},
)
# Store entity type/slug correspondence on the worker
mock_elements_worker.entity_types = {
"22222222-2222-2222-2222-222222222222": "organization"
}
created_objects = mock_elements_worker.create_transcription_entities(
transcription=transcription,
entities=[
{
"name": "Teklia",
"type_id": "22222222-2222-2222-2222-222222222222",
"offset": 0,
"length": 6,
"confidence": 1.0,
}
],
)
assert len(created_objects) == 1
assert element_id in mock_elements_worker.report.report_data["elements"]
ml_report = mock_elements_worker.report.report_data["elements"][element_id]
assert "started" in ml_report
del ml_report["started"]
# Check reporting
assert ml_report == {
"elements": {},
"transcriptions": 0,
"classifications": {},
"entities": [
{
"id": "entity-id",
"type": "22222222-2222-2222-2222-222222222222",
"name": "Teklia",
}
],
"transcription_entities": [
{
"transcription_id": "transcription-id",
"entity_id": "entity-id",
"transcription_entity_id": "transc-entity-id",
}
],
"metadata": [],
"errors": [],
}
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
),
]
......@@ -167,7 +167,10 @@ def test_add_transcription_count():
def test_add_entity():
reporter = Reporter("worker")
reporter.add_entity(
"myelement", "12341234-1234-1234-1234-123412341234", "person", "Bob Bob"
"myelement",
"12341234-1234-1234-1234-123412341234",
"person-entity-type-id",
"Bob Bob",
)
assert "myelement" in reporter.report_data["elements"]
element_data = reporter.report_data["elements"]["myelement"]
......@@ -179,7 +182,7 @@ def test_add_entity():
"entities": [
{
"id": "12341234-1234-1234-1234-123412341234",
"type": "person",
"type": "person-entity-type-id",
"name": "Bob Bob",
}
],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment