Skip to content
Snippets Groups Projects

Bulk endpoint for transcription entities

Merged Yoann Schneider requested to merge bulk-transcription-entities into master
3 files
+ 351
4
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -3,7 +3,7 @@
ElementsWorker methods for entities.
"""
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, TypedDict, Union
from peewee import IntegrityError
@@ -11,6 +11,10 @@ from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscriptionEntity
from arkindex_worker.models import Element, Transcription
Entity = TypedDict(
"Entity", name=str, type_id=str, length=int, offset=int, confidence=Optional[float]
)
class MissingEntityType(Exception):
"""
@@ -108,7 +112,7 @@ class EntityMixin(object):
"worker_run_id": self.worker_run_id,
},
)
self.report.add_entity(element.id, entity["id"], type, name)
self.report.add_entity(element.id, entity["id"], entity_type_id, name)
if self.use_cache:
# Store entity in local cache
@@ -204,6 +208,110 @@ class EntityMixin(object):
)
return transcription_ent
def create_transcription_entities(
self,
transcription: Transcription,
entities: List[Entity],
) -> List[Dict[str, str]]:
"""
Create multiple entities attached to a transcription in a single API request.
:param transcription: Transcription to create the entity on.
:param entities: List of dicts, one per element. Each dict can have the following keys:
name (str)
Required. Name of the entity.
type_id (str)
Required. ID of the EntityType of the entity.
length (int)
Required. Length of the entity in the transcription's text.
offset (int)
Required. Starting position of the entity in the transcription's text, as a 0-based index.
confidence (float or None)
Optional confidence score, between 0.0 and 1.0.
:return: List of dicts, with each dict having a two keys, `transcription_entity_id` and `entity_id`, holding the UUID of each created object.
"""
assert transcription and isinstance(
transcription, Transcription
), "transcription shouldn't be null and should be of type Transcription"
# Needed for MLreport
assert (
hasattr(transcription, "element") and transcription.element
), f"No element linked to {transcription}"
assert entities and isinstance(
entities, list
), "entities shouldn't be null and should be of type list"
for index, entity in enumerate(entities):
assert isinstance(
entity, dict
), f"Entity at index {index} in entities: Should be of type dict"
name = entity.get("name")
assert name and isinstance(
name, str
), f"Entity at index {index} in entities: name shouldn't be null and should be of type str"
type_id = entity.get("type_id")
assert type_id and isinstance(
type_id, str
), f"Entity at index {index} in entities: type_id shouldn't be null and should be of type str"
offset = entity.get("offset")
assert (
offset is not None and isinstance(offset, int) and offset >= 0
), f"Entity at index {index} in entities: offset shouldn't be null and should be a positive integer"
length = entity.get("length")
assert (
length is not None and isinstance(length, int) and length > 0
), f"Entity at index {index} in entities: length shouldn't be null and should be a strictly positive integer"
confidence = entity.get("confidence")
assert confidence is None or (
isinstance(confidence, float) and 0 <= confidence <= 1
), f"Entity at index {index} in entities: confidence should be None or a float in [0..1] range"
if self.is_read_only:
logger.warning(
"Cannot create transcription entities in bulk as this worker is in read-only mode"
)
return
created_ids = self.request(
"CreateTranscriptionEntities",
id=transcription.id,
body={
"worker_run_id": self.worker_run_id,
"entities": entities,
},
)
for entity, created_objects in zip(entities, created_ids["entities"]):
# Report entity creation
self.report.add_entity(
transcription.element.id,
created_objects["entity_id"],
entity.get("type_id"),
entity.get("name"),
)
# Report transcription entity creation
self.report.add_transcription_entity(
created_objects["entity_id"],
transcription,
created_objects["transcription_entity_id"],
)
return created_ids["entities"]
def list_transcription_entities(
self,
transcription: Transcription,
Loading